import os import torch import torch.nn as nn from transformers import AutoTokenizer, GPTNeoForCausalLM, Trainer # Poprawiono importy from datasets import Dataset from PIL import Image import re import pytesseract import docx2txt import PyPDF2 import json from collections import defaultdict from huggingface_hub import login import torch torch.cuda.empty_cache() login(token="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") os.environ["TOKENIZERS_PARALLELISM"] = "false" # Nowa klasa do zarządzania źródłami class SourceMapper: def __init__(self): self.source_to_idx = defaultdict(lambda: len(self.source_to_idx)) self.idx_to_source = {} def add_source(self, source): if source and source not in self.source_to_idx: idx = self.source_to_idx[source] self.idx_to_source[idx] = source def get_idx(self, source): return self.source_to_idx[source] if source else -1 def get_source(self, idx): return self.idx_to_source.get(idx, "Unknown") def load_file_catalog(catalog_path): with open(catalog_path, 'r', encoding='utf-8') as file: return json.load(file) def identify_legal_document(filename, file_catalog): return file_catalog.get(filename, "Opracowanie własne") def extract_text_from_file(file_path): _, ext = os.path.splitext(file_path) ext = ext.lower() if ext in ['.txt', '.md']: with open(file_path, 'r', encoding='utf-8') as file: return file.read() elif ext == '.pdf': text = "" with open(file_path, 'rb') as file: reader = PyPDF2.PdfReader(file) for page in reader.pages: text += page.extract_text() return text elif ext in ['.doc', '.docx']: return docx2txt.process(file_path) elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']: return pytesseract.image_to_string(Image.open(file_path)) else: return "" def prepare_dataset(directory, catalog_path, source_mapper): file_catalog = load_file_catalog(catalog_path) data = [] for root, _, files in os.walk(directory): for file in files: file_path = os.path.join(root, file) text = extract_text_from_file(file_path) if not text: continue doc_type = identify_legal_document(file, file_catalog) if doc_type != "Opracowanie własne": articles = re.split(r'(Art\.\s+\d+[\.\s])', text) for i in range(1, len(articles), 2): article_number = articles[i].strip() article_content = articles[i+1].strip() if i+1 < len(articles) else "" source = f"{doc_type}, {article_number}" source_mapper.add_source(source) data.append({ "text": f"{article_number} {article_content}", "source_idx": source_mapper.get_idx(source) }) else: chunks = [text[i:i+512] for i in range(0, len(text), 512)] for chunk in chunks: data.append({ "text": chunk, "source_idx": -1 # Brak źródła }) return data def tokenize_function(examples): tokenized = tokenizer( examples["text"], truncation=True, padding="max_length", max_length=512, return_tensors="pt" ) tokenized["labels"] = tokenized["input_ids"].clone() tokenized["source_idx"] = examples["source_idx"] return tokenized def custom_collate_fn(batch): input_ids = torch.stack([torch.tensor(b["input_ids"]) for b in batch]) attention_mask = torch.stack([torch.tensor(b["attention_mask"]) for b in batch]) labels = torch.stack([torch.tensor(b["labels"]) for b in batch]) # Dodajemy domyślne source_idx, jeśli nie istnieje source_idx = torch.tensor([b.get("source_idx", -1) for b in batch], dtype=torch.long) return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "source_idx": source_idx} class CustomModel(GPTNeoForCausalLM): # Zmiana klasy bazowej def __init__(self, config): super().__init__(config) self.source_embedding = nn.Embedding( num_embeddings=1000, embedding_dim=config.hidden_size, padding_idx=-1 ) def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs): outputs = super().forward( input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs ) if source_idx is not None: source_embeds = self.source_embedding(source_idx).unsqueeze(1) outputs.logits += source_embeds return outputs class CustomTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False, **kwargs): labels = inputs.pop("labels") source_idx = inputs.pop("source_idx") outputs = model(**inputs, labels=labels, source_idx=source_idx) return (outputs.loss, outputs) if return_outputs else outputs.loss # Inicjalizacja komponentów source_mapper = SourceMapper() model_name = "EleutherAI/gpt-neo-2.7B" tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token # Przygotowanie danych catalog_path = "file_catalog.json" data = prepare_dataset("files", catalog_path, source_mapper) dataset = Dataset.from_list(data) tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=32) # Inicjalizacja modelu config = GPTNeoForCausalLM.from_pretrained(model_name).config model = CustomModel.from_pretrained(model_name, config=config) model.resize_token_embeddings(len(tokenizer)) model.gradient_checkpointing_enable() # Konfiguracja treningu training_args = TrainingArguments( output_dir="./results", num_train_epochs=3, gradient_accumulation_steps=4, learning_rate=2e-5, fp16=True, logging_steps=100, save_strategy="steps", save_steps=1000, report_to="none", gradient_checkpointing=True, per_device_train_batch_size=4, # batch size dla treningu per_device_eval_batch_size=4, # batch size dla ewaluacji logging_dir='./logs' # folder do logów ) # Trening trainer = CustomTrainer( model=model, args=training_args, train_dataset=tokenized_dataset, data_collator=custom_collate_fn # Użyj niestandardowego collate_fn ) trainer.train() # Funkcja generująca odpowiedź def generate_answer(question, model, tokenizer, source_mapper, max_length=200): inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512) outputs = model.generate( **inputs, max_length=max_length, num_return_sequences=1, return_dict_in_generate=True, output_scores=True, ) answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) # Pobierz źródło z ostatniego tokena last_token_id = outputs.sequences[0][-1].item() source_idx = last_token_id % 1000 # Zaktualizuj sposób określania źródła source = source_mapper.get_source(source_idx) return f"{answer}\n\nŹródło: {source if source else 'Opracowanie własne'}" # Przykład użycia question = "Ile dni urlopu przysługuje pracownikowi?" answer = generate_answer(question, model, tokenizer, source_mapper) print(answer)