mod allegro

This commit is contained in:
l.gabrysiak 2025-02-28 20:59:54 +01:00
parent 8e1f346f6e
commit 6a6546a03d
1 changed files with 36 additions and 38 deletions

View File

@ -1,16 +1,16 @@
import os import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import torch import torch
import weaviate import weaviate
from weaviate.classes.config import Property, DataType, Configure import numpy as np
from weaviate.classes.query import Query
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForSeq2Seq
from datasets import Dataset from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq
from weaviate.classes.config import Property, DataType, Configure
from weaviate.classes.client import WeaviateClient
from weaviate.classes.config import ConnectionParams
# 1⃣ Połączenie z bazą Weaviate # 1⃣ Połączenie z Weaviate
client = weaviate.WeaviateClient( client = WeaviateClient(
connection_params=weaviate.ConnectionParams.from_params( connection_params=ConnectionParams.from_params(
http_host="weaviate", http_host="weaviate",
http_port=8080, http_port=8080,
http_secure=False, http_secure=False,
@ -20,44 +20,42 @@ client = weaviate.WeaviateClient(
) )
) )
# 2⃣ Pobranie dokumentów z bazy Weaviate # 2⃣ Pobranie dokumentów z Weaviate
collection_name = "Documents" def fetch_documents():
query = Query(collection_name).limit(1000) query = client.query.get("Document", ["content", "fileName"]).do()
result = client.query.run(query) documents = []
for item in query["data"]["Get"]["Document"]:
file_name = item.get("fileName", "unknown_file")
content = item.get("content", "")
if content:
documents.append(f"fileName: {file_name}, content: {content}")
return documents
documents = [] documents = fetch_documents()
file_names = []
for item in result[collection_name]['objects']: # 3⃣ Inicjalizacja modelu
documents.append(item['properties']['content']) model_name = "allegro/multislav-5lang"
file_names.append(item['properties']['fileName']) device = "cuda" if torch.cuda.is_available() else "cpu"
# 3⃣ Tworzenie datasetu model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
training_data = { tokenizer = AutoTokenizer.from_pretrained(model_name)
"text": documents,
"file_name": file_names
}
dataset = Dataset.from_dict(training_data)
# Podział na zestaw treningowy i ewaluacyjny # 4⃣ Przygotowanie danych treningowych
def create_training_data():
return Dataset.from_dict({"text": documents})
dataset = create_training_data()
split_dataset = dataset.train_test_split(test_size=0.25) split_dataset = dataset.train_test_split(test_size=0.25)
train_dataset = split_dataset["train"] train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"] eval_dataset = split_dataset["test"]
# 4⃣ Ładowanie modelu Multislav
model_name = "allegro/multislav-5lang"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 5⃣ Tokenizacja # 5⃣ Tokenizacja
max_length = 512
def tokenize_function(examples): def tokenize_function(examples):
return tokenizer( return tokenizer(
examples["text"], examples["text"],
padding="max_length", padding="max_length",
truncation=True, truncation=True,
max_length=max_length max_length=512
) )
tokenized_train = train_dataset.map(tokenize_function, batched=True) tokenized_train = train_dataset.map(tokenize_function, batched=True)
@ -68,13 +66,13 @@ training_args = TrainingArguments(
output_dir="./results", output_dir="./results",
evaluation_strategy="steps", evaluation_strategy="steps",
eval_steps=500, eval_steps=500,
save_strategy="steps",
save_steps=500, save_steps=500,
learning_rate=2e-5, learning_rate=2e-5,
per_device_train_batch_size=4, per_device_train_batch_size=2,
per_device_eval_batch_size=4, per_device_eval_batch_size=2,
num_train_epochs=10, num_train_epochs=16,
weight_decay=0.01, weight_decay=0.01,
save_total_limit=2,
load_best_model_at_end=True, load_best_model_at_end=True,
metric_for_best_model="loss", metric_for_best_model="loss",
greater_is_better=False, greater_is_better=False,
@ -95,7 +93,7 @@ trainer = Trainer(
trainer.train() trainer.train()
# 9⃣ Zapis modelu # 9⃣ Zapis modelu
model.save_pretrained("./trained_model/multislav") model.save_pretrained("./models/allegro")
tokenizer.save_pretrained("./trained_model/multislav") tokenizer.save_pretrained("./models/allegro")
print("✅ Model został wytrenowany i zapisany!") print("✅ Model został wytrenowany i zapisany!")