mod allegro

This commit is contained in:
l.gabrysiak 2025-02-28 20:58:24 +01:00
parent 4007d446e3
commit 8e1f346f6e
1 changed files with 39 additions and 58 deletions

View File

@ -2,21 +2,15 @@ import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
import torch import torch
import numpy as np
from sentence_transformers import SentenceTransformer
from datasets import Dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForSeq2Seq
import weaviate import weaviate
from weaviate.client import WeaviateClient from weaviate.classes.config import Property, DataType, Configure
from weaviate.connect import ConnectionParams from weaviate.classes.query import Query
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForSeq2Seq
from datasets import Dataset
# 1⃣ Inicjalizacja modelu do embeddingów # 1⃣ Połączenie z bazą Weaviate
embed_model = SentenceTransformer("all-MiniLM-L6-v2") client = weaviate.WeaviateClient(
connection_params=weaviate.ConnectionParams.from_params(
# 2⃣ Połączenie z Weaviate i pobranie dokumentów
client = WeaviateClient(
connection_params=ConnectionParams.from_params(
http_host="weaviate", http_host="weaviate",
http_port=8080, http_port=8080,
http_secure=False, http_secure=False,
@ -26,48 +20,38 @@ client = WeaviateClient(
) )
) )
collection_name = "Document" # Zakładam, że to jest nazwa Twojej kolekcji # 2⃣ Pobranie dokumentów z bazy Weaviate
result = ( collection_name = "Documents"
client.query.get(collection_name, ["content"]) query = Query(collection_name).limit(1000)
.with_additional(["id"]) result = client.query.run(query)
.do()
)
documents = [item['content'] for item in result['data']['Get'][collection_name]] documents = []
file_names = []
# 3⃣ Generowanie embeddingów for item in result[collection_name]['objects']:
embeddings = embed_model.encode(documents) documents.append(item['properties']['content'])
file_names.append(item['properties']['fileName'])
# 4⃣ Przygotowanie danych treningowych # 3⃣ Tworzenie datasetu
def create_training_data(): training_data = {
data = { "text": documents,
"text": documents, "file_name": file_names
"embedding": embeddings.tolist() }
} dataset = Dataset.from_dict(training_data)
return Dataset.from_dict(data)
dataset = create_training_data() # Podział na zestaw treningowy i ewaluacyjny
# Podział danych na treningowe i ewaluacyjne
split_dataset = dataset.train_test_split(test_size=0.25) split_dataset = dataset.train_test_split(test_size=0.25)
train_dataset = split_dataset["train"] train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"] eval_dataset = split_dataset["test"]
# 5⃣ Ładowanie modelu allegro/multislav-5lang # 4⃣ Ładowanie modelu Multislav
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "allegro/multislav-5lang" model_name = "allegro/multislav-5lang"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
# 6⃣ Konfiguracja LoRA # 5⃣ Tokenizacja
lora_config = LoraConfig( max_length = 512
r=8, lora_alpha=32, lora_dropout=0.1, bias="none", task_type="SEQ_2_SEQ_LM"
)
model = get_peft_model(model, lora_config)
# 7⃣ Tokenizacja danych
max_length = 384
def tokenize_function(examples): def tokenize_function(examples):
return tokenizer( return tokenizer(
examples["text"], examples["text"],
@ -79,30 +63,27 @@ def tokenize_function(examples):
tokenized_train = train_dataset.map(tokenize_function, batched=True) tokenized_train = train_dataset.map(tokenize_function, batched=True)
tokenized_eval = eval_dataset.map(tokenize_function, batched=True) tokenized_eval = eval_dataset.map(tokenize_function, batched=True)
# 8️⃣ Parametry treningu # 6️⃣ Parametry treningu
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir="./results", output_dir="./results",
eval_strategy="steps", evaluation_strategy="steps",
eval_steps=500, eval_steps=500,
save_strategy="steps", save_strategy="steps",
save_steps=500, save_steps=500,
learning_rate=1e-5, learning_rate=2e-5,
per_device_train_batch_size=2, per_device_train_batch_size=4,
per_device_eval_batch_size=2, per_device_eval_batch_size=4,
num_train_epochs=16, num_train_epochs=10,
weight_decay=0.01, weight_decay=0.01,
load_best_model_at_end=True, load_best_model_at_end=True,
metric_for_best_model="loss", metric_for_best_model="loss",
greater_is_better=False, greater_is_better=False,
) )
# 9⃣ Data Collator # 7⃣ Data Collator
data_collator = DataCollatorForSeq2Seq( data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
tokenizer=tokenizer,
model=model
)
# 🔟 Trening modelu # 8⃣ Trening
trainer = Trainer( trainer = Trainer(
model=model, model=model,
args=training_args, args=training_args,
@ -113,8 +94,8 @@ trainer = Trainer(
trainer.train() trainer.train()
# 1⃣1️⃣ Zapis modelu # 9️⃣ Zapis modelu
model.save_pretrained("./models/allegro") model.save_pretrained("./trained_model/multislav")
tokenizer.save_pretrained("./models/allegro") tokenizer.save_pretrained("./trained_model/multislav")
print("✅ Model został wytrenowany i zapisany!") print("✅ Model został wytrenowany i zapisany!")