powrót do gemma2

This commit is contained in:
l.gabrysiak 2025-02-25 15:20:55 +01:00
parent 7179a2de95
commit eb1f2229f0
1 changed files with 56 additions and 53 deletions

109
hft.py
View File

@ -1,7 +1,7 @@
import os
import torch
import torch.nn as nn
from transformers import GPTNeoForCausalLM, Trainer, TrainingArguments, AutoTokenizer, AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from datasets import Dataset
from PIL import Image
import re
@ -9,36 +9,28 @@ import pytesseract
import docx2txt
import PyPDF2
import json
from torch.cuda.amp import autocast
from collections import defaultdict
from huggingface_hub import login
import torch
torch.cuda.empty_cache()
# Logowanie do Hugging Face Hub
login(token="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")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
def free_memory():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
# Nowa klasa do zarządzania źródłami
class SourceMapper:
def __init__(self):
self.source_to_idx = defaultdict(lambda: 0) # Domyślnie 0 dla nieznanych
self.idx_to_source = {0: "Unknown"}
self.next_idx = 1 # Indeksy od 1 dla znanych źródeł
self.source_to_idx = defaultdict(lambda: len(self.source_to_idx))
self.idx_to_source = {}
def add_source(self, source):
if source and source not in self.source_to_idx:
idx = self.next_idx
self.source_to_idx[source] = idx
idx = self.source_to_idx[source]
self.idx_to_source[idx] = source
self.next_idx += 1
def get_idx(self, source):
return self.source_to_idx.get(source, 0)
return self.source_to_idx[source] if source else -1
def get_source(self, idx):
return self.idx_to_source.get(idx, "Unknown")
@ -62,7 +54,7 @@ def extract_text_from_file(file_path):
with open(file_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
for page in reader.pages:
text += page.extract_text() or ""
text += page.extract_text()
return text
elif ext in ['.doc', '.docx']:
return docx2txt.process(file_path)
@ -84,7 +76,7 @@ def prepare_dataset(directory, catalog_path, source_mapper):
doc_type = identify_legal_document(file, file_catalog)
if doc_type != "Opracowanie własne":
articles = re.split(r'(Art\.\s+\d+\.)', text)
articles = re.split(r'(Art\.\s+\d+[\.\s])', text)
for i in range(1, len(articles), 2):
article_number = articles[i].strip()
article_content = articles[i+1].strip() if i+1 < len(articles) else ""
@ -100,7 +92,7 @@ def prepare_dataset(directory, catalog_path, source_mapper):
for chunk in chunks:
data.append({
"text": chunk,
"source_idx": 0
"source_idx": -1 # Brak źródła
})
return data
@ -120,74 +112,85 @@ def custom_collate_fn(batch):
input_ids = torch.stack([torch.tensor(b["input_ids"]) for b in batch])
attention_mask = torch.stack([torch.tensor(b["attention_mask"]) for b in batch])
labels = torch.stack([torch.tensor(b["labels"]) for b in batch])
# Dodajemy domyślne source_idx, jeśli nie istnieje
source_idx = torch.tensor([b.get("source_idx", -1) for b in batch], dtype=torch.long)
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "source_idx": source_idx}
class CustomModel(GPTNeoForCausalLM):
class CustomModel(AutoModelForCausalLM):
def __init__(self, config):
super().__init__(config)
self.source_embedding = nn.Embedding(
num_embeddings=1000,
num_embeddings=1000, # Maksymalna liczba unikalnych źródeł
embedding_dim=config.hidden_size,
padding_idx=0 # Poprawiony padding_idx
padding_idx=-1
)
self.source_proj = nn.Linear(config.hidden_size, config.vocab_size)
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
with autocast():
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
**kwargs
)
if source_idx is not None:
source_embeds = self.source_embedding(source_idx)
source_projected = self.source_proj(source_embeds)
outputs.logits += source_projected.unsqueeze(1)
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
**kwargs
)
if source_idx is not None:
# Dodajemy embedding źródła do hidden states
source_embeds = self.source_embedding(source_idx).unsqueeze(1)
outputs.logits += source_embeds
return outputs
class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
labels = inputs.pop("labels")
source_idx = inputs.pop("source_idx")
outputs = model(**inputs, labels=labels, source_idx=source_idx)
return (outputs.loss, outputs) if return_outputs else outputs.loss
# Inicjalizacja komponentów
source_mapper = SourceMapper()
model_name = "EleutherAI/gpt-neo-1.3B"
model_name = "google/gemma-2-2b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
data = prepare_dataset("files", "file_catalog.json", source_mapper)
# Przygotowanie danych
catalog_path = "file_catalog.json"
data = prepare_dataset("files", catalog_path, source_mapper)
dataset = Dataset.from_list(data)
tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=16)
tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=32)
# Inicjalizacja modelu
config = AutoModelForCausalLM.from_pretrained(model_name).config
model = CustomModel.from_pretrained(model_name)
model.config.gradient_checkpointing = True
model.config.use_cache = False
model.resize_token_embeddings(len(tokenizer))
model = CustomModel.from_pretrained(model_name, config=config)
model.gradient_checkpointing_enable()
# Konfiguracja treningu
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
gradient_accumulation_steps=8,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=2e-5,
fp16=True,
logging_steps=50,
logging_steps=100,
save_strategy="steps",
save_steps=500,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
logging_dir='./logs'
save_steps=1000,
report_to="none",
gradient_checkpointing=True
)
trainer = Trainer(
# Trening
trainer = CustomTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=custom_collate_fn
data_collator=custom_collate_fn, # Użyj niestandardowego collate_fn
batch_size=8 # zmniejszenie rozmiaru batcha
)
trainer.train()
free_memory()
# Funkcja generująca odpowiedź
def generate_answer(question, model, tokenizer, source_mapper, max_length=200):
inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512)
@ -203,8 +206,8 @@ def generate_answer(question, model, tokenizer, source_mapper, max_length=200):
answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
# Pobierz źródło z ostatniego tokena
last_token_logits = outputs.scores[-1]
source_idx = torch.argmax(last_token_logits, dim=-1)[-1].item()
last_token_id = outputs.sequences[0][-1].item()
source_idx = model.source_embedding.weight.shape[0] - 1 # Tymczasowe rozwiązanie
source = source_mapper.get_source(source_idx)
return f"{answer}\n\nŹródło: {source if source else 'Opracowanie własne'}"