import os import torch import random import re import json import PyPDF2 import docx2txt import pytesseract import numpy as np from PIL import Image from collections import defaultdict from multiprocessing import cpu_count from concurrent.futures import ThreadPoolExecutor from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling ) from datasets import Dataset from nlpaug import Augmenter, CharAugmenter, WordAugmenter from huggingface_hub import login # Konfiguracja os.environ["TOKENIZERS_PARALLELISM"] = "false" login(token="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") # Zastąp prawdziwym tokenem class SourceMapper: def __init__(self): self.source_to_idx = defaultdict(lambda: len(self.source_to_idx)) self.idx_to_source = {} def add_source(self, source): if source and source not in self.source_to_idx: idx = self.source_to_idx[source] self.idx_to_source[idx] = source def get_idx(self, source): return self.source_to_idx[source] if source else -1 def get_source(self, idx): return self.idx_to_source.get(idx, "Unknown") class LegalDataProcessor: def __init__(self, catalog_path): self.catalog = self.load_catalog(catalog_path) self.augmenter = WordAugmenter.AntonymAug() def load_catalog(self, path): try: with open(path, 'r', encoding='utf-8') as f: return json.load(f) except: return defaultdict(str) def identify_document(self, filename): base = os.path.splitext(filename)[0].lower() return self.catalog.get(base, "Opracowanie własne") def extract_text(self, file_path): ext = os.path.splitext(file_path)[1].lower() try: if ext == '.pdf': return self._extract_pdf(file_path) elif ext in ['.doc', '.docx']: return docx2txt.process(file_path) elif ext in ['.jpg', '.jpeg', '.png']: return self._extract_ocr(file_path) else: with open(file_path, 'r', encoding='utf-8') as f: return f.read() except Exception as e: print(f"Błąd przetwarzania {file_path}: {str(e)}") return "" def _extract_pdf(self, path): text = "" with open(path, 'rb') as f: reader = PyPDF2.PdfReader(f) for page in reader.pages: text += page.extract_text() + "\n" return re.sub(r'\s+', ' ', text) def _extract_ocr(self, path): return pytesseract.image_to_string( Image.open(path), config='--psm 4 --oem 3 -c preserve_interword_spaces=1' ) def process_legal(self, text, doc_type): articles = re.split( r'(?ix)(Art\.?\s*\d+[a-z]*|§\s*\d+|Rozdział\s+[IVXLCDM]+)\b', text ) processed = [] current_header = "" for item in articles: if item and re.match(r'(?i)(Art|§|Rozdział)', item): if current_header: processed.append(current_header) current_header = item.strip() elif current_header: processed.append(current_header + " " + item.strip()) current_header = "" else: processed.append(item.strip()) return [ (f"[{doc_type}] {p}", doc_type) for p in processed if len(p) > 30 ] def process_custom(self, text): clean_text = re.sub(r'\s+', ' ', text).strip() chunk_size = 384 overlap = 128 chunks = [ clean_text[i:i+chunk_size] for i in range(0, len(clean_text), chunk_size - overlap) ] return [("[Custom] " + c, "Custom") for c in chunks if c.strip()] class EnhancedDataCollator(DataCollatorForLanguageModeling): def torch_call(self, examples): batch = super().torch_call(examples) if "source_idx" in examples[0]: batch["source_idx"] = torch.tensor( [ex["source_idx"] for ex in examples], dtype=torch.long ) return batch def main(): # Konfiguracja source_mapper = SourceMapper() processor = LegalDataProcessor("file_catalog.json") tokenizer = AutoTokenizer.from_pretrained("crumb/nano-mistral") tokenizer.pad_token = tokenizer.eos_token # Przygotowanie danych data = [] def process_file(file_path): nonlocal data text = processor.extract_text(file_path) if not text: return doc_type = processor.identify_document(os.path.basename(file_path)) if doc_type != "Opracowanie własne": processed = processor.process_legal(text, doc_type) else: processed = processor.process_custom(text) for text, source in processed: source_mapper.add_source(source) data.append({ "text": text, "source_idx": source_mapper.get_idx(source) }) # Przetwarzanie wielowątkowe with ThreadPoolExecutor(max_workers=cpu_count()) as executor: futures = [] for root, _, files in os.walk("files"): for file in files: futures.append(executor.submit( process_file, os.path.join(root, file) )) for future in futures: try: future.result() except Exception as e: print(f"Błąd: {str(e)}") # Augmentacja print(f"Przed augmentacją: {len(data)} przykładów") augmented = [] for item in data: for _ in range(2): # 2 dodatkowe warianty sentences = item['text'].split('. ') random.shuffle(sentences) augmented.append({ "text": '. '.join(sentences), "source_idx": item["source_idx"] }) data += augmented print(f"Po augmentacji: {len(data)} przykładów") # Przygotowanie datasetu dataset = Dataset.from_list(data) def tokenize_fn(examples): tokenized = tokenizer( examples["text"], max_length=512, padding="max_length", truncation=True, return_tensors="pt" ) return { "input_ids": tokenized["input_ids"].squeeze(), "attention_mask": tokenized["attention_mask"].squeeze(), "labels": tokenized["input_ids"].squeeze(), "source_idx": examples["source_idx"] } tokenized_ds = dataset.map( tokenize_fn, batched=True, batch_size=32, num_proc=4 ) # Model model = AutoModelForCausalLM.from_pretrained( "crumb/nano-mistral", trust_remote_code=True ) model.resize_token_embeddings(len(tokenizer)) # Trening training_args = TrainingArguments( output_dir="./results", num_train_epochs=5, per_device_train_batch_size=2, gradient_accumulation_steps=8, learning_rate=1e-4, fp16=torch.cuda.is_available(), logging_steps=20, save_strategy="epoch", report_to="none" ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_ds, data_collator=EnhancedDataCollator(tokenizer=tokenizer, mlm=False) ) print("Rozpoczęcie treningu...") trainer.train() print("Trening zakończony!") # Zapisz model model.save_pretrained("./trained_model") tokenizer.save_pretrained("./trained_model") if __name__ == "__main__": main()