powrót do gemma2

This commit is contained in:
l.gabrysiak 2025-02-25 15:20:55 +01:00
parent 7179a2de95
commit eb1f2229f0
1 changed files with 56 additions and 53 deletions

109
hft.py
View File

@ -1,7 +1,7 @@
import os import os
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import GPTNeoForCausalLM, Trainer, TrainingArguments, AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from datasets import Dataset from datasets import Dataset
from PIL import Image from PIL import Image
import re import re
@ -9,36 +9,28 @@ import pytesseract
import docx2txt import docx2txt
import PyPDF2 import PyPDF2
import json import json
from torch.cuda.amp import autocast
from collections import defaultdict from collections import defaultdict
from huggingface_hub import login from huggingface_hub import login
import torch
torch.cuda.empty_cache() torch.cuda.empty_cache()
# Logowanie do Hugging Face Hub
login(token="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") login(token="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")
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
def free_memory():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
# Nowa klasa do zarządzania źródłami
class SourceMapper: class SourceMapper:
def __init__(self): def __init__(self):
self.source_to_idx = defaultdict(lambda: 0) # Domyślnie 0 dla nieznanych self.source_to_idx = defaultdict(lambda: len(self.source_to_idx))
self.idx_to_source = {0: "Unknown"} self.idx_to_source = {}
self.next_idx = 1 # Indeksy od 1 dla znanych źródeł
def add_source(self, source): def add_source(self, source):
if source and source not in self.source_to_idx: if source and source not in self.source_to_idx:
idx = self.next_idx idx = self.source_to_idx[source]
self.source_to_idx[source] = idx
self.idx_to_source[idx] = source self.idx_to_source[idx] = source
self.next_idx += 1
def get_idx(self, source): def get_idx(self, source):
return self.source_to_idx.get(source, 0) return self.source_to_idx[source] if source else -1
def get_source(self, idx): def get_source(self, idx):
return self.idx_to_source.get(idx, "Unknown") return self.idx_to_source.get(idx, "Unknown")
@ -62,7 +54,7 @@ def extract_text_from_file(file_path):
with open(file_path, 'rb') as file: with open(file_path, 'rb') as file:
reader = PyPDF2.PdfReader(file) reader = PyPDF2.PdfReader(file)
for page in reader.pages: for page in reader.pages:
text += page.extract_text() or "" text += page.extract_text()
return text return text
elif ext in ['.doc', '.docx']: elif ext in ['.doc', '.docx']:
return docx2txt.process(file_path) return docx2txt.process(file_path)
@ -84,7 +76,7 @@ def prepare_dataset(directory, catalog_path, source_mapper):
doc_type = identify_legal_document(file, file_catalog) doc_type = identify_legal_document(file, file_catalog)
if doc_type != "Opracowanie własne": if doc_type != "Opracowanie własne":
articles = re.split(r'(Art\.\s+\d+\.)', text) articles = re.split(r'(Art\.\s+\d+[\.\s])', text)
for i in range(1, len(articles), 2): for i in range(1, len(articles), 2):
article_number = articles[i].strip() article_number = articles[i].strip()
article_content = articles[i+1].strip() if i+1 < len(articles) else "" article_content = articles[i+1].strip() if i+1 < len(articles) else ""
@ -100,7 +92,7 @@ def prepare_dataset(directory, catalog_path, source_mapper):
for chunk in chunks: for chunk in chunks:
data.append({ data.append({
"text": chunk, "text": chunk,
"source_idx": 0 "source_idx": -1 # Brak źródła
}) })
return data return data
@ -120,74 +112,85 @@ def custom_collate_fn(batch):
input_ids = torch.stack([torch.tensor(b["input_ids"]) for b in batch]) input_ids = torch.stack([torch.tensor(b["input_ids"]) for b in batch])
attention_mask = torch.stack([torch.tensor(b["attention_mask"]) for b in batch]) attention_mask = torch.stack([torch.tensor(b["attention_mask"]) for b in batch])
labels = torch.stack([torch.tensor(b["labels"]) for b in batch]) labels = torch.stack([torch.tensor(b["labels"]) for b in batch])
# Dodajemy domyślne source_idx, jeśli nie istnieje
source_idx = torch.tensor([b.get("source_idx", -1) for b in batch], dtype=torch.long) source_idx = torch.tensor([b.get("source_idx", -1) for b in batch], dtype=torch.long)
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "source_idx": source_idx} return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "source_idx": source_idx}
class CustomModel(GPTNeoForCausalLM): class CustomModel(AutoModelForCausalLM):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.source_embedding = nn.Embedding( self.source_embedding = nn.Embedding(
num_embeddings=1000, num_embeddings=1000, # Maksymalna liczba unikalnych źródeł
embedding_dim=config.hidden_size, embedding_dim=config.hidden_size,
padding_idx=0 # Poprawiony padding_idx padding_idx=-1
) )
self.source_proj = nn.Linear(config.hidden_size, config.vocab_size)
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs): def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
with autocast(): outputs = super().forward(
outputs = super().forward( input_ids=input_ids,
input_ids=input_ids, attention_mask=attention_mask,
attention_mask=attention_mask, labels=labels,
labels=labels, **kwargs
**kwargs )
)
if source_idx is not None: if source_idx is not None:
source_embeds = self.source_embedding(source_idx) # Dodajemy embedding źródła do hidden states
source_projected = self.source_proj(source_embeds) source_embeds = self.source_embedding(source_idx).unsqueeze(1)
outputs.logits += source_projected.unsqueeze(1) outputs.logits += source_embeds
return outputs return outputs
class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
labels = inputs.pop("labels")
source_idx = inputs.pop("source_idx")
outputs = model(**inputs, labels=labels, source_idx=source_idx)
return (outputs.loss, outputs) if return_outputs else outputs.loss
# Inicjalizacja komponentów
source_mapper = SourceMapper() source_mapper = SourceMapper()
model_name = "EleutherAI/gpt-neo-1.3B" model_name = "google/gemma-2-2b"
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
data = prepare_dataset("files", "file_catalog.json", source_mapper) # Przygotowanie danych
catalog_path = "file_catalog.json"
data = prepare_dataset("files", catalog_path, source_mapper)
dataset = Dataset.from_list(data) dataset = Dataset.from_list(data)
tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=16) tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=32)
# Inicjalizacja modelu
config = AutoModelForCausalLM.from_pretrained(model_name).config config = AutoModelForCausalLM.from_pretrained(model_name).config
model = CustomModel.from_pretrained(model_name) model = CustomModel.from_pretrained(model_name, config=config)
model.config.gradient_checkpointing = True
model.config.use_cache = False
model.resize_token_embeddings(len(tokenizer))
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
# Konfiguracja treningu
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir="./results", output_dir="./results",
num_train_epochs=3, num_train_epochs=3,
gradient_accumulation_steps=8, per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=2e-5, learning_rate=2e-5,
fp16=True, fp16=True,
logging_steps=50, logging_steps=100,
save_strategy="steps", save_strategy="steps",
save_steps=500, save_steps=1000,
per_device_train_batch_size=2, report_to="none",
per_device_eval_batch_size=2, gradient_checkpointing=True
logging_dir='./logs'
) )
trainer = Trainer( # Trening
trainer = CustomTrainer(
model=model, model=model,
args=training_args, args=training_args,
train_dataset=tokenized_dataset, train_dataset=tokenized_dataset,
data_collator=custom_collate_fn data_collator=custom_collate_fn, # Użyj niestandardowego collate_fn
batch_size=8 # zmniejszenie rozmiaru batcha
) )
trainer.train() trainer.train()
free_memory()
# Funkcja generująca odpowiedź # Funkcja generująca odpowiedź
def generate_answer(question, model, tokenizer, source_mapper, max_length=200): def generate_answer(question, model, tokenizer, source_mapper, max_length=200):
inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512) inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512)
@ -203,8 +206,8 @@ def generate_answer(question, model, tokenizer, source_mapper, max_length=200):
answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
# Pobierz źródło z ostatniego tokena # Pobierz źródło z ostatniego tokena
last_token_logits = outputs.scores[-1] last_token_id = outputs.sequences[0][-1].item()
source_idx = torch.argmax(last_token_logits, dim=-1)[-1].item() source_idx = model.source_embedding.weight.shape[0] - 1 # Tymczasowe rozwiązanie
source = source_mapper.get_source(source_idx) source = source_mapper.get_source(source_idx)
return f"{answer}\n\nŹródło: {source if source else 'Opracowanie własne'}" return f"{answer}\n\nŹródło: {source if source else 'Opracowanie własne'}"