ably.do/gemma.py

101 lines
2.8 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import torch
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from datasets import Dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling
# 1⃣ Inicjalizacja modelu do embeddingów
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
# 2⃣ Dodanie dokumentów i embeddingów
documents = [
"Jak założyć firmę w Polsce?",
"Jak rozliczyć podatek VAT?",
"Procedura składania reklamacji w e-sklepie.",
"Jakie dokumenty są potrzebne do rejestracji działalności?"
]
embeddings = embed_model.encode(documents)
# 3⃣ Inicjalizacja FAISS i dodanie wektorów
dim = embeddings.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(np.array(embeddings, dtype=np.float32))
# 4⃣ Przygotowanie danych treningowych
def create_training_data():
data = {
"text": documents,
"embedding": embeddings.tolist()
}
return Dataset.from_dict(data)
dataset = create_training_data()
# 5⃣ Ładowanie modelu Gemma 2 7B
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "google/gemma-2b"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 6⃣ Konfiguracja LoRA dla efektywnego treningu
lora_config = LoraConfig(
r=8, lora_alpha=32, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
# 7⃣ Tokenizacja danych
max_length = 128
def tokenize_function(examples):
return tokenizer(
examples["text"],
padding="max_length",
truncation=True,
max_length=max_length
)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
# 8⃣ Parametry treningu
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=2,
gradient_accumulation_steps=4, # Symuluje większy batch size
num_train_epochs=5,
logging_dir="./logs",
save_strategy="epoch",
learning_rate=2e-5,
warmup_steps=100,
fp16=True, # Używa mixed precision training
evaluation_strategy="steps",
eval_steps=500,
load_best_model_at_end=True,
)
# 9⃣ Data Collator
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
# 🔟 Trening modelu
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=data_collator,
)
trainer.train()
# 1⃣1⃣ Zapisanie dostrojonego modelu
model.save_pretrained("./trained_model/gemma")
tokenizer.save_pretrained("./trained_model/gemma")
print("✅ Model został wytrenowany i zapisany!")