mod allegro

This commit is contained in:
l.gabrysiak 2025-02-28 20:59:54 +01:00
parent 8e1f346f6e
commit 6a6546a03d
1 changed files with 36 additions and 38 deletions

View File

@ -1,16 +1,16 @@
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import torch
import weaviate
from weaviate.classes.config import Property, DataType, Configure
from weaviate.classes.query import Query
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForSeq2Seq
import numpy as np
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq
from weaviate.classes.config import Property, DataType, Configure
from weaviate.classes.client import WeaviateClient
from weaviate.classes.config import ConnectionParams
# 1⃣ Połączenie z bazą Weaviate
client = weaviate.WeaviateClient(
connection_params=weaviate.ConnectionParams.from_params(
# 1⃣ Połączenie z Weaviate
client = WeaviateClient(
connection_params=ConnectionParams.from_params(
http_host="weaviate",
http_port=8080,
http_secure=False,
@ -20,44 +20,42 @@ client = weaviate.WeaviateClient(
)
)
# 2⃣ Pobranie dokumentów z bazy Weaviate
collection_name = "Documents"
query = Query(collection_name).limit(1000)
result = client.query.run(query)
# 2⃣ Pobranie dokumentów z Weaviate
def fetch_documents():
query = client.query.get("Document", ["content", "fileName"]).do()
documents = []
for item in query["data"]["Get"]["Document"]:
file_name = item.get("fileName", "unknown_file")
content = item.get("content", "")
if content:
documents.append(f"fileName: {file_name}, content: {content}")
return documents
documents = []
file_names = []
documents = fetch_documents()
for item in result[collection_name]['objects']:
documents.append(item['properties']['content'])
file_names.append(item['properties']['fileName'])
# 3⃣ Inicjalizacja modelu
model_name = "allegro/multislav-5lang"
device = "cuda" if torch.cuda.is_available() else "cpu"
# 3⃣ Tworzenie datasetu
training_data = {
"text": documents,
"file_name": file_names
}
dataset = Dataset.from_dict(training_data)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Podział na zestaw treningowy i ewaluacyjny
# 4⃣ Przygotowanie danych treningowych
def create_training_data():
return Dataset.from_dict({"text": documents})
dataset = create_training_data()
split_dataset = dataset.train_test_split(test_size=0.25)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]
# 4⃣ Ładowanie modelu Multislav
model_name = "allegro/multislav-5lang"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 5⃣ Tokenizacja
max_length = 512
def tokenize_function(examples):
return tokenizer(
examples["text"],
padding="max_length",
truncation=True,
max_length=max_length
max_length=512
)
tokenized_train = train_dataset.map(tokenize_function, batched=True)
@ -68,13 +66,13 @@ training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="steps",
eval_steps=500,
save_strategy="steps",
save_steps=500,
learning_rate=2e-5,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
num_train_epochs=10,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
num_train_epochs=16,
weight_decay=0.01,
save_total_limit=2,
load_best_model_at_end=True,
metric_for_best_model="loss",
greater_is_better=False,
@ -95,7 +93,7 @@ trainer = Trainer(
trainer.train()
# 9⃣ Zapis modelu
model.save_pretrained("./trained_model/multislav")
tokenizer.save_pretrained("./trained_model/multislav")
model.save_pretrained("./models/allegro")
tokenizer.save_pretrained("./models/allegro")
print("✅ Model został wytrenowany i zapisany!")