import os import torch import random import re import json import PyPDF2 import docx2txt import pytesseract import numpy as np from PIL import Image from collections import defaultdict from multiprocessing import cpu_count from concurrent.futures import ThreadPoolExecutor from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling ) from datasets import Dataset from nlpaug.augmenter.word import SynonymAug from huggingface_hub import login # Konfiguracja os.environ["TOKENIZERS_PARALLELISM"] = "false" login(token="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") # Zastąp swoim tokenem class SourceMapper: def __init__(self): self.source_to_idx = defaultdict(lambda: len(self.source_to_idx)) self.idx_to_source = {} def add_source(self, source): if source and source not in self.source_to_idx: idx = self.source_to_idx[source] self.idx_to_source[idx] = source def get_idx(self, source): return self.source_to_idx[source] if source else -1 def get_source(self, idx): return self.idx_to_source.get(idx, "Unknown") class LegalProcessor: def __init__(self, catalog_path): self.catalog = self.load_catalog(catalog_path) self.augmenter = SynonymAug(aug_src='wordnet', aug_max=3) def load_catalog(self, path): try: with open(path, 'r', encoding='utf-8') as f: return json.load(f) except: return defaultdict(str) def process_file(self, file_path): text = self.extract_text(file_path) if not text: return [] doc_type = self.identify_doc_type(file_path) return self.split_content(text, doc_type) def extract_text(self, file_path): ext = os.path.splitext(file_path)[1].lower() try: if ext == '.pdf': return self.extract_pdf(file_path) elif ext in ['.doc', '.docx']: return docx2txt.process(file_path) elif ext in ['.jpg', '.jpeg', '.png']: return self.extract_image(file_path) else: with open(file_path, 'r', encoding='utf-8') as f: return f.read() except Exception as e: print(f"Błąd przetwarzania {file_path}: {str(e)}") return "" def extract_pdf(self, path): text = "" with open(path, 'rb') as f: reader = PyPDF2.PdfReader(f) for page in reader.pages: text += page.extract_text() + "\n" return re.sub(r'\s+', ' ', text) def extract_image(self, path): return pytesseract.image_to_string( Image.open(path), config='--psm 4 --oem 3 -c preserve_interword_spaces=1' ) def identify_doc_type(self, file_path): base = os.path.splitext(os.path.basename(file_path))[0].lower() return self.catalog.get(base, "Custom") def split_content(self, text, doc_type): if doc_type == "Custom": return self.split_custom(text) return self.split_legal(text, doc_type) def split_legal(self, text, doc_type): pattern = r'(?i)(Art[\.\s]*\d+[a-z]*|§\s*\d+|Rozdział\s+[IVXLCDM]+)' parts = re.split(pattern, text) results = [] current_header = "" for part in parts: if not part: continue if re.match(pattern, part): if current_header: results.append(current_header) current_header = f"[{doc_type}] {part.strip()}" else: if current_header: results.append(f"{current_header}: {part.strip()}") current_header = "" else: results.append(part.strip()) return [text for text in results if len(text) > 50] def split_custom(self, text): clean_text = re.sub(r'\s+', ' ', text).strip() chunk_size = 384 overlap = 64 chunks = [] start = 0 while start < len(clean_text): end = start + chunk_size chunks.append(clean_text[start:end]) start = end - overlap return [f"[Custom] {chunk}" for chunk in chunks if chunk.strip()] def main(): # Inicjalizacja komponentów source_mapper = SourceMapper() processor = LegalProcessor("file_catalog.json") tokenizer = AutoTokenizer.from_pretrained("crumb/nano-mistral") tokenizer.pad_token = tokenizer.eos_token # Przetwarzanie danych data = [] def process_and_augment(file_path): try: items = processor.process_file(file_path) for text in items: source = text.split("]")[0][1:] source_mapper.add_source(source) # Oryginalny tekst data.append({ "text": text, "source_idx": source_mapper.get_idx(source) }) # Augmentacja augmented = processor.augmenter.augment(text) if augmented != text: data.append({ "text": augmented, "source_idx": source_mapper.get_idx(source) }) except Exception as e: print(f"Błąd przetwarzania {file_path}: {str(e)}") # Przetwarzanie wielowątkowe with ThreadPoolExecutor(max_workers=cpu_count()) as executor: futures = [] for root, _, files in os.walk("files"): # Zmieniono na "files" for file in files: file_path = os.path.join(root, file) futures.append(executor.submit(process_and_augment, file_path)) for future in futures: future.result() # Reszta kodu pozostaje bez zmian... if __name__ == "__main__": main()