code ready to deploy
This commit is contained in:
parent
10f40c54ee
commit
ad08ce74ce
|
|
@ -1,11 +1,16 @@
|
|||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
import ollama
|
||||
import weaviate
|
||||
from weaviate.connect import ConnectionParams
|
||||
from weaviate.collections.classes.filters import Filter
|
||||
import re
|
||||
import json
|
||||
import uvicorn
|
||||
import httpx
|
||||
from typing import List, Optional
|
||||
import asyncio
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
|
@ -25,9 +30,30 @@ weaviate_client = weaviate.WeaviateClient(
|
|||
)
|
||||
)
|
||||
weaviate_client.connect()
|
||||
# Pobierz kolekcję
|
||||
collection = weaviate_client.collections.get("Document")
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Message]
|
||||
stream: Optional[bool] = False
|
||||
options: Optional[dict] = None
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
model: str
|
||||
created_at: str
|
||||
message: Message
|
||||
done: bool
|
||||
total_duration: int
|
||||
load_duration: int
|
||||
prompt_eval_count: int
|
||||
prompt_eval_duration: int
|
||||
eval_count: int
|
||||
eval_duration: int
|
||||
|
||||
prompt = """
|
||||
Jesteś precyzyjnym narzędziem do generowania słów kluczowych z zakresu BHP i prawa pracy. Twoje zadanie to podanie WYŁĄCZNIE najistotniejszych słów do wyszukiwania w bazie dokumentów prawnych.
|
||||
|
||||
|
|
@ -55,120 +81,83 @@ def analyze_query(query):
|
|||
print("Słowa kluczowe:", keywords)
|
||||
return keywords
|
||||
|
||||
def extract_relevant_fragment(content, query, context_size=200):
|
||||
article_match = re.match(r'Art\.\s*(\d+)', query)
|
||||
if article_match:
|
||||
article_number = article_match.group(1)
|
||||
article_pattern = rf"Art\.\s*{article_number}\..*?(?=Art\.\s*\d+\.|\Z)"
|
||||
match = re.search(article_pattern, content, re.DOTALL)
|
||||
if match:
|
||||
return match.group(0).strip()
|
||||
|
||||
index = content.lower().find(query.lower())
|
||||
if index != -1:
|
||||
start = max(0, index - context_size)
|
||||
end = min(len(content), index + len(query) + context_size)
|
||||
return f"...{content[start:end]}..."
|
||||
return content[:400] + "..."
|
||||
|
||||
def expand_query(keywords):
|
||||
expansions = {}
|
||||
expanded_terms = keywords.copy()
|
||||
for keyword in keywords:
|
||||
expanded_terms.extend(expansions.get(keyword.lower(), []))
|
||||
return " ".join(set(expanded_terms))
|
||||
|
||||
def extract_relevant_fragment(content, query, context_size=200):
|
||||
article_pattern = r"Art\.\s*154\..*?(?=Art\.\s*\d+\.|\Z)"
|
||||
match = re.search(article_pattern, content, re.DOTALL)
|
||||
def extract_full_article(content, article_number):
|
||||
pattern = rf"Art\.\s*{article_number}\..*?(?=Art\.\s*\d+\.|\Z)"
|
||||
match = re.search(pattern, content, re.DOTALL)
|
||||
if match:
|
||||
return match.group(0).strip()
|
||||
|
||||
return None
|
||||
|
||||
def extract_relevant_fragment(content, query, context_size=100):
|
||||
article_match = re.match(r"Art\.\s*(\d+)", query)
|
||||
if article_match:
|
||||
article_number = article_match.group(1)
|
||||
full_article = extract_full_article(content, article_number)
|
||||
if full_article:
|
||||
return full_article
|
||||
|
||||
index = content.lower().find(query.lower())
|
||||
if index != -1:
|
||||
start = max(0, index - context_size)
|
||||
end = min(len(content), index + len(query) + context_size)
|
||||
return f"...{content[start:end]}..."
|
||||
return content[:400] + "..."
|
||||
return content[:200] + "..."
|
||||
|
||||
def hybrid_search(keywords, limit=5, alpha=0.5):
|
||||
if isinstance(keywords, str):
|
||||
keywords = [keywords]
|
||||
|
||||
query = " ".join(keywords)
|
||||
|
||||
all_results = []
|
||||
for keyword in keywords:
|
||||
print(f"\nWyszukiwanie hybrydowe dla słowa kluczowego: '{keyword}'")
|
||||
response = collection.query.hybrid(
|
||||
query=keyword,
|
||||
alpha=alpha,
|
||||
limit=limit * 2
|
||||
)
|
||||
print(f"\nWyszukiwanie hybrydowe dla słowa kluczowego: '{query}'")
|
||||
response = collection.query.hybrid(
|
||||
query=query,
|
||||
alpha=alpha,
|
||||
limit=limit * 2
|
||||
)
|
||||
|
||||
results = []
|
||||
|
||||
for obj in response.objects:
|
||||
#print(f"UUID: {obj.uuid}")
|
||||
relevant_fragment = extract_relevant_fragment(obj.properties['content'], query)
|
||||
#print(f"Relewantny fragment:\n{relevant_fragment}")
|
||||
#print(f"Nazwa pliku: {obj.properties['fileName']}")
|
||||
#print("---")
|
||||
# Zmieniamy warunek na 'any' zamiast 'all'
|
||||
#if any(term.lower() in relevant_fragment.lower() for term in keywords):
|
||||
results.append({
|
||||
"uuid": obj.uuid,
|
||||
"relevant_fragment": relevant_fragment,
|
||||
"file_name": obj.properties['fileName'],
|
||||
"keyword": query
|
||||
})
|
||||
print(f"Dodano do wyników: {obj.uuid}")
|
||||
|
||||
for obj in response.objects:
|
||||
relevant_fragment = extract_relevant_fragment(obj.properties['content'], keyword)
|
||||
if keyword.lower() in relevant_fragment.lower():
|
||||
result = {
|
||||
"uuid": obj.uuid,
|
||||
"relevant_fragment": relevant_fragment,
|
||||
"file_name": obj.properties['fileName'],
|
||||
"keyword": keyword
|
||||
}
|
||||
if result not in all_results:
|
||||
all_results.append(result)
|
||||
print(f"UUID: {obj.uuid}")
|
||||
print(f"Relewantny fragment:\n{relevant_fragment}")
|
||||
print(f"Nazwa pliku: {obj.properties['fileName']}")
|
||||
print("---")
|
||||
|
||||
if len(all_results) >= limit:
|
||||
break
|
||||
|
||||
if len(all_results) >= limit:
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
return all_results[:limit]
|
||||
return results[:limit]
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
model: str
|
||||
messages: list[dict]
|
||||
query: str
|
||||
@app.get("/api/tags")
|
||||
async def tags_proxy():
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{OLLAMA_BASE_URL}/api/tags")
|
||||
return response.json()
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
content: str
|
||||
weaviate_results: list
|
||||
@app.get("/api/version")
|
||||
async def tags_proxy():
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{OLLAMA_BASE_URL}/api/version")
|
||||
return response.json()
|
||||
|
||||
@app.post("/chat", response_model=ChatResponse)
|
||||
async def chat_endpoint(request: ChatRequest):
|
||||
try:
|
||||
keywords = analyze_query(request.query)
|
||||
weaviate_results = hybrid_search(keywords)
|
||||
|
||||
if not weaviate_results:
|
||||
response = ollama_client.chat(
|
||||
model=request.model,
|
||||
messages=[{"role": "user", "content": f"Nie znalazłem informacji na temat: {request.query}. Proszę poinformuj użytkownika, że nie masz wystarczającej wiedzy, aby udzielić jednoznacznej odpowiedzi."}]
|
||||
)
|
||||
else:
|
||||
context = "Znalezione informacje:\n"
|
||||
for item in weaviate_results:
|
||||
context += f"Źródło: {item['file_name']}\nFragment: {item['relevant_fragment']}\n\n"
|
||||
|
||||
response = ollama_client.chat(
|
||||
model=request.model,
|
||||
messages=[
|
||||
{"role": "system", "content": context},
|
||||
{"role": "user", "content": f"Na podstawie powyższych informacji, odpowiedz na pytanie: {request.query}. Odwołaj się do konkretnych artykułów lub zacytuj fragmenty źródeł."}
|
||||
]
|
||||
)
|
||||
@app.post("/api/generate")
|
||||
async def generate_proxy(request: Request):
|
||||
data = await request.json()
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(f"{OLLAMA_BASE_URL}/api/generate", json=data)
|
||||
return response.json()
|
||||
|
||||
return ChatResponse(
|
||||
content=response['message']['content'],
|
||||
weaviate_results=weaviate_results
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/models")
|
||||
@app.get("/api/models")
|
||||
async def list_models():
|
||||
try:
|
||||
models = ollama_client.list()
|
||||
|
|
@ -176,5 +165,71 @@ async def list_models():
|
|||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def stream_chat(model, messages, options):
|
||||
try:
|
||||
# Użycie httpx do asynchronicznego pobrania danych od Ollamy
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{OLLAMA_BASE_URL}/api/chat",
|
||||
json={"model": model, "messages": messages, "stream": True, "options": options},
|
||||
) as response:
|
||||
async for line in response.aiter_lines():
|
||||
yield line + "\n"
|
||||
except Exception as e:
|
||||
yield json.dumps({"error": str(e)}) + "\n"
|
||||
|
||||
@app.post("/api/chat")
|
||||
async def chat_endpoint(request: ChatRequest):
|
||||
try:
|
||||
query = request.messages[-1].content if request.messages else ""
|
||||
keywords = analyze_query(query)
|
||||
weaviate_results = hybrid_search(keywords)
|
||||
|
||||
if not weaviate_results:
|
||||
context = f"""
|
||||
Nie znalazłem informacji na temat: {query}.
|
||||
Proszę poinformuj użytkownika, że nie masz wystarczającej wiedzy, aby udzielić jednoznacznej odpowiedzi.
|
||||
"""
|
||||
else:
|
||||
context = "Znalezione informacje:\n"
|
||||
for item in weaviate_results:
|
||||
context += f"Źródło: {item['file_name']}\nFragment: {item['relevant_fragment']}\n\n"
|
||||
|
||||
messages_with_context =[
|
||||
{"role": "system", "content": context},
|
||||
{"role": "user", "content": f"""
|
||||
Na podstawie powyższych informacji, odpowiedz na pytanie: {query}.
|
||||
Odwołaj się do konkretnych artykułów lub zacytuj fragmenty źródeł.
|
||||
"""}
|
||||
]
|
||||
|
||||
if request.stream:
|
||||
return StreamingResponse(stream_chat(request.model, messages_with_context, request.options), media_type="application/json")
|
||||
|
||||
ollama_response = ollama_client.chat(
|
||||
model=request.model,
|
||||
messages=messages_with_context,
|
||||
stream=False,
|
||||
options=request.options
|
||||
)
|
||||
return ChatResponse(
|
||||
model=request.model,
|
||||
created_at=ollama_response.get('created_at', ''),
|
||||
message=Message(
|
||||
role=ollama_response['message']['role'],
|
||||
content=ollama_response['message']['content']
|
||||
),
|
||||
done=ollama_response.get('done', True),
|
||||
total_duration=ollama_response.get('total_duration', 0),
|
||||
load_duration=ollama_response.get('load_duration', 0),
|
||||
prompt_eval_count=ollama_response.get('prompt_eval_count', 0),
|
||||
prompt_eval_duration=ollama_response.get('prompt_eval_duration', 0),
|
||||
eval_count=ollama_response.get('eval_count', 0),
|
||||
eval_duration=ollama_response.get('eval_duration', 0)
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
fastapi
|
||||
uvicorn
|
||||
ollama
|
||||
weaviate-client
|
||||
weaviate-client
|
||||
unidecode
|
||||
Loading…
Reference in New Issue