mod gemma

This commit is contained in:
l.gabrysiak 2025-02-26 13:15:18 +01:00
parent 30a6350071
commit b822c32206
1 changed files with 34 additions and 11 deletions

View File

@ -1,10 +1,13 @@
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import torch import torch
import faiss import faiss
import numpy as np import numpy as np
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
from datasets import Dataset from datasets import Dataset
from peft import LoraConfig, get_peft_model from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling
# 1⃣ Inicjalizacja modelu do embeddingów # 1⃣ Inicjalizacja modelu do embeddingów
embed_model = SentenceTransformer("all-MiniLM-L6-v2") embed_model = SentenceTransformer("all-MiniLM-L6-v2")
@ -19,13 +22,12 @@ documents = [
embeddings = embed_model.encode(documents) embeddings = embed_model.encode(documents)
# 3⃣ Inicjalizacja FAISS i dodanie wektorów # 3⃣ Inicjalizacja FAISS i dodanie wektorów
dim = embeddings.shape[1] # Wymiary wektorów dim = embeddings.shape[1]
index = faiss.IndexFlatL2(dim) # Tworzymy indeks FAISS dla metryki L2 index = faiss.IndexFlatL2(dim)
index.add(np.array(embeddings, dtype=np.float32)) # Dodajemy wektory do indeksu FAISS index.add(np.array(embeddings, dtype=np.float32))
# 4⃣ Przygotowanie danych treningowych # 4⃣ Przygotowanie danych treningowych
def create_training_data(): def create_training_data():
# Pobranie dokumentów (możesz połączyć je z odpowiednimi embeddingami, jeśli trzeba)
data = { data = {
"text": documents, "text": documents,
"embedding": embeddings.tolist() "embedding": embeddings.tolist()
@ -47,8 +49,15 @@ lora_config = LoraConfig(
model = get_peft_model(model, lora_config) model = get_peft_model(model, lora_config)
# 7⃣ Tokenizacja danych # 7⃣ Tokenizacja danych
max_length = 128
def tokenize_function(examples): def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True) return tokenizer(
examples["text"],
padding="max_length",
truncation=True,
max_length=max_length
)
tokenized_dataset = dataset.map(tokenize_function, batched=True) tokenized_dataset = dataset.map(tokenize_function, batched=True)
@ -56,21 +65,35 @@ tokenized_dataset = dataset.map(tokenize_function, batched=True)
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir="./results", output_dir="./results",
per_device_train_batch_size=2, per_device_train_batch_size=2,
num_train_epochs=3, gradient_accumulation_steps=4, # Symuluje większy batch size
num_train_epochs=5,
logging_dir="./logs", logging_dir="./logs",
save_strategy="epoch" save_strategy="epoch",
learning_rate=2e-5,
warmup_steps=100,
fp16=True, # Używa mixed precision training
evaluation_strategy="steps",
eval_steps=500,
load_best_model_at_end=True,
) )
# 9⃣ Trening modelu # 9⃣ Data Collator
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
# 🔟 Trening modelu
trainer = Trainer( trainer = Trainer(
model=model, model=model,
args=training_args, args=training_args,
train_dataset=tokenized_dataset, train_dataset=tokenized_dataset,
data_collator=data_collator,
) )
trainer.train() trainer.train()
# 🔟 Zapisanie dostrojonego modelu # 1⃣1 Zapisanie dostrojonego modelu
model.save_pretrained("./trained_model/gemma") model.save_pretrained("./trained_model/gemma")
tokenizer.save_pretrained("./trained_model/gemma") tokenizer.save_pretrained("./trained_model/gemma")