From 12cef050a2db9ce920c630f14739cb09757658dc Mon Sep 17 00:00:00 2001 From: "l.gabrysiak" Date: Fri, 28 Feb 2025 21:26:21 +0100 Subject: [PATCH] mod allegro --- allegro.py | 78 +++++++++++++++++++++++++++++++++--------------------- 1 file changed, 48 insertions(+), 30 deletions(-) diff --git a/allegro.py b/allegro.py index 2ee14c1..54bc777 100644 --- a/allegro.py +++ b/allegro.py @@ -1,12 +1,17 @@ import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" + import torch -import weaviate -import numpy as np -from datasets import Dataset -from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq +import faiss from weaviate.connect import ConnectionParams import weaviate -import tempfile + +from sentence_transformers import SentenceTransformer +from datasets import Dataset +from peft import LoraConfig, get_peft_model +from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling + +embed_model = SentenceTransformer("all-MiniLM-L6-v2") # 1️⃣ Połączenie z Weaviate client = weaviate.WeaviateClient( @@ -38,67 +43,80 @@ def fetch_documents(): documents = fetch_documents() -# 3️⃣ Inicjalizacja modelu -model_name = "allegro/multislav-5lang" -device = "cuda" if torch.cuda.is_available() else "cpu" +embeddings = embed_model.encode(documents) -model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) -tokenizer = AutoTokenizer.from_pretrained(model_name) +dim = embeddings.shape[1] +index = faiss.IndexFlatL2(dim) +index.add(np.array(embeddings, dtype=np.float32)) -# 4️⃣ Przygotowanie danych treningowych def create_training_data(): - return Dataset.from_dict({"text": documents}) + data = { + "text": documents, + "embedding": embeddings.tolist() + } + return Dataset.from_dict(data) dataset = create_training_data() + split_dataset = dataset.train_test_split(test_size=0.25) train_dataset = split_dataset["train"] eval_dataset = split_dataset["test"] -# 5️⃣ Tokenizacja +device = "cuda" if torch.cuda.is_available() else "cpu" +model_name = "allegro/multislav-5lang" +model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device) +tokenizer = AutoTokenizer.from_pretrained(model_name) + +lora_config = LoraConfig( + r=8, lora_alpha=32, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM" +) +model = get_peft_model(model, lora_config) + +max_length = 384 + def tokenize_function(examples): return tokenizer( examples["text"], padding="max_length", truncation=True, - max_length=512 + max_length=max_length ) tokenized_train = train_dataset.map(tokenize_function, batched=True) tokenized_eval = eval_dataset.map(tokenize_function, batched=True) -# 6️⃣ Parametry treningu training_args = TrainingArguments( output_dir="./results", - evaluation_strategy="steps", - eval_steps=500, - save_steps=500, - learning_rate=2e-5, + eval_strategy="steps", # Ewaluacja co określoną liczbę kroków + eval_steps=500, # Ewaluacja co 500 kroków + save_strategy="steps", # Zapis modelu co określoną liczbę kroków + save_steps=500, # Zapis modelu co 500 kroków + learning_rate=1e-5, per_device_train_batch_size=2, per_device_eval_batch_size=2, num_train_epochs=16, weight_decay=0.01, - save_total_limit=2, - load_best_model_at_end=True, - metric_for_best_model="loss", - greater_is_better=False, + load_best_model_at_end=True, # Wczytaj najlepszy model na końcu + metric_for_best_model="loss", # Kryterium wyboru najlepszego modelu + greater_is_better=False, # Niższy loss = lepszy model ) -# 7️⃣ Data Collator -data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model) +data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False +) -# 8️⃣ Trening trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_train, - eval_dataset=tokenized_eval, + eval_dataset=tokenized_eval, # Dodany zestaw ewaluacyjny data_collator=data_collator, ) trainer.train() -# 9️⃣ Zapis modelu -model.save_pretrained("./models/allegro") -tokenizer.save_pretrained("./models/allegro") +model.save_pretrained("./trained_model/gemma") +tokenizer.save_pretrained("./trained_model/gemma") print("✅ Model został wytrenowany i zapisany!") \ No newline at end of file