2025-02-25 04:03:59 -05:00
|
|
|
import os
|
|
|
|
|
import torch
|
2025-02-25 16:17:13 -05:00
|
|
|
import random
|
2025-02-25 04:03:59 -05:00
|
|
|
import re
|
2025-02-25 06:21:39 -05:00
|
|
|
import json
|
2025-02-25 15:23:33 -05:00
|
|
|
import PyPDF2
|
|
|
|
|
import docx2txt
|
|
|
|
|
import pytesseract
|
2025-02-25 16:17:13 -05:00
|
|
|
import numpy as np
|
2025-02-25 15:23:33 -05:00
|
|
|
from PIL import Image
|
2025-02-25 07:34:04 -05:00
|
|
|
from collections import defaultdict
|
2025-02-25 16:17:13 -05:00
|
|
|
from multiprocessing import cpu_count
|
|
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
|
from transformers import (
|
|
|
|
|
AutoTokenizer,
|
|
|
|
|
AutoModelForCausalLM,
|
|
|
|
|
TrainingArguments,
|
|
|
|
|
Trainer,
|
|
|
|
|
DataCollatorForLanguageModeling
|
|
|
|
|
)
|
|
|
|
|
from datasets import Dataset
|
|
|
|
|
from nlpaug import Augmenter, CharAugmenter, WordAugmenter
|
2025-02-25 04:45:37 -05:00
|
|
|
from huggingface_hub import login
|
|
|
|
|
|
2025-02-25 15:17:17 -05:00
|
|
|
# Konfiguracja
|
2025-02-25 07:17:17 -05:00
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
2025-02-25 16:17:13 -05:00
|
|
|
login(token="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") # Zastąp prawdziwym tokenem
|
2025-02-25 11:24:26 -05:00
|
|
|
|
2025-02-25 07:34:04 -05:00
|
|
|
class SourceMapper:
|
|
|
|
|
def __init__(self):
|
2025-02-25 09:20:55 -05:00
|
|
|
self.source_to_idx = defaultdict(lambda: len(self.source_to_idx))
|
|
|
|
|
self.idx_to_source = {}
|
2025-02-25 07:34:04 -05:00
|
|
|
|
|
|
|
|
def add_source(self, source):
|
|
|
|
|
if source and source not in self.source_to_idx:
|
2025-02-25 09:20:55 -05:00
|
|
|
idx = self.source_to_idx[source]
|
2025-02-25 07:34:04 -05:00
|
|
|
self.idx_to_source[idx] = source
|
|
|
|
|
|
|
|
|
|
def get_idx(self, source):
|
2025-02-25 09:20:55 -05:00
|
|
|
return self.source_to_idx[source] if source else -1
|
2025-02-25 07:34:04 -05:00
|
|
|
|
|
|
|
|
def get_source(self, idx):
|
|
|
|
|
return self.idx_to_source.get(idx, "Unknown")
|
|
|
|
|
|
2025-02-25 16:17:13 -05:00
|
|
|
class LegalDataProcessor:
|
|
|
|
|
def __init__(self, catalog_path):
|
|
|
|
|
self.catalog = self.load_catalog(catalog_path)
|
|
|
|
|
self.augmenter = WordAugmenter.AntonymAug()
|
2025-02-25 15:30:01 -05:00
|
|
|
|
2025-02-25 16:17:13 -05:00
|
|
|
def load_catalog(self, path):
|
|
|
|
|
try:
|
|
|
|
|
with open(path, 'r', encoding='utf-8') as f:
|
|
|
|
|
return json.load(f)
|
|
|
|
|
except:
|
|
|
|
|
return defaultdict(str)
|
|
|
|
|
|
|
|
|
|
def identify_document(self, filename):
|
|
|
|
|
base = os.path.splitext(filename)[0].lower()
|
|
|
|
|
return self.catalog.get(base, "Opracowanie własne")
|
|
|
|
|
|
|
|
|
|
def extract_text(self, file_path):
|
|
|
|
|
ext = os.path.splitext(file_path)[1].lower()
|
|
|
|
|
try:
|
|
|
|
|
if ext == '.pdf':
|
|
|
|
|
return self._extract_pdf(file_path)
|
|
|
|
|
elif ext in ['.doc', '.docx']:
|
|
|
|
|
return docx2txt.process(file_path)
|
|
|
|
|
elif ext in ['.jpg', '.jpeg', '.png']:
|
|
|
|
|
return self._extract_ocr(file_path)
|
|
|
|
|
else:
|
|
|
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
|
|
|
return f.read()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"Błąd przetwarzania {file_path}: {str(e)}")
|
2025-02-25 15:30:01 -05:00
|
|
|
return ""
|
2025-02-25 07:34:04 -05:00
|
|
|
|
2025-02-25 16:17:13 -05:00
|
|
|
def _extract_pdf(self, path):
|
|
|
|
|
text = ""
|
|
|
|
|
with open(path, 'rb') as f:
|
|
|
|
|
reader = PyPDF2.PdfReader(f)
|
|
|
|
|
for page in reader.pages:
|
|
|
|
|
text += page.extract_text() + "\n"
|
|
|
|
|
return re.sub(r'\s+', ' ', text)
|
2025-02-25 15:30:01 -05:00
|
|
|
|
2025-02-25 16:17:13 -05:00
|
|
|
def _extract_ocr(self, path):
|
|
|
|
|
return pytesseract.image_to_string(
|
|
|
|
|
Image.open(path),
|
|
|
|
|
config='--psm 4 --oem 3 -c preserve_interword_spaces=1'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def process_legal(self, text, doc_type):
|
|
|
|
|
articles = re.split(
|
|
|
|
|
r'(?ix)(Art\.?\s*\d+[a-z]*|§\s*\d+|Rozdział\s+[IVXLCDM]+)\b',
|
|
|
|
|
text
|
|
|
|
|
)
|
|
|
|
|
processed = []
|
|
|
|
|
current_header = ""
|
2025-02-25 15:30:01 -05:00
|
|
|
|
2025-02-25 16:17:13 -05:00
|
|
|
for item in articles:
|
|
|
|
|
if item and re.match(r'(?i)(Art|§|Rozdział)', item):
|
|
|
|
|
if current_header:
|
|
|
|
|
processed.append(current_header)
|
|
|
|
|
current_header = item.strip()
|
|
|
|
|
elif current_header:
|
|
|
|
|
processed.append(current_header + " " + item.strip())
|
|
|
|
|
current_header = ""
|
|
|
|
|
else:
|
|
|
|
|
processed.append(item.strip())
|
2025-02-25 07:34:04 -05:00
|
|
|
|
2025-02-25 16:17:13 -05:00
|
|
|
return [
|
|
|
|
|
(f"[{doc_type}] {p}", doc_type)
|
|
|
|
|
for p in processed if len(p) > 30
|
|
|
|
|
]
|
2025-02-25 14:09:36 -05:00
|
|
|
|
2025-02-25 16:17:13 -05:00
|
|
|
def process_custom(self, text):
|
|
|
|
|
clean_text = re.sub(r'\s+', ' ', text).strip()
|
|
|
|
|
chunk_size = 384
|
|
|
|
|
overlap = 128
|
|
|
|
|
|
|
|
|
|
chunks = [
|
|
|
|
|
clean_text[i:i+chunk_size]
|
|
|
|
|
for i in range(0, len(clean_text), chunk_size - overlap)
|
|
|
|
|
]
|
|
|
|
|
return [("[Custom] " + c, "Custom") for c in chunks if c.strip()]
|
2025-02-25 04:03:59 -05:00
|
|
|
|
2025-02-25 16:17:13 -05:00
|
|
|
class EnhancedDataCollator(DataCollatorForLanguageModeling):
|
2025-02-25 15:57:05 -05:00
|
|
|
def torch_call(self, examples):
|
2025-02-25 16:17:13 -05:00
|
|
|
batch = super().torch_call(examples)
|
2025-02-25 15:57:05 -05:00
|
|
|
if "source_idx" in examples[0]:
|
2025-02-25 16:17:13 -05:00
|
|
|
batch["source_idx"] = torch.tensor(
|
|
|
|
|
[ex["source_idx"] for ex in examples],
|
|
|
|
|
dtype=torch.long
|
|
|
|
|
)
|
2025-02-25 15:57:05 -05:00
|
|
|
return batch
|
|
|
|
|
|
2025-02-25 14:38:44 -05:00
|
|
|
def main():
|
2025-02-25 16:17:13 -05:00
|
|
|
# Konfiguracja
|
2025-02-25 14:38:44 -05:00
|
|
|
source_mapper = SourceMapper()
|
2025-02-25 16:17:13 -05:00
|
|
|
processor = LegalDataProcessor("file_catalog.json")
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("crumb/nano-mistral")
|
2025-02-25 14:38:44 -05:00
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
2025-02-25 16:17:13 -05:00
|
|
|
|
2025-02-25 14:38:44 -05:00
|
|
|
# Przygotowanie danych
|
2025-02-25 16:17:13 -05:00
|
|
|
data = []
|
2025-02-25 15:30:01 -05:00
|
|
|
|
2025-02-25 16:17:13 -05:00
|
|
|
def process_file(file_path):
|
|
|
|
|
nonlocal data
|
|
|
|
|
text = processor.extract_text(file_path)
|
|
|
|
|
if not text:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
doc_type = processor.identify_document(os.path.basename(file_path))
|
|
|
|
|
if doc_type != "Opracowanie własne":
|
|
|
|
|
processed = processor.process_legal(text, doc_type)
|
|
|
|
|
else:
|
|
|
|
|
processed = processor.process_custom(text)
|
|
|
|
|
|
|
|
|
|
for text, source in processed:
|
|
|
|
|
source_mapper.add_source(source)
|
|
|
|
|
data.append({
|
|
|
|
|
"text": text,
|
|
|
|
|
"source_idx": source_mapper.get_idx(source)
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
# Przetwarzanie wielowątkowe
|
|
|
|
|
with ThreadPoolExecutor(max_workers=cpu_count()) as executor:
|
|
|
|
|
futures = []
|
|
|
|
|
for root, _, files in os.walk("files"):
|
|
|
|
|
for file in files:
|
|
|
|
|
futures.append(executor.submit(
|
|
|
|
|
process_file,
|
|
|
|
|
os.path.join(root, file)
|
|
|
|
|
))
|
|
|
|
|
|
|
|
|
|
for future in futures:
|
|
|
|
|
try:
|
|
|
|
|
future.result()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"Błąd: {str(e)}")
|
|
|
|
|
|
|
|
|
|
# Augmentacja
|
|
|
|
|
print(f"Przed augmentacją: {len(data)} przykładów")
|
|
|
|
|
augmented = []
|
|
|
|
|
for item in data:
|
|
|
|
|
for _ in range(2): # 2 dodatkowe warianty
|
|
|
|
|
sentences = item['text'].split('. ')
|
|
|
|
|
random.shuffle(sentences)
|
|
|
|
|
augmented.append({
|
|
|
|
|
"text": '. '.join(sentences),
|
|
|
|
|
"source_idx": item["source_idx"]
|
|
|
|
|
})
|
|
|
|
|
data += augmented
|
|
|
|
|
print(f"Po augmentacji: {len(data)} przykładów")
|
|
|
|
|
|
|
|
|
|
# Przygotowanie datasetu
|
2025-02-25 16:00:00 -05:00
|
|
|
dataset = Dataset.from_list(data)
|
2025-02-25 16:17:13 -05:00
|
|
|
|
|
|
|
|
def tokenize_fn(examples):
|
2025-02-25 15:54:33 -05:00
|
|
|
tokenized = tokenizer(
|
2025-02-25 14:59:06 -05:00
|
|
|
examples["text"],
|
|
|
|
|
max_length=512,
|
2025-02-25 16:17:13 -05:00
|
|
|
padding="max_length",
|
|
|
|
|
truncation=True,
|
2025-02-25 14:59:06 -05:00
|
|
|
return_tensors="pt"
|
|
|
|
|
)
|
2025-02-25 15:54:33 -05:00
|
|
|
return {
|
2025-02-25 15:57:05 -05:00
|
|
|
"input_ids": tokenized["input_ids"].squeeze(),
|
|
|
|
|
"attention_mask": tokenized["attention_mask"].squeeze(),
|
2025-02-25 16:17:13 -05:00
|
|
|
"labels": tokenized["input_ids"].squeeze(),
|
|
|
|
|
"source_idx": examples["source_idx"]
|
2025-02-25 15:54:33 -05:00
|
|
|
}
|
2025-02-25 16:17:13 -05:00
|
|
|
|
|
|
|
|
tokenized_ds = dataset.map(
|
|
|
|
|
tokenize_fn,
|
|
|
|
|
batched=True,
|
|
|
|
|
batch_size=32,
|
|
|
|
|
num_proc=4
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Model
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
"crumb/nano-mistral",
|
|
|
|
|
trust_remote_code=True
|
|
|
|
|
)
|
|
|
|
|
model.resize_token_embeddings(len(tokenizer))
|
|
|
|
|
|
|
|
|
|
# Trening
|
2025-02-25 14:38:44 -05:00
|
|
|
training_args = TrainingArguments(
|
|
|
|
|
output_dir="./results",
|
2025-02-25 16:17:13 -05:00
|
|
|
num_train_epochs=5,
|
2025-02-25 15:54:33 -05:00
|
|
|
per_device_train_batch_size=2,
|
2025-02-25 16:17:13 -05:00
|
|
|
gradient_accumulation_steps=8,
|
|
|
|
|
learning_rate=1e-4,
|
2025-02-25 14:38:44 -05:00
|
|
|
fp16=torch.cuda.is_available(),
|
2025-02-25 16:17:13 -05:00
|
|
|
logging_steps=20,
|
|
|
|
|
save_strategy="epoch",
|
|
|
|
|
report_to="none"
|
2025-02-25 14:38:44 -05:00
|
|
|
)
|
2025-02-25 16:17:13 -05:00
|
|
|
|
2025-02-25 16:03:59 -05:00
|
|
|
trainer = Trainer(
|
2025-02-25 14:38:44 -05:00
|
|
|
model=model,
|
|
|
|
|
args=training_args,
|
2025-02-25 16:17:13 -05:00
|
|
|
train_dataset=tokenized_ds,
|
|
|
|
|
data_collator=EnhancedDataCollator(tokenizer=tokenizer, mlm=False)
|
2025-02-25 14:38:44 -05:00
|
|
|
)
|
2025-02-25 16:17:13 -05:00
|
|
|
|
|
|
|
|
print("Rozpoczęcie treningu...")
|
2025-02-25 14:38:44 -05:00
|
|
|
trainer.train()
|
2025-02-25 16:17:13 -05:00
|
|
|
print("Trening zakończony!")
|
|
|
|
|
|
|
|
|
|
# Zapisz model
|
|
|
|
|
model.save_pretrained("./trained_model")
|
|
|
|
|
tokenizer.save_pretrained("./trained_model")
|
2025-02-25 14:38:44 -05:00
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|