mod allegro

This commit is contained in:
l.gabrysiak 2025-02-28 20:58:24 +01:00
parent 4007d446e3
commit 8e1f346f6e
1 changed files with 39 additions and 58 deletions

View File

@ -2,21 +2,15 @@ import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import torch
import numpy as np
from sentence_transformers import SentenceTransformer
from datasets import Dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForSeq2Seq
import weaviate
from weaviate.client import WeaviateClient
from weaviate.connect import ConnectionParams
from weaviate.classes.config import Property, DataType, Configure
from weaviate.classes.query import Query
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForSeq2Seq
from datasets import Dataset
# 1⃣ Inicjalizacja modelu do embeddingów
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
# 2⃣ Połączenie z Weaviate i pobranie dokumentów
client = WeaviateClient(
connection_params=ConnectionParams.from_params(
# 1⃣ Połączenie z bazą Weaviate
client = weaviate.WeaviateClient(
connection_params=weaviate.ConnectionParams.from_params(
http_host="weaviate",
http_port=8080,
http_secure=False,
@ -26,48 +20,38 @@ client = WeaviateClient(
)
)
collection_name = "Document" # Zakładam, że to jest nazwa Twojej kolekcji
result = (
client.query.get(collection_name, ["content"])
.with_additional(["id"])
.do()
)
# 2⃣ Pobranie dokumentów z bazy Weaviate
collection_name = "Documents"
query = Query(collection_name).limit(1000)
result = client.query.run(query)
documents = [item['content'] for item in result['data']['Get'][collection_name]]
documents = []
file_names = []
# 3⃣ Generowanie embeddingów
embeddings = embed_model.encode(documents)
for item in result[collection_name]['objects']:
documents.append(item['properties']['content'])
file_names.append(item['properties']['fileName'])
# 4⃣ Przygotowanie danych treningowych
def create_training_data():
data = {
"text": documents,
"embedding": embeddings.tolist()
}
return Dataset.from_dict(data)
# 3⃣ Tworzenie datasetu
training_data = {
"text": documents,
"file_name": file_names
}
dataset = Dataset.from_dict(training_data)
dataset = create_training_data()
# Podział danych na treningowe i ewaluacyjne
# Podział na zestaw treningowy i ewaluacyjny
split_dataset = dataset.train_test_split(test_size=0.25)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]
# 5⃣ Ładowanie modelu allegro/multislav-5lang
device = "cuda" if torch.cuda.is_available() else "cpu"
# 4⃣ Ładowanie modelu Multislav
model_name = "allegro/multislav-5lang"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 6⃣ Konfiguracja LoRA
lora_config = LoraConfig(
r=8, lora_alpha=32, lora_dropout=0.1, bias="none", task_type="SEQ_2_SEQ_LM"
)
model = get_peft_model(model, lora_config)
# 7⃣ Tokenizacja danych
max_length = 384
# 5⃣ Tokenizacja
max_length = 512
def tokenize_function(examples):
return tokenizer(
examples["text"],
@ -79,30 +63,27 @@ def tokenize_function(examples):
tokenized_train = train_dataset.map(tokenize_function, batched=True)
tokenized_eval = eval_dataset.map(tokenize_function, batched=True)
# 8️⃣ Parametry treningu
# 6️⃣ Parametry treningu
training_args = TrainingArguments(
output_dir="./results",
eval_strategy="steps",
evaluation_strategy="steps",
eval_steps=500,
save_strategy="steps",
save_steps=500,
learning_rate=1e-5,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
num_train_epochs=16,
learning_rate=2e-5,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
num_train_epochs=10,
weight_decay=0.01,
load_best_model_at_end=True,
metric_for_best_model="loss",
greater_is_better=False,
)
# 9⃣ Data Collator
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
model=model
)
# 7⃣ Data Collator
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
# 🔟 Trening modelu
# 8⃣ Trening
trainer = Trainer(
model=model,
args=training_args,
@ -113,8 +94,8 @@ trainer = Trainer(
trainer.train()
# 1⃣1️⃣ Zapis modelu
model.save_pretrained("./models/allegro")
tokenizer.save_pretrained("./models/allegro")
# 9️⃣ Zapis modelu
model.save_pretrained("./trained_model/multislav")
tokenizer.save_pretrained("./trained_model/multislav")
print("✅ Model został wytrenowany i zapisany!")