ably.do/ollama_service.py

181 lines
6.6 KiB
Python

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import ollama
import weaviate
from weaviate.connect import ConnectionParams
from weaviate.collections.classes.filters import Filter
import re
import uvicorn
app = FastAPI()
OLLAMA_BASE_URL = "http://ollama:11434"
WEAVIATE_URL = "http://weaviate:8080"
# Inicjalizacja klientów
ollama_client = ollama.Client(host=OLLAMA_BASE_URL)
weaviate_client = weaviate.WeaviateClient(
connection_params=ConnectionParams.from_params(
http_host="weaviate",
http_port=8080,
http_secure=False,
grpc_host="weaviate",
grpc_port=50051,
grpc_secure=False,
)
)
weaviate_client.connect()
# Pobierz kolekcję
collection = weaviate_client.collections.get("Document")
prompt = """
Jesteś precyzyjnym narzędziem do generowania słów kluczowych z zakresu BHP i prawa pracy. Twoje zadanie to podanie WYŁĄCZNIE najistotniejszych słów do wyszukiwania w bazie dokumentów prawnych.
Ścisłe zasady:
1. Jeśli zapytanie dotyczy konkretnego artykułu:
- Podaj TYLKO numer artykułu i nazwę kodeksu (np. "Art. 154, Kodeks pracy").
- NIE dodawaj żadnych innych słów.
2. Jeśli zapytanie nie dotyczy konkretnego artykułu:
- Podaj maksymalnie 3 najbardziej specyficzne terminy związane z zapytaniem.
- Unikaj ogólnych słów jak "praca", "pracownik", "pracodawca", chyba że są częścią specjalistycznego terminu.
3. Używaj wyłącznie terminów, które z pewnością występują w dokumentach prawnych lub specjalistycznych opracowaniach.
4. NIE dodawaj własnych interpretacji ani rozszerzeń zapytania.
Odpowiedz TYLKO listą słów kluczowych oddzielonych przecinkami, bez żadnych dodatkowych wyjaśnień czy komentarzy.
Zapytanie: '{query}'
"""
def analyze_query(query):
analysis = ollama_client.chat(
model="gemma2:2b",
messages=[{"role": "user", "content": prompt.format(query=query)}]
)
keywords = [word.strip() for word in analysis['message']['content'].split(',') if word.strip()]
print("Słowa kluczowe:", keywords)
return keywords
def extract_relevant_fragment(content, query, context_size=200):
article_match = re.match(r'Art\.\s*(\d+)', query)
if article_match:
article_number = article_match.group(1)
article_pattern = rf"Art\.\s*{article_number}\..*?(?=Art\.\s*\d+\.|\Z)"
match = re.search(article_pattern, content, re.DOTALL)
if match:
return match.group(0).strip()
index = content.lower().find(query.lower())
if index != -1:
start = max(0, index - context_size)
end = min(len(content), index + len(query) + context_size)
return f"...{content[start:end]}..."
return content[:400] + "..."
def expand_query(keywords):
expansions = {}
expanded_terms = keywords.copy()
for keyword in keywords:
expanded_terms.extend(expansions.get(keyword.lower(), []))
return " ".join(set(expanded_terms))
def extract_relevant_fragment(content, query, context_size=200):
article_pattern = r"Art\.\s*154\..*?(?=Art\.\s*\d+\.|\Z)"
match = re.search(article_pattern, content, re.DOTALL)
if match:
return match.group(0).strip()
index = content.lower().find(query.lower())
if index != -1:
start = max(0, index - context_size)
end = min(len(content), index + len(query) + context_size)
return f"...{content[start:end]}..."
return content[:400] + "..."
def hybrid_search(keywords, limit=5, alpha=0.5):
if isinstance(keywords, str):
keywords = [keywords]
all_results = []
for keyword in keywords:
print(f"\nWyszukiwanie hybrydowe dla słowa kluczowego: '{keyword}'")
response = collection.query.hybrid(
query=keyword,
alpha=alpha,
limit=limit * 2
)
for obj in response.objects:
relevant_fragment = extract_relevant_fragment(obj.properties['content'], keyword)
if keyword.lower() in relevant_fragment.lower():
result = {
"uuid": obj.uuid,
"relevant_fragment": relevant_fragment,
"file_name": obj.properties['fileName'],
"keyword": keyword
}
if result not in all_results:
all_results.append(result)
print(f"UUID: {obj.uuid}")
print(f"Relewantny fragment:\n{relevant_fragment}")
print(f"Nazwa pliku: {obj.properties['fileName']}")
print("---")
if len(all_results) >= limit:
break
if len(all_results) >= limit:
break
return all_results[:limit]
class ChatRequest(BaseModel):
model: str
messages: list[dict]
query: str
class ChatResponse(BaseModel):
content: str
weaviate_results: list
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(request: ChatRequest):
try:
keywords = analyze_query(request.query)
weaviate_results = hybrid_search(keywords)
if not weaviate_results:
response = ollama_client.chat(
model=request.model,
messages=[{"role": "user", "content": f"Nie znalazłem informacji na temat: {request.query}. Proszę poinformuj użytkownika, że nie masz wystarczającej wiedzy, aby udzielić jednoznacznej odpowiedzi."}]
)
else:
context = "Znalezione informacje:\n"
for item in weaviate_results:
context += f"Źródło: {item['file_name']}\nFragment: {item['relevant_fragment']}\n\n"
response = ollama_client.chat(
model=request.model,
messages=[
{"role": "system", "content": context},
{"role": "user", "content": f"Na podstawie powyższych informacji, odpowiedz na pytanie: {request.query}. Odwołaj się do konkretnych artykułów lub zacytuj fragmenty źródeł."}
]
)
return ChatResponse(
content=response['message']['content'],
weaviate_results=weaviate_results
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/models")
async def list_models():
try:
models = ollama_client.list()
return {"models": [model['name'] for model in models['models']]}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)