diff --git a/allegro.py b/allegro.py index e2aa098..e0678ab 100644 --- a/allegro.py +++ b/allegro.py @@ -1,16 +1,16 @@ import os -os.environ["TOKENIZERS_PARALLELISM"] = "false" - import torch import weaviate -from weaviate.classes.config import Property, DataType, Configure -from weaviate.classes.query import Query -from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForSeq2Seq +import numpy as np from datasets import Dataset +from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq +from weaviate.classes.config import Property, DataType, Configure +from weaviate.classes.client import WeaviateClient +from weaviate.classes.config import ConnectionParams -# 1️⃣ Połączenie z bazą Weaviate -client = weaviate.WeaviateClient( - connection_params=weaviate.ConnectionParams.from_params( +# 1️⃣ Połączenie z Weaviate +client = WeaviateClient( + connection_params=ConnectionParams.from_params( http_host="weaviate", http_port=8080, http_secure=False, @@ -20,44 +20,42 @@ client = weaviate.WeaviateClient( ) ) -# 2️⃣ Pobranie dokumentów z bazy Weaviate -collection_name = "Documents" -query = Query(collection_name).limit(1000) -result = client.query.run(query) +# 2️⃣ Pobranie dokumentów z Weaviate +def fetch_documents(): + query = client.query.get("Document", ["content", "fileName"]).do() + documents = [] + for item in query["data"]["Get"]["Document"]: + file_name = item.get("fileName", "unknown_file") + content = item.get("content", "") + if content: + documents.append(f"fileName: {file_name}, content: {content}") + return documents -documents = [] -file_names = [] +documents = fetch_documents() -for item in result[collection_name]['objects']: - documents.append(item['properties']['content']) - file_names.append(item['properties']['fileName']) +# 3️⃣ Inicjalizacja modelu +model_name = "allegro/multislav-5lang" +device = "cuda" if torch.cuda.is_available() else "cpu" -# 3️⃣ Tworzenie datasetu -training_data = { - "text": documents, - "file_name": file_names -} -dataset = Dataset.from_dict(training_data) +model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) +tokenizer = AutoTokenizer.from_pretrained(model_name) -# Podział na zestaw treningowy i ewaluacyjny +# 4️⃣ Przygotowanie danych treningowych +def create_training_data(): + return Dataset.from_dict({"text": documents}) + +dataset = create_training_data() split_dataset = dataset.train_test_split(test_size=0.25) train_dataset = split_dataset["train"] eval_dataset = split_dataset["test"] -# 4️⃣ Ładowanie modelu Multislav -model_name = "allegro/multislav-5lang" -device = "cuda" if torch.cuda.is_available() else "cpu" -model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) -tokenizer = AutoTokenizer.from_pretrained(model_name) - # 5️⃣ Tokenizacja -max_length = 512 def tokenize_function(examples): return tokenizer( examples["text"], padding="max_length", truncation=True, - max_length=max_length + max_length=512 ) tokenized_train = train_dataset.map(tokenize_function, batched=True) @@ -68,13 +66,13 @@ training_args = TrainingArguments( output_dir="./results", evaluation_strategy="steps", eval_steps=500, - save_strategy="steps", save_steps=500, learning_rate=2e-5, - per_device_train_batch_size=4, - per_device_eval_batch_size=4, - num_train_epochs=10, + per_device_train_batch_size=2, + per_device_eval_batch_size=2, + num_train_epochs=16, weight_decay=0.01, + save_total_limit=2, load_best_model_at_end=True, metric_for_best_model="loss", greater_is_better=False, @@ -95,7 +93,7 @@ trainer = Trainer( trainer.train() # 9️⃣ Zapis modelu -model.save_pretrained("./trained_model/multislav") -tokenizer.save_pretrained("./trained_model/multislav") +model.save_pretrained("./models/allegro") +tokenizer.save_pretrained("./models/allegro") print("✅ Model został wytrenowany i zapisany!")