diff --git a/hft.py b/hft.py index ecc1e42..eccec6c 100644 --- a/hft.py +++ b/hft.py @@ -1,7 +1,7 @@ import os import torch import torch.nn as nn -from transformers import GPTNeoForCausalLM, Trainer, TrainingArguments, AutoTokenizer, AutoModelForCausalLM +from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer from datasets import Dataset from PIL import Image import re @@ -9,36 +9,28 @@ import pytesseract import docx2txt import PyPDF2 import json -from torch.cuda.amp import autocast from collections import defaultdict from huggingface_hub import login +import torch torch.cuda.empty_cache() -# Logowanie do Hugging Face Hub login(token="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") os.environ["TOKENIZERS_PARALLELISM"] = "false" -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" - -def free_memory(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() +# Nowa klasa do zarządzania źródłami class SourceMapper: def __init__(self): - self.source_to_idx = defaultdict(lambda: 0) # Domyślnie 0 dla nieznanych - self.idx_to_source = {0: "Unknown"} - self.next_idx = 1 # Indeksy od 1 dla znanych źródeł + self.source_to_idx = defaultdict(lambda: len(self.source_to_idx)) + self.idx_to_source = {} def add_source(self, source): if source and source not in self.source_to_idx: - idx = self.next_idx - self.source_to_idx[source] = idx + idx = self.source_to_idx[source] self.idx_to_source[idx] = source - self.next_idx += 1 def get_idx(self, source): - return self.source_to_idx.get(source, 0) + return self.source_to_idx[source] if source else -1 def get_source(self, idx): return self.idx_to_source.get(idx, "Unknown") @@ -62,7 +54,7 @@ def extract_text_from_file(file_path): with open(file_path, 'rb') as file: reader = PyPDF2.PdfReader(file) for page in reader.pages: - text += page.extract_text() or "" + text += page.extract_text() return text elif ext in ['.doc', '.docx']: return docx2txt.process(file_path) @@ -84,7 +76,7 @@ def prepare_dataset(directory, catalog_path, source_mapper): doc_type = identify_legal_document(file, file_catalog) if doc_type != "Opracowanie własne": - articles = re.split(r'(Art\.\s+\d+\.)', text) + articles = re.split(r'(Art\.\s+\d+[\.\s])', text) for i in range(1, len(articles), 2): article_number = articles[i].strip() article_content = articles[i+1].strip() if i+1 < len(articles) else "" @@ -100,7 +92,7 @@ def prepare_dataset(directory, catalog_path, source_mapper): for chunk in chunks: data.append({ "text": chunk, - "source_idx": 0 + "source_idx": -1 # Brak źródła }) return data @@ -120,74 +112,85 @@ def custom_collate_fn(batch): input_ids = torch.stack([torch.tensor(b["input_ids"]) for b in batch]) attention_mask = torch.stack([torch.tensor(b["attention_mask"]) for b in batch]) labels = torch.stack([torch.tensor(b["labels"]) for b in batch]) + + # Dodajemy domyślne source_idx, jeśli nie istnieje source_idx = torch.tensor([b.get("source_idx", -1) for b in batch], dtype=torch.long) + return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "source_idx": source_idx} -class CustomModel(GPTNeoForCausalLM): +class CustomModel(AutoModelForCausalLM): def __init__(self, config): super().__init__(config) self.source_embedding = nn.Embedding( - num_embeddings=1000, + num_embeddings=1000, # Maksymalna liczba unikalnych źródeł embedding_dim=config.hidden_size, - padding_idx=0 # Poprawiony padding_idx + padding_idx=-1 ) - self.source_proj = nn.Linear(config.hidden_size, config.vocab_size) def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs): - with autocast(): - outputs = super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - labels=labels, - **kwargs - ) - if source_idx is not None: - source_embeds = self.source_embedding(source_idx) - source_projected = self.source_proj(source_embeds) - outputs.logits += source_projected.unsqueeze(1) + outputs = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + **kwargs + ) + + if source_idx is not None: + # Dodajemy embedding źródła do hidden states + source_embeds = self.source_embedding(source_idx).unsqueeze(1) + outputs.logits += source_embeds + return outputs +class CustomTrainer(Trainer): + def compute_loss(self, model, inputs, return_outputs=False, **kwargs): + labels = inputs.pop("labels") + source_idx = inputs.pop("source_idx") + outputs = model(**inputs, labels=labels, source_idx=source_idx) + return (outputs.loss, outputs) if return_outputs else outputs.loss + +# Inicjalizacja komponentów source_mapper = SourceMapper() -model_name = "EleutherAI/gpt-neo-1.3B" +model_name = "google/gemma-2-2b" tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token -data = prepare_dataset("files", "file_catalog.json", source_mapper) +# Przygotowanie danych +catalog_path = "file_catalog.json" +data = prepare_dataset("files", catalog_path, source_mapper) dataset = Dataset.from_list(data) -tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=16) +tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=32) +# Inicjalizacja modelu config = AutoModelForCausalLM.from_pretrained(model_name).config -model = CustomModel.from_pretrained(model_name) -model.config.gradient_checkpointing = True -model.config.use_cache = False -model.resize_token_embeddings(len(tokenizer)) +model = CustomModel.from_pretrained(model_name, config=config) model.gradient_checkpointing_enable() +# Konfiguracja treningu training_args = TrainingArguments( output_dir="./results", num_train_epochs=3, - gradient_accumulation_steps=8, + per_device_train_batch_size=2, + gradient_accumulation_steps=4, learning_rate=2e-5, fp16=True, - logging_steps=50, + logging_steps=100, save_strategy="steps", - save_steps=500, - per_device_train_batch_size=2, - per_device_eval_batch_size=2, - logging_dir='./logs' + save_steps=1000, + report_to="none", + gradient_checkpointing=True ) -trainer = Trainer( +# Trening +trainer = CustomTrainer( model=model, args=training_args, train_dataset=tokenized_dataset, - data_collator=custom_collate_fn + data_collator=custom_collate_fn, # Użyj niestandardowego collate_fn + batch_size=8 # zmniejszenie rozmiaru batcha ) - trainer.train() -free_memory() - # Funkcja generująca odpowiedź def generate_answer(question, model, tokenizer, source_mapper, max_length=200): inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512) @@ -203,8 +206,8 @@ def generate_answer(question, model, tokenizer, source_mapper, max_length=200): answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) # Pobierz źródło z ostatniego tokena - last_token_logits = outputs.scores[-1] - source_idx = torch.argmax(last_token_logits, dim=-1)[-1].item() + last_token_id = outputs.sequences[0][-1].item() + source_idx = model.source_embedding.weight.shape[0] - 1 # Tymczasowe rozwiązanie source = source_mapper.get_source(source_idx) return f"{answer}\n\nŹródło: {source if source else 'Opracowanie własne'}"