import os import torch import torch.nn as nn from transformers import GPTNeoForCausalLM, Trainer, TrainingArguments, AutoTokenizer, AutoModelForCausalLM from datasets import Dataset from PIL import Image import re import pytesseract import docx2txt import PyPDF2 import json from torch.cuda.amp import autocast from collections import defaultdict from huggingface_hub import login torch.cuda.empty_cache() # Logowanie do Hugging Face Hub login(token="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") os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" def free_memory(): torch.cuda.empty_cache() torch.cuda.ipc_collect() class SourceMapper: def __init__(self): self.source_to_idx = defaultdict(lambda: 0) # Domyślnie 0 dla nieznanych self.idx_to_source = {0: "Unknown"} self.next_idx = 1 # Indeksy od 1 dla znanych źródeł def add_source(self, source): if source and source not in self.source_to_idx: idx = self.next_idx self.source_to_idx[source] = idx self.idx_to_source[idx] = source self.next_idx += 1 def get_idx(self, source): return self.source_to_idx.get(source, 0) def get_source(self, idx): return self.idx_to_source.get(idx, "Unknown") def load_file_catalog(catalog_path): with open(catalog_path, 'r', encoding='utf-8') as file: return json.load(file) def identify_legal_document(filename, file_catalog): return file_catalog.get(filename, "Opracowanie własne") def extract_text_from_file(file_path): _, ext = os.path.splitext(file_path) ext = ext.lower() if ext in ['.txt', '.md']: with open(file_path, 'r', encoding='utf-8') as file: return file.read() elif ext == '.pdf': text = "" with open(file_path, 'rb') as file: reader = PyPDF2.PdfReader(file) for page in reader.pages: text += page.extract_text() or "" return text elif ext in ['.doc', '.docx']: return docx2txt.process(file_path) elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']: return pytesseract.image_to_string(Image.open(file_path)) else: return "" def prepare_dataset(directory, catalog_path, source_mapper): file_catalog = load_file_catalog(catalog_path) data = [] for root, _, files in os.walk(directory): for file in files: file_path = os.path.join(root, file) text = extract_text_from_file(file_path) if not text: continue doc_type = identify_legal_document(file, file_catalog) if doc_type != "Opracowanie własne": articles = re.split(r'(Art\.\s+\d+\.)', text) for i in range(1, len(articles), 2): article_number = articles[i].strip() article_content = articles[i+1].strip() if i+1 < len(articles) else "" source = f"{doc_type}, {article_number}" source_mapper.add_source(source) data.append({ "text": f"{article_number} {article_content}", "source_idx": source_mapper.get_idx(source) }) else: chunks = [text[i:i+512] for i in range(0, len(text), 512)] for chunk in chunks: data.append({ "text": chunk, "source_idx": 0 }) return data def tokenize_function(examples): tokenized = tokenizer( examples["text"], truncation=True, padding="max_length", max_length=512, return_tensors="pt" ) tokenized["labels"] = tokenized["input_ids"].clone() tokenized["source_idx"] = examples["source_idx"] return tokenized def custom_collate_fn(batch): input_ids = torch.stack([torch.tensor(b["input_ids"]) for b in batch]) attention_mask = torch.stack([torch.tensor(b["attention_mask"]) for b in batch]) labels = torch.stack([torch.tensor(b["labels"]) for b in batch]) source_idx = torch.tensor([b.get("source_idx", -1) for b in batch], dtype=torch.long) return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "source_idx": source_idx} class CustomModel(GPTNeoForCausalLM): def __init__(self, config): super().__init__(config) self.source_embedding = nn.Embedding( num_embeddings=1000, embedding_dim=config.hidden_size, padding_idx=0 # Poprawiony padding_idx ) self.source_proj = nn.Linear(config.hidden_size, config.vocab_size) def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs): with autocast(): outputs = super().forward( input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs ) if source_idx is not None: source_embeds = self.source_embedding(source_idx) source_projected = self.source_proj(source_embeds) outputs.logits += source_projected.unsqueeze(1) return outputs source_mapper = SourceMapper() model_name = "EleutherAI/gpt-neo-1.3B" tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token data = prepare_dataset("files", "file_catalog.json", source_mapper) dataset = Dataset.from_list(data) tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=16) config = AutoModelForCausalLM.from_pretrained(model_name).config model = CustomModel.from_pretrained(model_name) model.config.gradient_checkpointing = True model.config.use_cache = False model.resize_token_embeddings(len(tokenizer)) model.gradient_checkpointing_enable() training_args = TrainingArguments( output_dir="./results", num_train_epochs=3, gradient_accumulation_steps=8, learning_rate=2e-5, fp16=True, logging_steps=50, save_strategy="steps", save_steps=500, per_device_train_batch_size=2, per_device_eval_batch_size=2, logging_dir='./logs' ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset, data_collator=custom_collate_fn ) trainer.train() free_memory() # Funkcja generująca odpowiedź def generate_answer(question, model, tokenizer, source_mapper, max_length=200): inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512) outputs = model.generate( **inputs, max_length=max_length, num_return_sequences=1, return_dict_in_generate=True, output_scores=True, ) answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) # Pobierz źródło z ostatniego tokena last_token_logits = outputs.scores[-1] source_idx = torch.argmax(last_token_logits, dim=-1)[-1].item() source = source_mapper.get_source(source_idx) return f"{answer}\n\nŹródło: {source if source else 'Opracowanie własne'}" # Przykład użycia question = "Ile dni urlopu przysługuje pracownikowi?" answer = generate_answer(question, model, tokenizer, source_mapper) print(answer)