diff --git a/ollama_service.py b/ollama_service.py index 60ed2e8..d045cfc 100644 --- a/ollama_service.py +++ b/ollama_service.py @@ -1,11 +1,16 @@ -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import StreamingResponse from pydantic import BaseModel import ollama import weaviate from weaviate.connect import ConnectionParams from weaviate.collections.classes.filters import Filter import re +import json import uvicorn +import httpx +from typing import List, Optional +import asyncio app = FastAPI() @@ -25,9 +30,30 @@ weaviate_client = weaviate.WeaviateClient( ) ) weaviate_client.connect() -# Pobierz kolekcję collection = weaviate_client.collections.get("Document") +class Message(BaseModel): + role: str + content: str + +class ChatRequest(BaseModel): + model: str + messages: List[Message] + stream: Optional[bool] = False + options: Optional[dict] = None + +class ChatResponse(BaseModel): + model: str + created_at: str + message: Message + done: bool + total_duration: int + load_duration: int + prompt_eval_count: int + prompt_eval_duration: int + eval_count: int + eval_duration: int + prompt = """ Jesteś precyzyjnym narzędziem do generowania słów kluczowych z zakresu BHP i prawa pracy. Twoje zadanie to podanie WYŁĄCZNIE najistotniejszych słów do wyszukiwania w bazie dokumentów prawnych. @@ -55,120 +81,83 @@ def analyze_query(query): print("Słowa kluczowe:", keywords) return keywords -def extract_relevant_fragment(content, query, context_size=200): - article_match = re.match(r'Art\.\s*(\d+)', query) - if article_match: - article_number = article_match.group(1) - article_pattern = rf"Art\.\s*{article_number}\..*?(?=Art\.\s*\d+\.|\Z)" - match = re.search(article_pattern, content, re.DOTALL) - if match: - return match.group(0).strip() - - index = content.lower().find(query.lower()) - if index != -1: - start = max(0, index - context_size) - end = min(len(content), index + len(query) + context_size) - return f"...{content[start:end]}..." - return content[:400] + "..." - -def expand_query(keywords): - expansions = {} - expanded_terms = keywords.copy() - for keyword in keywords: - expanded_terms.extend(expansions.get(keyword.lower(), [])) - return " ".join(set(expanded_terms)) - -def extract_relevant_fragment(content, query, context_size=200): - article_pattern = r"Art\.\s*154\..*?(?=Art\.\s*\d+\.|\Z)" - match = re.search(article_pattern, content, re.DOTALL) +def extract_full_article(content, article_number): + pattern = rf"Art\.\s*{article_number}\..*?(?=Art\.\s*\d+\.|\Z)" + match = re.search(pattern, content, re.DOTALL) if match: return match.group(0).strip() - + return None + +def extract_relevant_fragment(content, query, context_size=100): + article_match = re.match(r"Art\.\s*(\d+)", query) + if article_match: + article_number = article_match.group(1) + full_article = extract_full_article(content, article_number) + if full_article: + return full_article + index = content.lower().find(query.lower()) if index != -1: start = max(0, index - context_size) end = min(len(content), index + len(query) + context_size) return f"...{content[start:end]}..." - return content[:400] + "..." + return content[:200] + "..." def hybrid_search(keywords, limit=5, alpha=0.5): if isinstance(keywords, str): keywords = [keywords] + + query = " ".join(keywords) - all_results = [] - for keyword in keywords: - print(f"\nWyszukiwanie hybrydowe dla słowa kluczowego: '{keyword}'") - response = collection.query.hybrid( - query=keyword, - alpha=alpha, - limit=limit * 2 - ) + print(f"\nWyszukiwanie hybrydowe dla słowa kluczowego: '{query}'") + response = collection.query.hybrid( + query=query, + alpha=alpha, + limit=limit * 2 + ) + + results = [] + + for obj in response.objects: + #print(f"UUID: {obj.uuid}") + relevant_fragment = extract_relevant_fragment(obj.properties['content'], query) + #print(f"Relewantny fragment:\n{relevant_fragment}") + #print(f"Nazwa pliku: {obj.properties['fileName']}") + #print("---") + # Zmieniamy warunek na 'any' zamiast 'all' + #if any(term.lower() in relevant_fragment.lower() for term in keywords): + results.append({ + "uuid": obj.uuid, + "relevant_fragment": relevant_fragment, + "file_name": obj.properties['fileName'], + "keyword": query + }) + print(f"Dodano do wyników: {obj.uuid}") - for obj in response.objects: - relevant_fragment = extract_relevant_fragment(obj.properties['content'], keyword) - if keyword.lower() in relevant_fragment.lower(): - result = { - "uuid": obj.uuid, - "relevant_fragment": relevant_fragment, - "file_name": obj.properties['fileName'], - "keyword": keyword - } - if result not in all_results: - all_results.append(result) - print(f"UUID: {obj.uuid}") - print(f"Relewantny fragment:\n{relevant_fragment}") - print(f"Nazwa pliku: {obj.properties['fileName']}") - print("---") - - if len(all_results) >= limit: - break - - if len(all_results) >= limit: + if len(results) >= limit: break - - return all_results[:limit] + return results[:limit] -class ChatRequest(BaseModel): - model: str - messages: list[dict] - query: str +@app.get("/api/tags") +async def tags_proxy(): + async with httpx.AsyncClient() as client: + response = await client.get(f"{OLLAMA_BASE_URL}/api/tags") + return response.json() -class ChatResponse(BaseModel): - content: str - weaviate_results: list +@app.get("/api/version") +async def tags_proxy(): + async with httpx.AsyncClient() as client: + response = await client.get(f"{OLLAMA_BASE_URL}/api/version") + return response.json() -@app.post("/chat", response_model=ChatResponse) -async def chat_endpoint(request: ChatRequest): - try: - keywords = analyze_query(request.query) - weaviate_results = hybrid_search(keywords) - - if not weaviate_results: - response = ollama_client.chat( - model=request.model, - messages=[{"role": "user", "content": f"Nie znalazłem informacji na temat: {request.query}. Proszę poinformuj użytkownika, że nie masz wystarczającej wiedzy, aby udzielić jednoznacznej odpowiedzi."}] - ) - else: - context = "Znalezione informacje:\n" - for item in weaviate_results: - context += f"Źródło: {item['file_name']}\nFragment: {item['relevant_fragment']}\n\n" - - response = ollama_client.chat( - model=request.model, - messages=[ - {"role": "system", "content": context}, - {"role": "user", "content": f"Na podstawie powyższych informacji, odpowiedz na pytanie: {request.query}. Odwołaj się do konkretnych artykułów lub zacytuj fragmenty źródeł."} - ] - ) +@app.post("/api/generate") +async def generate_proxy(request: Request): + data = await request.json() + async with httpx.AsyncClient() as client: + response = await client.post(f"{OLLAMA_BASE_URL}/api/generate", json=data) + return response.json() - return ChatResponse( - content=response['message']['content'], - weaviate_results=weaviate_results - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - -@app.get("/models") +@app.get("/api/models") async def list_models(): try: models = ollama_client.list() @@ -176,5 +165,71 @@ async def list_models(): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) +async def stream_chat(model, messages, options): + try: + # Użycie httpx do asynchronicznego pobrania danych od Ollamy + async with httpx.AsyncClient() as client: + async with client.stream( + "POST", + f"{OLLAMA_BASE_URL}/api/chat", + json={"model": model, "messages": messages, "stream": True, "options": options}, + ) as response: + async for line in response.aiter_lines(): + yield line + "\n" + except Exception as e: + yield json.dumps({"error": str(e)}) + "\n" + +@app.post("/api/chat") +async def chat_endpoint(request: ChatRequest): + try: + query = request.messages[-1].content if request.messages else "" + keywords = analyze_query(query) + weaviate_results = hybrid_search(keywords) + + if not weaviate_results: + context = f""" + Nie znalazłem informacji na temat: {query}. + Proszę poinformuj użytkownika, że nie masz wystarczającej wiedzy, aby udzielić jednoznacznej odpowiedzi. + """ + else: + context = "Znalezione informacje:\n" + for item in weaviate_results: + context += f"Źródło: {item['file_name']}\nFragment: {item['relevant_fragment']}\n\n" + + messages_with_context =[ + {"role": "system", "content": context}, + {"role": "user", "content": f""" + Na podstawie powyższych informacji, odpowiedz na pytanie: {query}. + Odwołaj się do konkretnych artykułów lub zacytuj fragmenty źródeł. + """} + ] + + if request.stream: + return StreamingResponse(stream_chat(request.model, messages_with_context, request.options), media_type="application/json") + + ollama_response = ollama_client.chat( + model=request.model, + messages=messages_with_context, + stream=False, + options=request.options + ) + return ChatResponse( + model=request.model, + created_at=ollama_response.get('created_at', ''), + message=Message( + role=ollama_response['message']['role'], + content=ollama_response['message']['content'] + ), + done=ollama_response.get('done', True), + total_duration=ollama_response.get('total_duration', 0), + load_duration=ollama_response.get('load_duration', 0), + prompt_eval_count=ollama_response.get('prompt_eval_count', 0), + prompt_eval_duration=ollama_response.get('prompt_eval_duration', 0), + eval_count=ollama_response.get('eval_count', 0), + eval_duration=ollama_response.get('eval_duration', 0) + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + if __name__ == "__main__": - uvicorn.run(app, host="0.0.0.0", port=8000) + uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 3557097..8e207d7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ fastapi uvicorn ollama -weaviate-client \ No newline at end of file +weaviate-client +unidecode \ No newline at end of file