ably.do/allegro.py

102 lines
2.8 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import torch
import weaviate
from weaviate.classes.config import Property, DataType, Configure
from weaviate.classes.query import Query
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForSeq2Seq
from datasets import Dataset
# 1⃣ Połączenie z bazą Weaviate
client = weaviate.WeaviateClient(
connection_params=weaviate.ConnectionParams.from_params(
http_host="weaviate",
http_port=8080,
http_secure=False,
grpc_host="weaviate",
grpc_port=50051,
grpc_secure=False,
)
)
# 2⃣ Pobranie dokumentów z bazy Weaviate
collection_name = "Documents"
query = Query(collection_name).limit(1000)
result = client.query.run(query)
documents = []
file_names = []
for item in result[collection_name]['objects']:
documents.append(item['properties']['content'])
file_names.append(item['properties']['fileName'])
# 3⃣ Tworzenie datasetu
training_data = {
"text": documents,
"file_name": file_names
}
dataset = Dataset.from_dict(training_data)
# Podział na zestaw treningowy i ewaluacyjny
split_dataset = dataset.train_test_split(test_size=0.25)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]
# 4⃣ Ładowanie modelu Multislav
model_name = "allegro/multislav-5lang"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 5⃣ Tokenizacja
max_length = 512
def tokenize_function(examples):
return tokenizer(
examples["text"],
padding="max_length",
truncation=True,
max_length=max_length
)
tokenized_train = train_dataset.map(tokenize_function, batched=True)
tokenized_eval = eval_dataset.map(tokenize_function, batched=True)
# 6⃣ Parametry treningu
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="steps",
eval_steps=500,
save_strategy="steps",
save_steps=500,
learning_rate=2e-5,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
num_train_epochs=10,
weight_decay=0.01,
load_best_model_at_end=True,
metric_for_best_model="loss",
greater_is_better=False,
)
# 7⃣ Data Collator
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
# 8⃣ Trening
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_train,
eval_dataset=tokenized_eval,
data_collator=data_collator,
)
trainer.train()
# 9⃣ Zapis modelu
model.save_pretrained("./trained_model/multislav")
tokenizer.save_pretrained("./trained_model/multislav")
print("✅ Model został wytrenowany i zapisany!")