import os import torch import torch.nn as nn import re import json import numpy as np import PyPDF2 import docx2txt import pytesseract from PIL import Image from collections import defaultdict from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling ) from datasets import Dataset, Features, Value from huggingface_hub import login os.environ["TOKENIZERS_PARALLELISM"] = "false" login(token="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") class LegalAITrainer: def __init__(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class SourceMapper: def __init__(self): self.source_to_idx = defaultdict(lambda: len(self.source_to_idx)) self.idx_to_source = {} def add_source(self, source): if source and source not in self.source_to_idx: idx = self.source_to_idx[source] self.idx_to_source[idx] = source def get_idx(self, source): return self.source_to_idx[source] if source else -1 def get_source(self, idx): return self.idx_to_source.get(idx, "Unknown") class LegalModel(nn.Module): def __init__(self, model_name, config): super().__init__() self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config) self.source_embedding = nn.Embedding(100000, config.hidden_size, padding_idx=-1) self.confidence_layer = nn.Linear(config.hidden_size, 1) for param in self.base_model.parameters(): param.requires_grad = False for layer in [self.source_embedding, self.confidence_layer]: for param in layer.parameters(): param.requires_grad = True def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None): if source_idx is not None: source_idx = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1) source_embeds = self.source_embedding(source_idx).unsqueeze(1) inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds outputs = self.base_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels ) else: outputs = self.base_model( input_ids=input_ids, attention_mask=attention_mask, labels=labels ) confidence = torch.sigmoid(self.confidence_layer(outputs.hidden_states[-1].mean(dim=1))) return { "loss": outputs.loss, "logits": outputs.logits, "confidence": confidence, "hidden_states": outputs.hidden_states } def load_file_catalog(self, catalog_path): try: with open(catalog_path, 'r', encoding='utf-8') as f: return json.load(f) except Exception as e: print(f"Błąd ładowania katalogu: {str(e)}") return {} def extract_text(self, file_path): ext = os.path.splitext(file_path)[1].lower() try: if ext in ['.txt', '.md']: with open(file_path, 'r', encoding='utf-8') as f: return f.read() elif ext == '.pdf': text = "" with open(file_path, 'rb') as f: reader = PyPDF2.PdfReader(f) for page in reader.pages: text += page.extract_text() or "" return text elif ext in ['.doc', '.docx']: return docx2txt.process(file_path) elif ext in ['.jpg', '.jpeg', '.png']: return pytesseract.image_to_string(Image.open(file_path)) else: return "" except Exception as e: print(f"Błąd przetwarzania {file_path}: {str(e)}") return "" def prepare_data(self, data_dir, catalog_path): catalog = self.load_file_catalog(catalog_path) data = [] source_mapper = self.SourceMapper() for root, _, files in os.walk(data_dir): for file in files: file_path = os.path.join(root, file) text = self.extract_text(file_path) if not text: continue doc_type = catalog.get(os.path.splitext(file)[0].lower(), "Opracowanie własne") if doc_type != "Opracowanie własne": articles = re.split(r'(?i)(Art\.\s*\d+[a-z]*)', text) for i in range(1, len(articles), 2): art_num = articles[i].strip() content = articles[i+1].strip() if len(content) < 100: continue source = f"{doc_type}, {art_num}" source_mapper.add_source(source) data.append({ "text": f"[LEGAL] {art_num} {content}", "source_idx": source_mapper.get_idx(source), "is_legal": 1 }) else: chunks = [f"[GENERAL] {text[i:i+512]}" for i in range(0, len(text), 512)] for chunk in chunks: data.append({ "text": chunk, "source_idx": -1, "is_legal": 0 }) features = Features({ "text": Value("string"), "source_idx": Value("int32"), "is_legal": Value("int32") }) return Dataset.from_dict({ "text": [d["text"] for d in data], "source_idx": np.array([d["source_idx"] for d in data], dtype=np.int32), "is_legal": np.array([d["is_legal"] for d in data], dtype=np.int32) }, features=features), source_mapper def train(self, model_name="crumb/nano-mistral", data_dir="data", catalog_path="catalog.json"): dataset, source_mapper = self.prepare_data(data_dir, catalog_path) tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token def tokenize_fn(examples): tokenized = tokenizer( examples["text"], padding="max_length", truncation=True, max_length=512, return_tensors="pt" ) return { "input_ids": tokenized["input_ids"].squeeze().tolist(), "attention_mask": tokenized["attention_mask"].squeeze().tolist(), "labels": tokenized["input_ids"].squeeze().clone().tolist(), "source_idx": examples["source_idx"] } tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16) class CustomDataCollator(DataCollatorForLanguageModeling): def torch_call(self, examples): batch = super().torch_call(examples) if "source_idx" in examples[0]: batch["source_idx"] = torch.tensor( [ex["source_idx"] for ex in examples], dtype=torch.int32 ) return batch config = AutoModelForCausalLM.from_pretrained(model_name).config model = self.LegalModel(model_name, config).to(self.device) training_args = TrainingArguments( output_dir="./legal_ai_model", num_train_epochs=3, per_device_train_batch_size=2, gradient_accumulation_steps=4, learning_rate=2e-5, fp16=torch.cuda.is_available(), logging_steps=50, save_strategy="steps", save_steps=500, report_to="none", remove_unused_columns=False ) class LegalTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): outputs = model(**inputs) loss = outputs["loss"] target_conf = (inputs["source_idx"] != -1).float() conf_loss = nn.BCELoss()(outputs["confidence"].squeeze(), target_conf) total_loss = loss + 0.7 * conf_loss return (total_loss, outputs) if return_outputs else total_loss trainer = LegalTrainer( model=model, args=training_args, train_dataset=tokenized_dataset, data_collator=CustomDataCollator(tokenizer=tokenizer, mlm=False) ) print("Rozpoczęcie treningu...") trainer.train() model.save_pretrained("./trained_legal_ai") tokenizer.save_pretrained("./trained_legal_ai") with open("./trained_legal_ai/source_mapper.json", "w") as f: json.dump(source_mapper.idx_to_source, f) print("Trening zakończony!") def generate_response(self, prompt, confidence_threshold=0.65): model = self.LegalModel.from_pretrained( "./trained_legal_ai", config=AutoModelForCausalLM.from_pretrained("crumb/nano-mistral").config ).to(self.device) tokenizer = AutoTokenizer.from_pretrained("./trained_legal_ai") with open("./trained_legal_ai/source_mapper.json", "r") as f: source_mapper = json.load(f) inputs = tokenizer( f"[PROMPT] {prompt} [RESPONSE]", return_tensors="pt", max_length=512, truncation=True ).to(self.device) with torch.no_grad(): outputs = model.generate( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, max_length=512, do_sample=True, temperature=0.7, top_k=50, pad_token_id=tokenizer.eos_token_id, output_scores=True, return_dict_in_generate=True ) full_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) confidence = torch.sigmoid(outputs.scores[-1][:, tokenizer.eos_token_id]).item() citations = list(set(re.findall(r"Art\.\s*\d+[a-z]*", full_text))) verified = [c for c in citations if any(c in s for s in source_mapper.values())] if confidence < confidence_threshold or not verified: return "Nie mogę udzielić jednoznacznej odpowiedzi na podstawie dostępnych danych." else: return f"{full_text}\n\nPotwierdzone źródła: {', '.join(verified)}" if __name__ == "__main__": legal_ai = LegalAITrainer() legal_ai.train( model_name="crumb/nano-mistral", data_dir="./legal_docs", catalog_path="./catalog.json" ) test_prompt = "Jakie są kary za nieprzestrzeganie przepisów RODO?" print(legal_ai.generate_response(test_prompt))