ably.do/allegro.py

102 lines
2.8 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import torch
import weaviate
import numpy as np
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq
from weaviate.connect import ConnectionParams
import weaviate
# 1⃣ Połączenie z Weaviate
client = weaviate.WeaviateClient(
connection_params=ConnectionParams.from_params(
http_host="weaviate",
http_port=8080,
http_secure=False,
grpc_host="weaviate",
grpc_port=50051,
grpc_secure=False,
)
)
client.connect()
# 2⃣ Pobranie dokumentów z Weaviate
def fetch_documents():
collection = client.collections.get("Document")
response = collection.query.fetch_objects()
documents = []
for item in response["data"]["Get"]["Document"]:
file_name = item.get("fileName", "unknown_file")
content = item.get("content", "")
if content:
documents.append(f"fileName: {file_name}, content: {content}")
print(f"fileName: {file_name}")
return documents
documents = fetch_documents()
# 3⃣ Inicjalizacja modelu
model_name = "allegro/multislav-5lang"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 4⃣ Przygotowanie danych treningowych
def create_training_data():
return Dataset.from_dict({"text": documents})
dataset = create_training_data()
split_dataset = dataset.train_test_split(test_size=0.25)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]
# 5⃣ Tokenizacja
def tokenize_function(examples):
return tokenizer(
examples["text"],
padding="max_length",
truncation=True,
max_length=512
)
tokenized_train = train_dataset.map(tokenize_function, batched=True)
tokenized_eval = eval_dataset.map(tokenize_function, batched=True)
# 6⃣ Parametry treningu
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="steps",
eval_steps=500,
save_steps=500,
learning_rate=2e-5,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
num_train_epochs=16,
weight_decay=0.01,
save_total_limit=2,
load_best_model_at_end=True,
metric_for_best_model="loss",
greater_is_better=False,
)
# 7⃣ Data Collator
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
# 8⃣ Trening
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_train,
eval_dataset=tokenized_eval,
data_collator=data_collator,
)
trainer.train()
# 9⃣ Zapis modelu
model.save_pretrained("./models/allegro")
tokenizer.save_pretrained("./models/allegro")
print("✅ Model został wytrenowany i zapisany!")