mod gemma

This commit is contained in:
l.gabrysiak 2025-02-26 13:15:18 +01:00
parent 30a6350071
commit b822c32206
1 changed files with 34 additions and 11 deletions

View File

@ -1,10 +1,13 @@
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import torch
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from datasets import Dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling
# 1⃣ Inicjalizacja modelu do embeddingów
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
@ -19,13 +22,12 @@ documents = [
embeddings = embed_model.encode(documents)
# 3⃣ Inicjalizacja FAISS i dodanie wektorów
dim = embeddings.shape[1] # Wymiary wektorów
index = faiss.IndexFlatL2(dim) # Tworzymy indeks FAISS dla metryki L2
index.add(np.array(embeddings, dtype=np.float32)) # Dodajemy wektory do indeksu FAISS
dim = embeddings.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(np.array(embeddings, dtype=np.float32))
# 4⃣ Przygotowanie danych treningowych
def create_training_data():
# Pobranie dokumentów (możesz połączyć je z odpowiednimi embeddingami, jeśli trzeba)
data = {
"text": documents,
"embedding": embeddings.tolist()
@ -47,8 +49,15 @@ lora_config = LoraConfig(
model = get_peft_model(model, lora_config)
# 7⃣ Tokenizacja danych
max_length = 128
def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True)
return tokenizer(
examples["text"],
padding="max_length",
truncation=True,
max_length=max_length
)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
@ -56,22 +65,36 @@ tokenized_dataset = dataset.map(tokenize_function, batched=True)
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=2,
num_train_epochs=3,
gradient_accumulation_steps=4, # Symuluje większy batch size
num_train_epochs=5,
logging_dir="./logs",
save_strategy="epoch"
save_strategy="epoch",
learning_rate=2e-5,
warmup_steps=100,
fp16=True, # Używa mixed precision training
evaluation_strategy="steps",
eval_steps=500,
load_best_model_at_end=True,
)
# 9⃣ Trening modelu
# 9⃣ Data Collator
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
# 🔟 Trening modelu
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=data_collator,
)
trainer.train()
# 🔟 Zapisanie dostrojonego modelu
# 1⃣1 Zapisanie dostrojonego modelu
model.save_pretrained("./trained_model/gemma")
tokenizer.save_pretrained("./trained_model/gemma")
print("✅ Model został wytrenowany i zapisany!")
print("✅ Model został wytrenowany i zapisany!")