import os import torch import torch.nn as nn from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer from datasets import Dataset import re import json import PyPDF2 import docx2txt import pytesseract from PIL import Image from collections import defaultdict from huggingface_hub import login # Konfiguracja os.environ['TORCH_USE_CUDA_DSA'] = '1' os.environ["TOKENIZERS_PARALLELISM"] = "false" login(token="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") class SourceMapper: def __init__(self): self.source_to_idx = defaultdict(lambda: len(self.source_to_idx)) self.idx_to_source = {} def add_source(self, source): if source and source not in self.source_to_idx: idx = self.source_to_idx[source] self.idx_to_source[idx] = source def get_idx(self, source): return self.source_to_idx[source] if source else -1 def get_source(self, idx): return self.idx_to_source.get(idx, "Unknown") def load_file_catalog(catalog_path): try: with open(catalog_path, 'r', encoding='utf-8') as file: return json.load(file) except Exception as e: print(f"Błąd wczytywania katalogu plików: {str(e)}") return {} def identify_legal_document(filename, file_catalog): base_name = os.path.splitext(filename)[0].lower() return file_catalog.get(base_name, "Opracowanie własne") def extract_text_from_file(file_path): try: _, ext = os.path.splitext(file_path) ext = ext.lower() if ext in ['.txt', '.md']: with open(file_path, 'r', encoding='utf-8') as file: return file.read() elif ext == '.pdf': text = "" try: with open(file_path, 'rb') as file: reader = PyPDF2.PdfReader(file) for page in reader.pages: text += page.extract_text() or "" except Exception as e: print(f"Błąd PDF: {str(e)}") return text elif ext in ['.doc', '.docx']: return docx2txt.process(file_path) elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']: return pytesseract.image_to_string(Image.open(file_path)) else: print(f"Nieobsługiwany format pliku: {ext}") return "" except Exception as e: print(f"Błąd ekstrakcji tekstu: {str(e)}") return "" def prepare_dataset(directory, catalog_path, source_mapper): file_catalog = load_file_catalog(catalog_path) data = [] print(f"\n{'='*50}\nDIAGNOSTYKA DANYCH\n{'='*50}") for root, _, files in os.walk(directory): for file in files: file_path = os.path.join(root, file) print(f"\nPrzetwarzanie pliku: {file_path}") try: text = extract_text_from_file(file_path) if not text.strip(): print("Pominięto - brak tekstu") continue print(f"Długość tekstu: {len(text)} znaków") doc_type = identify_legal_document(file, file_catalog) print(f"Rozpoznany typ dokumentu: {doc_type}") if doc_type != "Opracowanie własne": # Nowe wyrażenie regularne dla formatu "Art. XX." articles = re.split(r'(Art\. \d+\.?)', text) print(f"Znaleziono {len(articles)} fragmentów") for i in range(1, len(articles), 2): article_number = articles[i].strip() article_content = articles[i+1].strip() if i+1 < len(articles) else "" if not article_content: continue source = f"{doc_type}, {article_number}" print(f"Dodano artykuł: {source}") source_mapper.add_source(source) data.append({ "text": f"{article_number} {article_content}", "source_idx": source_mapper.get_idx(source) }) else: clean_text = re.sub(r'\s+', ' ', text).strip() chunks = [clean_text[i:i+512] for i in range(0, len(clean_text), 512)] chunks = [c for c in chunks if c.strip()] for chunk in chunks: data.append({ "text": chunk, "source_idx": -1 }) print(f"Dodano {len(chunks)} chunków") except Exception as e: print(f"Błąd podczas przetwarzania pliku: {str(e)}") continue print(f"\nPodsumowanie przygotowania danych:") print(f"Łączna liczba przykładów: {len(data)}") if data: print("Przykładowy wpis:") print(json.dumps(data[0], indent=2, ensure_ascii=False)) else: print("BRAK DANYCH - sprawdź diagnostykę powyżej") return data class CustomModel(nn.Module): def __init__(self, model_name, config): super().__init__() self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config) self.source_embedding = nn.Embedding(1000, config.hidden_size, padding_idx=-1) def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs): if source_idx is not None: valid_indices = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1) source_embeds = self.source_embedding(valid_indices).unsqueeze(1).expand(-1, input_ids.size(1), -1) inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds return self.base_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, **kwargs) return self.base_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs) def generate(self, *args, **kwargs): return self.base_model.generate(*args, **kwargs) class CustomTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False, **kwargs): labels = inputs.pop("labels") source_idx = inputs.pop("source_idx", None) outputs = model(**inputs, labels=labels, source_idx=source_idx) return (outputs.loss, outputs) if return_outputs else outputs.loss def main(): # Inicjalizacja source_mapper = SourceMapper() model_name = "crumb/nano-mistral" tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token # Przygotowanie danych catalog_path = "file_catalog.json" data = prepare_dataset("files", catalog_path, source_mapper) if not data: print("\nBrak danych do treningu! Sprawdź pliki w katalogu 'files' i diagnostykę powyżej.") return dataset = Dataset.from_list(data) def tokenize_function(examples): tokenized = tokenizer( examples["text"], truncation=True, padding="max_length", max_length=512, return_tensors="pt" ) return { "input_ids": tokenized["input_ids"][0], "attention_mask": tokenized["attention_mask"][0], "labels": tokenized["input_ids"][0].clone(), "source_idx": examples["source_idx"] } tokenized_dataset = dataset.map(tokenize_function, batched=False) def custom_collate_fn(features): return { "input_ids": torch.stack([torch.tensor(f["input_ids"]) for f in features]), "attention_mask": torch.stack([torch.tensor(f["attention_mask"]) for f in features]), "labels": torch.stack([torch.tensor(f["labels"]) for f in features]), "source_idx": torch.tensor([f["source_idx"] for f in features], dtype=torch.long) } # Model config = AutoModelForCausalLM.from_pretrained(model_name).config model = CustomModel(model_name, config) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Trening training_args = TrainingArguments( output_dir="./results", num_train_epochs=3, per_device_train_batch_size=2, gradient_accumulation_steps=4, learning_rate=2e-5, fp16=torch.cuda.is_available(), logging_steps=10, save_strategy="steps", save_steps=500, report_to="none", remove_unused_columns=False ) trainer = CustomTrainer( model=model, args=training_args, train_dataset=tokenized_dataset, data_collator=custom_collate_fn ) print("\nRozpoczęcie treningu...") trainer.train() # Testowanie def generate_answer(question): model.eval() prompt = f"[PYTANIE PRAWNE] {question}" inputs = tokenizer( prompt, return_tensors="pt", truncation=True, max_length=512 ).to(device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=200, temperature=0.7, top_p=0.9, do_sample=True, repetition_penalty=1.5, no_repeat_ngram_size=3, pad_token_id=tokenizer.eos_token_id ) answer = tokenizer.decode(outputs[0], skip_special_tokens=True) answer = answer.replace(prompt, "").strip() sources = set() for match in re.finditer(r'(?i)art\.?\s*\d+\.?', answer): article_ref = match.group(0).strip().rstrip('.') for source in source_mapper.idx_to_source.values(): if article_ref.lower() in source.lower(): sources.add(source) return { "question": question, "answer": answer, "sources": list(sources) if sources else ["Opracowanie własne"] } # Testy test_questions = [ "Jakie są prawa pracownika według art. 1?", "Kto jest pracownikiem według art. 2?", "Jakie są obowiązki pracodawcy według art. 3?" ] print("\n" + "="*50 + "\nWYNIKI TESTOW\n" + "="*50) for question in test_questions: result = generate_answer(question) print(f"\nPYTANIE: {result['question']}") print(f"ODPOWIEDŹ: {result['answer'][:500]}") print(f"ŹRÓDŁA: {', '.join(result['sources'])}") print("-"*80) if __name__ == "__main__": main()