diff --git a/gemma.py b/gemma.py index 5742741..aeb8eed 100644 --- a/gemma.py +++ b/gemma.py @@ -1,13 +1,12 @@ import torch -import chromadb +import faiss +import numpy as np from sentence_transformers import SentenceTransformer from datasets import Dataset from peft import LoraConfig, get_peft_model from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer -# 1️⃣ Inicjalizacja ChromaDB i modelu do embeddingów -chroma_client = chromadb.PersistentClient(path="./chroma_db") -collection = chroma_client.get_or_create_collection("my_embeddings") +# 1️⃣ Inicjalizacja modelu do embeddingów embed_model = SentenceTransformer("all-MiniLM-L6-v2") # 2️⃣ Dodanie dokumentów i embeddingów @@ -17,40 +16,43 @@ documents = [ "Procedura składania reklamacji w e-sklepie.", "Jakie dokumenty są potrzebne do rejestracji działalności?" ] -embeddings = embed_model.encode(documents).tolist() +embeddings = embed_model.encode(documents) -for i, (doc, emb) in enumerate(zip(documents, embeddings)): - collection.add(ids=[str(i)], documents=[doc], embeddings=[emb]) +# 3️⃣ Inicjalizacja FAISS i dodanie wektorów +dim = embeddings.shape[1] # Wymiary wektorów +index = faiss.IndexFlatL2(dim) # Tworzymy indeks FAISS dla metryki L2 +index.add(np.array(embeddings, dtype=np.float32)) # Dodajemy wektory do indeksu FAISS -# 3️⃣ Przygotowanie danych treningowych +# 4️⃣ Przygotowanie danych treningowych def create_training_data(): - data = collection.get(include=["documents", "embeddings"]) - return Dataset.from_dict({ - "text": data["documents"], - "embedding": data["embeddings"] - }) + # Pobranie dokumentów (możesz połączyć je z odpowiednimi embeddingami, jeśli trzeba) + data = { + "text": documents, + "embedding": embeddings.tolist() + } + return Dataset.from_dict(data) dataset = create_training_data() -# 4️⃣ Ładowanie modelu Gemma 2 7B +# 5️⃣ Ładowanie modelu Gemma 2 7B device = "cuda" if torch.cuda.is_available() else "cpu" model_name = "google/gemma-7b" model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device) tokenizer = AutoTokenizer.from_pretrained(model_name) -# 5️⃣ Konfiguracja LoRA dla efektywnego treningu +# 6️⃣ Konfiguracja LoRA dla efektywnego treningu lora_config = LoraConfig( r=8, lora_alpha=32, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM" ) model = get_peft_model(model, lora_config) -# 6️⃣ Tokenizacja danych +# 7️⃣ Tokenizacja danych def tokenize_function(examples): return tokenizer(examples["text"], padding="max_length", truncation=True) tokenized_dataset = dataset.map(tokenize_function, batched=True) -# 7️⃣ Parametry treningu +# 8️⃣ Parametry treningu training_args = TrainingArguments( output_dir="./results", per_device_train_batch_size=2, @@ -59,7 +61,7 @@ training_args = TrainingArguments( save_strategy="epoch" ) -# 8️⃣ Trening modelu +# 9️⃣ Trening modelu trainer = Trainer( model=model, args=training_args, @@ -68,8 +70,8 @@ trainer = Trainer( trainer.train() -# 9️⃣ Zapisanie dostrojonego modelu +# 🔟 Zapisanie dostrojonego modelu model.save_pretrained("./trained_model/gemma") tokenizer.save_pretrained("./trained_model/gemma") -print("✅ Model został wytrenowany i zapisany!") +print("✅ Model został wytrenowany i zapisany!") \ No newline at end of file