ably.do/gemma-faiss.py

62 lines
2.4 KiB
Python
Raw Normal View History

2025-02-26 09:30:23 -05:00
import numpy as np
import faiss
import gradio as gr
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
import ollama
# 1⃣ Inicjalizacja modelu do embeddingów
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
# 2⃣ Dodanie dokumentów i embeddingów
def read_documents_from_file(file_path):
with open(file_path, 'r', encoding='utf-8') as file:
content = file.read()
articles = content.split('\n\n')
documents = []
for article in articles:
if article.strip().startswith('Art.'):
documents.append(article.strip())
return documents
#documents = [
# "Jak założyć firmę w Polsce?",
# "Jak rozliczyć podatek VAT?",
# "Procedura składania reklamacji w e-sklepie.",
# "Jakie dokumenty są potrzebne do rejestracji działalności?"
#]
file_path = './docs/kodekspracy.txt' # Zmień na właściwą ścieżkę
documents = read_documents_from_file(file_path)
# 3⃣ Wygenerowanie embeddingów
embeddings = embed_model.encode(documents)
# 4⃣ Inicjalizacja FAISS
dim = embeddings.shape[1] # Wymiar embeddingu
index = faiss.IndexFlatL2(dim)
index.add(np.array(embeddings, dtype=np.float32))
# 5⃣ Wczytanie modelu Ollama (Gemma 2)
model_name = "./trained_model/gemma/Gemma-F16-LoRA.gguf" # Ścieżka do modelu w systemie
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
# 6⃣ Funkcja wyszukiwania w FAISS
def search(query, k=5):
query_embedding = embed_model.encode([query]) # Przekształć zapytanie w embedding
_, indices = index.search(np.array(query_embedding, dtype=np.float32), k) # Szukaj w indeksie FAISS
return indices # Zwróć indeksy najbardziej podobnych dokumentów
# 7⃣ Funkcja generowania odpowiedzi z Ollama
def generate_response(query):
indices = search(query) # Znajdź najbardziej podobne dokumenty
relevant_documents = [documents[i] for i in indices[0]] # Pobierz dokumenty na podstawie wyników wyszukiwania
prompt = " ".join(relevant_documents) + " " + query # Przygotuj prompt
response = ollama.chat(model=model_name, messages=[{"role": "user", "content": prompt}]) # Generowanie odpowiedzi przez Ollama
return response["text"]
# 8⃣ Interfejs Gradio (Open-WebUI)
iface = gr.Interface(fn=generate_response, inputs="text", outputs="text")
iface.launch() # Uruchom interfejs