From 8e1f346f6ef64819d8f39ef791b27a230b023544 Mon Sep 17 00:00:00 2001 From: "l.gabrysiak" Date: Fri, 28 Feb 2025 20:58:24 +0100 Subject: [PATCH] mod allegro --- allegro.py | 97 ++++++++++++++++++++++-------------------------------- 1 file changed, 39 insertions(+), 58 deletions(-) diff --git a/allegro.py b/allegro.py index 978af6f..e2aa098 100644 --- a/allegro.py +++ b/allegro.py @@ -2,21 +2,15 @@ import os os.environ["TOKENIZERS_PARALLELISM"] = "false" import torch -import numpy as np -from sentence_transformers import SentenceTransformer -from datasets import Dataset -from peft import LoraConfig, get_peft_model -from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForSeq2Seq import weaviate -from weaviate.client import WeaviateClient -from weaviate.connect import ConnectionParams +from weaviate.classes.config import Property, DataType, Configure +from weaviate.classes.query import Query +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForSeq2Seq +from datasets import Dataset -# 1️⃣ Inicjalizacja modelu do embeddingów -embed_model = SentenceTransformer("all-MiniLM-L6-v2") - -# 2️⃣ Połączenie z Weaviate i pobranie dokumentów -client = WeaviateClient( - connection_params=ConnectionParams.from_params( +# 1️⃣ Połączenie z bazą Weaviate +client = weaviate.WeaviateClient( + connection_params=weaviate.ConnectionParams.from_params( http_host="weaviate", http_port=8080, http_secure=False, @@ -26,48 +20,38 @@ client = WeaviateClient( ) ) -collection_name = "Document" # Zakładam, że to jest nazwa Twojej kolekcji -result = ( - client.query.get(collection_name, ["content"]) - .with_additional(["id"]) - .do() -) +# 2️⃣ Pobranie dokumentów z bazy Weaviate +collection_name = "Documents" +query = Query(collection_name).limit(1000) +result = client.query.run(query) -documents = [item['content'] for item in result['data']['Get'][collection_name]] +documents = [] +file_names = [] -# 3️⃣ Generowanie embeddingów -embeddings = embed_model.encode(documents) +for item in result[collection_name]['objects']: + documents.append(item['properties']['content']) + file_names.append(item['properties']['fileName']) -# 4️⃣ Przygotowanie danych treningowych -def create_training_data(): - data = { - "text": documents, - "embedding": embeddings.tolist() - } - return Dataset.from_dict(data) +# 3️⃣ Tworzenie datasetu +training_data = { + "text": documents, + "file_name": file_names +} +dataset = Dataset.from_dict(training_data) -dataset = create_training_data() - -# Podział danych na treningowe i ewaluacyjne +# Podział na zestaw treningowy i ewaluacyjny split_dataset = dataset.train_test_split(test_size=0.25) train_dataset = split_dataset["train"] eval_dataset = split_dataset["test"] -# 5️⃣ Ładowanie modelu allegro/multislav-5lang -device = "cuda" if torch.cuda.is_available() else "cpu" +# 4️⃣ Ładowanie modelu Multislav model_name = "allegro/multislav-5lang" +device = "cuda" if torch.cuda.is_available() else "cpu" model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) tokenizer = AutoTokenizer.from_pretrained(model_name) -# 6️⃣ Konfiguracja LoRA -lora_config = LoraConfig( - r=8, lora_alpha=32, lora_dropout=0.1, bias="none", task_type="SEQ_2_SEQ_LM" -) -model = get_peft_model(model, lora_config) - -# 7️⃣ Tokenizacja danych -max_length = 384 - +# 5️⃣ Tokenizacja +max_length = 512 def tokenize_function(examples): return tokenizer( examples["text"], @@ -79,30 +63,27 @@ def tokenize_function(examples): tokenized_train = train_dataset.map(tokenize_function, batched=True) tokenized_eval = eval_dataset.map(tokenize_function, batched=True) -# 8️⃣ Parametry treningu +# 6️⃣ Parametry treningu training_args = TrainingArguments( output_dir="./results", - eval_strategy="steps", + evaluation_strategy="steps", eval_steps=500, save_strategy="steps", save_steps=500, - learning_rate=1e-5, - per_device_train_batch_size=2, - per_device_eval_batch_size=2, - num_train_epochs=16, + learning_rate=2e-5, + per_device_train_batch_size=4, + per_device_eval_batch_size=4, + num_train_epochs=10, weight_decay=0.01, load_best_model_at_end=True, metric_for_best_model="loss", greater_is_better=False, ) -# 9️⃣ Data Collator -data_collator = DataCollatorForSeq2Seq( - tokenizer=tokenizer, - model=model -) +# 7️⃣ Data Collator +data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model) -# 🔟 Trening modelu +# 8️⃣ Trening trainer = Trainer( model=model, args=training_args, @@ -113,8 +94,8 @@ trainer = Trainer( trainer.train() -# 1️⃣1️⃣ Zapis modelu -model.save_pretrained("./models/allegro") -tokenizer.save_pretrained("./models/allegro") +# 9️⃣ Zapis modelu +model.save_pretrained("./trained_model/multislav") +tokenizer.save_pretrained("./trained_model/multislav") print("✅ Model został wytrenowany i zapisany!")