code ready to deploy

This commit is contained in:
l.gabrysiak 2025-02-27 20:31:55 +01:00
parent 10f40c54ee
commit ad08ce74ce
2 changed files with 156 additions and 100 deletions

View File

@ -1,11 +1,16 @@
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import ollama
import weaviate
from weaviate.connect import ConnectionParams
from weaviate.collections.classes.filters import Filter
import re
import json
import uvicorn
import httpx
from typing import List, Optional
import asyncio
app = FastAPI()
@ -25,9 +30,30 @@ weaviate_client = weaviate.WeaviateClient(
)
)
weaviate_client.connect()
# Pobierz kolekcję
collection = weaviate_client.collections.get("Document")
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: str
messages: List[Message]
stream: Optional[bool] = False
options: Optional[dict] = None
class ChatResponse(BaseModel):
model: str
created_at: str
message: Message
done: bool
total_duration: int
load_duration: int
prompt_eval_count: int
prompt_eval_duration: int
eval_count: int
eval_duration: int
prompt = """
Jesteś precyzyjnym narzędziem do generowania słów kluczowych z zakresu BHP i prawa pracy. Twoje zadanie to podanie WYŁĄCZNIE najistotniejszych słów do wyszukiwania w bazie dokumentów prawnych.
@ -55,120 +81,83 @@ def analyze_query(query):
print("Słowa kluczowe:", keywords)
return keywords
def extract_relevant_fragment(content, query, context_size=200):
article_match = re.match(r'Art\.\s*(\d+)', query)
if article_match:
article_number = article_match.group(1)
article_pattern = rf"Art\.\s*{article_number}\..*?(?=Art\.\s*\d+\.|\Z)"
match = re.search(article_pattern, content, re.DOTALL)
if match:
return match.group(0).strip()
index = content.lower().find(query.lower())
if index != -1:
start = max(0, index - context_size)
end = min(len(content), index + len(query) + context_size)
return f"...{content[start:end]}..."
return content[:400] + "..."
def expand_query(keywords):
expansions = {}
expanded_terms = keywords.copy()
for keyword in keywords:
expanded_terms.extend(expansions.get(keyword.lower(), []))
return " ".join(set(expanded_terms))
def extract_relevant_fragment(content, query, context_size=200):
article_pattern = r"Art\.\s*154\..*?(?=Art\.\s*\d+\.|\Z)"
match = re.search(article_pattern, content, re.DOTALL)
def extract_full_article(content, article_number):
pattern = rf"Art\.\s*{article_number}\..*?(?=Art\.\s*\d+\.|\Z)"
match = re.search(pattern, content, re.DOTALL)
if match:
return match.group(0).strip()
return None
def extract_relevant_fragment(content, query, context_size=100):
article_match = re.match(r"Art\.\s*(\d+)", query)
if article_match:
article_number = article_match.group(1)
full_article = extract_full_article(content, article_number)
if full_article:
return full_article
index = content.lower().find(query.lower())
if index != -1:
start = max(0, index - context_size)
end = min(len(content), index + len(query) + context_size)
return f"...{content[start:end]}..."
return content[:400] + "..."
return content[:200] + "..."
def hybrid_search(keywords, limit=5, alpha=0.5):
if isinstance(keywords, str):
keywords = [keywords]
query = " ".join(keywords)
all_results = []
for keyword in keywords:
print(f"\nWyszukiwanie hybrydowe dla słowa kluczowego: '{keyword}'")
response = collection.query.hybrid(
query=keyword,
alpha=alpha,
limit=limit * 2
)
print(f"\nWyszukiwanie hybrydowe dla słowa kluczowego: '{query}'")
response = collection.query.hybrid(
query=query,
alpha=alpha,
limit=limit * 2
)
results = []
for obj in response.objects:
#print(f"UUID: {obj.uuid}")
relevant_fragment = extract_relevant_fragment(obj.properties['content'], query)
#print(f"Relewantny fragment:\n{relevant_fragment}")
#print(f"Nazwa pliku: {obj.properties['fileName']}")
#print("---")
# Zmieniamy warunek na 'any' zamiast 'all'
#if any(term.lower() in relevant_fragment.lower() for term in keywords):
results.append({
"uuid": obj.uuid,
"relevant_fragment": relevant_fragment,
"file_name": obj.properties['fileName'],
"keyword": query
})
print(f"Dodano do wyników: {obj.uuid}")
for obj in response.objects:
relevant_fragment = extract_relevant_fragment(obj.properties['content'], keyword)
if keyword.lower() in relevant_fragment.lower():
result = {
"uuid": obj.uuid,
"relevant_fragment": relevant_fragment,
"file_name": obj.properties['fileName'],
"keyword": keyword
}
if result not in all_results:
all_results.append(result)
print(f"UUID: {obj.uuid}")
print(f"Relewantny fragment:\n{relevant_fragment}")
print(f"Nazwa pliku: {obj.properties['fileName']}")
print("---")
if len(all_results) >= limit:
break
if len(all_results) >= limit:
if len(results) >= limit:
break
return all_results[:limit]
return results[:limit]
class ChatRequest(BaseModel):
model: str
messages: list[dict]
query: str
@app.get("/api/tags")
async def tags_proxy():
async with httpx.AsyncClient() as client:
response = await client.get(f"{OLLAMA_BASE_URL}/api/tags")
return response.json()
class ChatResponse(BaseModel):
content: str
weaviate_results: list
@app.get("/api/version")
async def tags_proxy():
async with httpx.AsyncClient() as client:
response = await client.get(f"{OLLAMA_BASE_URL}/api/version")
return response.json()
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(request: ChatRequest):
try:
keywords = analyze_query(request.query)
weaviate_results = hybrid_search(keywords)
if not weaviate_results:
response = ollama_client.chat(
model=request.model,
messages=[{"role": "user", "content": f"Nie znalazłem informacji na temat: {request.query}. Proszę poinformuj użytkownika, że nie masz wystarczającej wiedzy, aby udzielić jednoznacznej odpowiedzi."}]
)
else:
context = "Znalezione informacje:\n"
for item in weaviate_results:
context += f"Źródło: {item['file_name']}\nFragment: {item['relevant_fragment']}\n\n"
response = ollama_client.chat(
model=request.model,
messages=[
{"role": "system", "content": context},
{"role": "user", "content": f"Na podstawie powyższych informacji, odpowiedz na pytanie: {request.query}. Odwołaj się do konkretnych artykułów lub zacytuj fragmenty źródeł."}
]
)
@app.post("/api/generate")
async def generate_proxy(request: Request):
data = await request.json()
async with httpx.AsyncClient() as client:
response = await client.post(f"{OLLAMA_BASE_URL}/api/generate", json=data)
return response.json()
return ChatResponse(
content=response['message']['content'],
weaviate_results=weaviate_results
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/models")
@app.get("/api/models")
async def list_models():
try:
models = ollama_client.list()
@ -176,5 +165,71 @@ async def list_models():
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
async def stream_chat(model, messages, options):
try:
# Użycie httpx do asynchronicznego pobrania danych od Ollamy
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
f"{OLLAMA_BASE_URL}/api/chat",
json={"model": model, "messages": messages, "stream": True, "options": options},
) as response:
async for line in response.aiter_lines():
yield line + "\n"
except Exception as e:
yield json.dumps({"error": str(e)}) + "\n"
@app.post("/api/chat")
async def chat_endpoint(request: ChatRequest):
try:
query = request.messages[-1].content if request.messages else ""
keywords = analyze_query(query)
weaviate_results = hybrid_search(keywords)
if not weaviate_results:
context = f"""
Nie znalazłem informacji na temat: {query}.
Proszę poinformuj użytkownika, że nie masz wystarczającej wiedzy, aby udzielić jednoznacznej odpowiedzi.
"""
else:
context = "Znalezione informacje:\n"
for item in weaviate_results:
context += f"Źródło: {item['file_name']}\nFragment: {item['relevant_fragment']}\n\n"
messages_with_context =[
{"role": "system", "content": context},
{"role": "user", "content": f"""
Na podstawie powyższych informacji, odpowiedz na pytanie: {query}.
Odwołaj się do konkretnych artykułów lub zacytuj fragmenty źródeł.
"""}
]
if request.stream:
return StreamingResponse(stream_chat(request.model, messages_with_context, request.options), media_type="application/json")
ollama_response = ollama_client.chat(
model=request.model,
messages=messages_with_context,
stream=False,
options=request.options
)
return ChatResponse(
model=request.model,
created_at=ollama_response.get('created_at', ''),
message=Message(
role=ollama_response['message']['role'],
content=ollama_response['message']['content']
),
done=ollama_response.get('done', True),
total_duration=ollama_response.get('total_duration', 0),
load_duration=ollama_response.get('load_duration', 0),
prompt_eval_count=ollama_response.get('prompt_eval_count', 0),
prompt_eval_duration=ollama_response.get('prompt_eval_duration', 0),
eval_count=ollama_response.get('eval_count', 0),
eval_duration=ollama_response.get('eval_duration', 0)
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@ -1,4 +1,5 @@
fastapi
uvicorn
ollama
weaviate-client
weaviate-client
unidecode