import os import torch import torch.nn as nn from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer from datasets import Dataset from PIL import Image import re import pytesseract import docx2txt import PyPDF2 import json from collections import defaultdict from huggingface_hub import login # Konfiguracja os.environ['TORCH_USE_CUDA_DSA'] = '1' os.environ["TOKENIZERS_PARALLELISM"] = "false" login(token="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") class SourceMapper: def __init__(self): self.source_to_idx = defaultdict(lambda: len(self.source_to_idx)) self.idx_to_source = {} def add_source(self, source): if source and source not in self.source_to_idx: idx = self.source_to_idx[source] self.idx_to_source[idx] = source def get_idx(self, source): return self.source_to_idx[source] if source else -1 def get_source(self, idx): return self.idx_to_source.get(idx, "Unknown") def load_file_catalog(catalog_path): with open(catalog_path, 'r', encoding='utf-8') as file: return json.load(file) def identify_legal_document(filename, file_catalog): return file_catalog.get(filename, "Opracowanie własne") def extract_text_from_file(file_path): _, ext = os.path.splitext(file_path) ext = ext.lower() if ext in ['.txt', '.md']: with open(file_path, 'r', encoding='utf-8') as file: return file.read() elif ext == '.pdf': text = "" with open(file_path, 'rb') as file: reader = PyPDF2.PdfReader(file) for page in reader.pages: text += page.extract_text() return text elif ext in ['.doc', '.docx']: return docx2txt.process(file_path) elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']: return pytesseract.image_to_string(Image.open(file_path)) else: return "" def prepare_dataset(directory, catalog_path, source_mapper): file_catalog = load_file_catalog(catalog_path) data = [] for root, _, files in os.walk(directory): for file in files: file_path = os.path.join(root, file) text = extract_text_from_file(file_path) if not text: continue doc_type = identify_legal_document(file, file_catalog) if doc_type != "Opracowanie własne": articles = re.split(r'(Art\.\s+\d+[\.\s])', text) for i in range(1, len(articles), 2): article_number = articles[i].strip() article_content = articles[i+1].strip() if i+1 < len(articles) else "" source = f"{doc_type}, {article_number}" source_mapper.add_source(source) data.append({ "text": f"{article_number} {article_content}", "source_idx": source_mapper.get_idx(source) }) else: chunks = [text[i:i+512] for i in range(0, len(text), 512)] for chunk in chunks: data.append({ "text": chunk, "source_idx": -1 }) return data def tokenize_function(examples): tokenized = tokenizer( examples["text"], truncation=True, padding="max_length", max_length=512, return_tensors="pt" ) tokenized["labels"] = tokenized["input_ids"].clone() tokenized["source_idx"] = examples["source_idx"] return tokenized def custom_collate_fn(batch): input_ids = torch.stack([torch.tensor(b["input_ids"]) for b in batch]) attention_mask = torch.stack([torch.tensor(b["attention_mask"]) for b in batch]) labels = torch.stack([torch.tensor(b["labels"]) for b in batch]) source_idx = torch.tensor([b.get("source_idx", -1) for b in batch], dtype=torch.long) return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "source_idx": source_idx } class CustomModel(nn.Module): def __init__(self, model_name, config): super().__init__() self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config) self.source_embedding = nn.Embedding(1000, config.hidden_size, padding_idx=-1) def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs): if source_idx is not None: source_idx = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1) source_embeds = self.source_embedding(source_idx).unsqueeze(1).expand(-1, input_ids.size(1), -1) inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds return self.base_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, **kwargs) return self.base_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs) def generate(self, *args, **kwargs): return self.base_model.generate(*args, **kwargs) def main(): # Inicjalizacja source_mapper = SourceMapper() model_name = "crumb/nano-mistral" tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token # Przygotowanie danych catalog_path = "file_catalog.json" data = prepare_dataset("files", catalog_path, source_mapper) dataset = Dataset.from_list(data) tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=8) # Model config = AutoModelForCausalLM.from_pretrained(model_name).config model = CustomModel(model_name, config) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Trening training_args = TrainingArguments( output_dir="./results", num_train_epochs=3, per_device_train_batch_size=2, gradient_accumulation_steps=4, learning_rate=2e-5, fp16=torch.cuda.is_available(), logging_steps=1, save_strategy="steps", save_steps=1000, report_to="none" ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset, data_collator=custom_collate_fn, ) print("Rozpoczęcie treningu...") trainer.train() # Testowanie def generate_answer(question): inputs = tokenizer(question, return_tensors="pt").to(device) outputs = model.generate( **inputs, max_new_tokens=200, temperature=0.7, top_p=0.9, do_sample=True, repetition_penalty=1.2, no_repeat_ngram_size=2, pad_token_id=tokenizer.eos_token_id ) answer = tokenizer.decode(outputs[0], skip_special_tokens=True) answer = answer.replace(question, "").strip() sources = set() for match in re.finditer(r'Art\.\s+\d+', answer): article_ref = match.group(0).strip() for idx, source in source_mapper.idx_to_source.items(): if article_ref in source: sources.add(source) return { "question": question, "answer": answer, "sources": list(sources) if sources else ["Opracowanie własne"] } # Przykładowe testy test_questions = [ "Jakie są zasady udzielania urlopu wypoczynkowego?", "Co mówi art. 154 kodeksu pracy?", "Jakie są obowiązki pracodawcy w zakresie BHP?" ] print("\n" + "="*50 + "\nWYNIKI TESTOW\n" + "="*50) for question in test_questions: result = generate_answer(question) print(f"\nPYTANIE: {result['question']}") print(f"ODPOWIEDŹ: {result['answer'][:500]}") print(f"ŹRÓDŁA: {', '.join(result['sources'])}") print("-"*80) if __name__ == "__main__": main()