ably.do/gemma.py

102 lines
3.0 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import torch
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from datasets import Dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling
# 1⃣ Inicjalizacja modelu do embeddingów
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
# 2⃣ Dodanie dokumentów i embeddingów
documents = [
"Jak założyć firmę w Polsce?",
"Jak rozliczyć podatek VAT?",
"Procedura składania reklamacji w e-sklepie.",
"Jakie dokumenty są potrzebne do rejestracji działalności?"
]
embeddings = embed_model.encode(documents)
# 3⃣ Inicjalizacja FAISS i dodanie wektorów
dim = embeddings.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(np.array(embeddings, dtype=np.float32))
# 4⃣ Przygotowanie danych treningowych
def create_training_data():
data = {
"text": documents,
"embedding": embeddings.tolist()
}
return Dataset.from_dict(data)
dataset = create_training_data()
# 5⃣ Ładowanie modelu Gemma 2 7B
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "google/gemma-2-2b"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 6⃣ Konfiguracja LoRA dla efektywnego treningu
lora_config = LoraConfig(
r=8, lora_alpha=32, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
# 7⃣ Tokenizacja danych
max_length = 128
def tokenize_function(examples):
return tokenizer(
examples["text"],
padding="max_length",
truncation=True,
max_length=max_length
)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
# 8⃣ Parametry treningu
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="steps", # Zmienione na "steps"
eval_steps=500, # Dodane
save_strategy="steps", # Zmienione na "steps"
save_steps=500, # Dodane, musi być takie samo jak eval_steps lub jego wielokrotność
learning_rate=2e-5,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
num_train_epochs=5,
weight_decay=0.01,
load_best_model_at_end=True,
metric_for_best_model="loss", # lub inna metryka, którą chcesz optymalizować
greater_is_better=False, # Ustaw na True, jeśli wyższa wartość metryki jest lepsza
)
# 9⃣ Data Collator
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
# 🔟 Trening modelu
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=data_collator,
)
trainer.train()
# 1⃣1⃣ Zapisanie dostrojonego modelu
model.save_pretrained("./trained_model/gemma")
tokenizer.save_pretrained("./trained_model/gemma")
print("✅ Model został wytrenowany i zapisany!")