ably.do/hft.py

256 lines
7.8 KiB
Python

import os
import torch
import random
import re
import json
import PyPDF2
import docx2txt
import pytesseract
import numpy as np
from PIL import Image
from collections import defaultdict
from multiprocessing import cpu_count
from concurrent.futures import ThreadPoolExecutor
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from datasets import Dataset
from nlpaug import Augmenter, CharAugmenter, WordAugmenter
from huggingface_hub import login
# Konfiguracja
os.environ["TOKENIZERS_PARALLELISM"] = "false"
login(token="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") # Zastąp prawdziwym tokenem
class SourceMapper:
def __init__(self):
self.source_to_idx = defaultdict(lambda: len(self.source_to_idx))
self.idx_to_source = {}
def add_source(self, source):
if source and source not in self.source_to_idx:
idx = self.source_to_idx[source]
self.idx_to_source[idx] = source
def get_idx(self, source):
return self.source_to_idx[source] if source else -1
def get_source(self, idx):
return self.idx_to_source.get(idx, "Unknown")
class LegalDataProcessor:
def __init__(self, catalog_path):
self.catalog = self.load_catalog(catalog_path)
self.augmenter = WordAugmenter.AntonymAug()
def load_catalog(self, path):
try:
with open(path, 'r', encoding='utf-8') as f:
return json.load(f)
except:
return defaultdict(str)
def identify_document(self, filename):
base = os.path.splitext(filename)[0].lower()
return self.catalog.get(base, "Opracowanie własne")
def extract_text(self, file_path):
ext = os.path.splitext(file_path)[1].lower()
try:
if ext == '.pdf':
return self._extract_pdf(file_path)
elif ext in ['.doc', '.docx']:
return docx2txt.process(file_path)
elif ext in ['.jpg', '.jpeg', '.png']:
return self._extract_ocr(file_path)
else:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
print(f"Błąd przetwarzania {file_path}: {str(e)}")
return ""
def _extract_pdf(self, path):
text = ""
with open(path, 'rb') as f:
reader = PyPDF2.PdfReader(f)
for page in reader.pages:
text += page.extract_text() + "\n"
return re.sub(r'\s+', ' ', text)
def _extract_ocr(self, path):
return pytesseract.image_to_string(
Image.open(path),
config='--psm 4 --oem 3 -c preserve_interword_spaces=1'
)
def process_legal(self, text, doc_type):
articles = re.split(
r'(?ix)(Art\.?\s*\d+[a-z]*|§\s*\d+|Rozdział\s+[IVXLCDM]+)\b',
text
)
processed = []
current_header = ""
for item in articles:
if item and re.match(r'(?i)(Art|§|Rozdział)', item):
if current_header:
processed.append(current_header)
current_header = item.strip()
elif current_header:
processed.append(current_header + " " + item.strip())
current_header = ""
else:
processed.append(item.strip())
return [
(f"[{doc_type}] {p}", doc_type)
for p in processed if len(p) > 30
]
def process_custom(self, text):
clean_text = re.sub(r'\s+', ' ', text).strip()
chunk_size = 384
overlap = 128
chunks = [
clean_text[i:i+chunk_size]
for i in range(0, len(clean_text), chunk_size - overlap)
]
return [("[Custom] " + c, "Custom") for c in chunks if c.strip()]
class EnhancedDataCollator(DataCollatorForLanguageModeling):
def torch_call(self, examples):
batch = super().torch_call(examples)
if "source_idx" in examples[0]:
batch["source_idx"] = torch.tensor(
[ex["source_idx"] for ex in examples],
dtype=torch.long
)
return batch
def main():
# Konfiguracja
source_mapper = SourceMapper()
processor = LegalDataProcessor("file_catalog.json")
tokenizer = AutoTokenizer.from_pretrained("crumb/nano-mistral")
tokenizer.pad_token = tokenizer.eos_token
# Przygotowanie danych
data = []
def process_file(file_path):
nonlocal data
text = processor.extract_text(file_path)
if not text:
return
doc_type = processor.identify_document(os.path.basename(file_path))
if doc_type != "Opracowanie własne":
processed = processor.process_legal(text, doc_type)
else:
processed = processor.process_custom(text)
for text, source in processed:
source_mapper.add_source(source)
data.append({
"text": text,
"source_idx": source_mapper.get_idx(source)
})
# Przetwarzanie wielowątkowe
with ThreadPoolExecutor(max_workers=cpu_count()) as executor:
futures = []
for root, _, files in os.walk("files"):
for file in files:
futures.append(executor.submit(
process_file,
os.path.join(root, file)
))
for future in futures:
try:
future.result()
except Exception as e:
print(f"Błąd: {str(e)}")
# Augmentacja
print(f"Przed augmentacją: {len(data)} przykładów")
augmented = []
for item in data:
for _ in range(2): # 2 dodatkowe warianty
sentences = item['text'].split('. ')
random.shuffle(sentences)
augmented.append({
"text": '. '.join(sentences),
"source_idx": item["source_idx"]
})
data += augmented
print(f"Po augmentacji: {len(data)} przykładów")
# Przygotowanie datasetu
dataset = Dataset.from_list(data)
def tokenize_fn(examples):
tokenized = tokenizer(
examples["text"],
max_length=512,
padding="max_length",
truncation=True,
return_tensors="pt"
)
return {
"input_ids": tokenized["input_ids"].squeeze(),
"attention_mask": tokenized["attention_mask"].squeeze(),
"labels": tokenized["input_ids"].squeeze(),
"source_idx": examples["source_idx"]
}
tokenized_ds = dataset.map(
tokenize_fn,
batched=True,
batch_size=32,
num_proc=4
)
# Model
model = AutoModelForCausalLM.from_pretrained(
"crumb/nano-mistral",
trust_remote_code=True
)
model.resize_token_embeddings(len(tokenizer))
# Trening
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=5,
per_device_train_batch_size=2,
gradient_accumulation_steps=8,
learning_rate=1e-4,
fp16=torch.cuda.is_available(),
logging_steps=20,
save_strategy="epoch",
report_to="none"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_ds,
data_collator=EnhancedDataCollator(tokenizer=tokenizer, mlm=False)
)
print("Rozpoczęcie treningu...")
trainer.train()
print("Trening zakończony!")
# Zapisz model
model.save_pretrained("./trained_model")
tokenizer.save_pretrained("./trained_model")
if __name__ == "__main__":
main()