from fastapi import FastAPI, HTTPException from pydantic import BaseModel import ollama import weaviate from weaviate.connect import ConnectionParams from weaviate.collections.classes.filters import Filter import re import uvicorn app = FastAPI() OLLAMA_BASE_URL = "http://ollama:11434" WEAVIATE_URL = "http://weaviate:8080" # Inicjalizacja klientów ollama_client = ollama.Client(host=OLLAMA_BASE_URL) weaviate_client = weaviate.WeaviateClient( connection_params=ConnectionParams.from_params( http_host="weaviate", http_port=8080, http_secure=False, grpc_host="weaviate", grpc_port=50051, grpc_secure=False, ) ) weaviate_client.connect() # Pobierz kolekcję collection = weaviate_client.collections.get("Document") prompt = """ Jesteś precyzyjnym narzędziem do generowania słów kluczowych z zakresu BHP i prawa pracy. Twoje zadanie to podanie WYŁĄCZNIE najistotniejszych słów do wyszukiwania w bazie dokumentów prawnych. Ścisłe zasady: 1. Jeśli zapytanie dotyczy konkretnego artykułu: - Podaj TYLKO numer artykułu i nazwę kodeksu (np. "Art. 154, Kodeks pracy"). - NIE dodawaj żadnych innych słów. 2. Jeśli zapytanie nie dotyczy konkretnego artykułu: - Podaj maksymalnie 3 najbardziej specyficzne terminy związane z zapytaniem. - Unikaj ogólnych słów jak "praca", "pracownik", "pracodawca", chyba że są częścią specjalistycznego terminu. 3. Używaj wyłącznie terminów, które z pewnością występują w dokumentach prawnych lub specjalistycznych opracowaniach. 4. NIE dodawaj własnych interpretacji ani rozszerzeń zapytania. Odpowiedz TYLKO listą słów kluczowych oddzielonych przecinkami, bez żadnych dodatkowych wyjaśnień czy komentarzy. Zapytanie: '{query}' """ def analyze_query(query): analysis = ollama_client.chat( model="gemma2:2b", messages=[{"role": "user", "content": prompt.format(query=query)}] ) keywords = [word.strip() for word in analysis['message']['content'].split(',') if word.strip()] print("Słowa kluczowe:", keywords) return keywords def extract_relevant_fragment(content, query, context_size=200): article_match = re.match(r'Art\.\s*(\d+)', query) if article_match: article_number = article_match.group(1) article_pattern = rf"Art\.\s*{article_number}\..*?(?=Art\.\s*\d+\.|\Z)" match = re.search(article_pattern, content, re.DOTALL) if match: return match.group(0).strip() index = content.lower().find(query.lower()) if index != -1: start = max(0, index - context_size) end = min(len(content), index + len(query) + context_size) return f"...{content[start:end]}..." return content[:400] + "..." def expand_query(keywords): expansions = {} expanded_terms = keywords.copy() for keyword in keywords: expanded_terms.extend(expansions.get(keyword.lower(), [])) return " ".join(set(expanded_terms)) def extract_relevant_fragment(content, query, context_size=200): article_pattern = r"Art\.\s*154\..*?(?=Art\.\s*\d+\.|\Z)" match = re.search(article_pattern, content, re.DOTALL) if match: return match.group(0).strip() index = content.lower().find(query.lower()) if index != -1: start = max(0, index - context_size) end = min(len(content), index + len(query) + context_size) return f"...{content[start:end]}..." return content[:400] + "..." def hybrid_search(keywords, limit=5, alpha=0.5): if isinstance(keywords, str): keywords = [keywords] all_results = [] for keyword in keywords: print(f"\nWyszukiwanie hybrydowe dla słowa kluczowego: '{keyword}'") response = collection.query.hybrid( query=keyword, alpha=alpha, limit=limit * 2 ) for obj in response.objects: relevant_fragment = extract_relevant_fragment(obj.properties['content'], keyword) if keyword.lower() in relevant_fragment.lower(): result = { "uuid": obj.uuid, "relevant_fragment": relevant_fragment, "file_name": obj.properties['fileName'], "keyword": keyword } if result not in all_results: all_results.append(result) print(f"UUID: {obj.uuid}") print(f"Relewantny fragment:\n{relevant_fragment}") print(f"Nazwa pliku: {obj.properties['fileName']}") print("---") if len(all_results) >= limit: break if len(all_results) >= limit: break return all_results[:limit] class ChatRequest(BaseModel): model: str messages: list[dict] query: str class ChatResponse(BaseModel): content: str weaviate_results: list @app.post("/chat", response_model=ChatResponse) async def chat_endpoint(request: ChatRequest): try: keywords = analyze_query(request.query) weaviate_results = hybrid_search(keywords) if not weaviate_results: response = ollama_client.chat( model=request.model, messages=[{"role": "user", "content": f"Nie znalazłem informacji na temat: {request.query}. Proszę poinformuj użytkownika, że nie masz wystarczającej wiedzy, aby udzielić jednoznacznej odpowiedzi."}] ) else: context = "Znalezione informacje:\n" for item in weaviate_results: context += f"Źródło: {item['file_name']}\nFragment: {item['relevant_fragment']}\n\n" response = ollama_client.chat( model=request.model, messages=[ {"role": "system", "content": context}, {"role": "user", "content": f"Na podstawie powyższych informacji, odpowiedz na pytanie: {request.query}. Odwołaj się do konkretnych artykułów lub zacytuj fragmenty źródeł."} ] ) return ChatResponse( content=response['message']['content'], weaviate_results=weaviate_results ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/models") async def list_models(): try: models = ollama_client.list() return {"models": [model['name'] for model in models['models']]} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)