ably.do/hft.py

256 lines
7.8 KiB
Python
Raw Normal View History

2025-02-25 04:03:59 -05:00
import os
import torch
2025-02-25 16:17:13 -05:00
import random
2025-02-25 04:03:59 -05:00
import re
2025-02-25 06:21:39 -05:00
import json
2025-02-25 15:23:33 -05:00
import PyPDF2
import docx2txt
import pytesseract
2025-02-25 16:17:13 -05:00
import numpy as np
2025-02-25 15:23:33 -05:00
from PIL import Image
2025-02-25 07:34:04 -05:00
from collections import defaultdict
2025-02-25 16:17:13 -05:00
from multiprocessing import cpu_count
from concurrent.futures import ThreadPoolExecutor
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from datasets import Dataset
from nlpaug import Augmenter, CharAugmenter, WordAugmenter
2025-02-25 04:45:37 -05:00
from huggingface_hub import login
2025-02-25 15:17:17 -05:00
# Konfiguracja
2025-02-25 07:17:17 -05:00
os.environ["TOKENIZERS_PARALLELISM"] = "false"
2025-02-25 16:17:13 -05:00
login(token="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") # Zastąp prawdziwym tokenem
2025-02-25 11:24:26 -05:00
2025-02-25 07:34:04 -05:00
class SourceMapper:
def __init__(self):
2025-02-25 09:20:55 -05:00
self.source_to_idx = defaultdict(lambda: len(self.source_to_idx))
self.idx_to_source = {}
2025-02-25 07:34:04 -05:00
def add_source(self, source):
if source and source not in self.source_to_idx:
2025-02-25 09:20:55 -05:00
idx = self.source_to_idx[source]
2025-02-25 07:34:04 -05:00
self.idx_to_source[idx] = source
def get_idx(self, source):
2025-02-25 09:20:55 -05:00
return self.source_to_idx[source] if source else -1
2025-02-25 07:34:04 -05:00
def get_source(self, idx):
return self.idx_to_source.get(idx, "Unknown")
2025-02-25 16:17:13 -05:00
class LegalDataProcessor:
def __init__(self, catalog_path):
self.catalog = self.load_catalog(catalog_path)
self.augmenter = WordAugmenter.AntonymAug()
2025-02-25 15:30:01 -05:00
2025-02-25 16:17:13 -05:00
def load_catalog(self, path):
try:
with open(path, 'r', encoding='utf-8') as f:
return json.load(f)
except:
return defaultdict(str)
def identify_document(self, filename):
base = os.path.splitext(filename)[0].lower()
return self.catalog.get(base, "Opracowanie własne")
def extract_text(self, file_path):
ext = os.path.splitext(file_path)[1].lower()
try:
if ext == '.pdf':
return self._extract_pdf(file_path)
elif ext in ['.doc', '.docx']:
return docx2txt.process(file_path)
elif ext in ['.jpg', '.jpeg', '.png']:
return self._extract_ocr(file_path)
else:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
print(f"Błąd przetwarzania {file_path}: {str(e)}")
2025-02-25 15:30:01 -05:00
return ""
2025-02-25 07:34:04 -05:00
2025-02-25 16:17:13 -05:00
def _extract_pdf(self, path):
text = ""
with open(path, 'rb') as f:
reader = PyPDF2.PdfReader(f)
for page in reader.pages:
text += page.extract_text() + "\n"
return re.sub(r'\s+', ' ', text)
2025-02-25 15:30:01 -05:00
2025-02-25 16:17:13 -05:00
def _extract_ocr(self, path):
return pytesseract.image_to_string(
Image.open(path),
config='--psm 4 --oem 3 -c preserve_interword_spaces=1'
)
def process_legal(self, text, doc_type):
articles = re.split(
r'(?ix)(Art\.?\s*\d+[a-z]*|§\s*\d+|Rozdział\s+[IVXLCDM]+)\b',
text
)
processed = []
current_header = ""
2025-02-25 15:30:01 -05:00
2025-02-25 16:17:13 -05:00
for item in articles:
if item and re.match(r'(?i)(Art|§|Rozdział)', item):
if current_header:
processed.append(current_header)
current_header = item.strip()
elif current_header:
processed.append(current_header + " " + item.strip())
current_header = ""
else:
processed.append(item.strip())
2025-02-25 07:34:04 -05:00
2025-02-25 16:17:13 -05:00
return [
(f"[{doc_type}] {p}", doc_type)
for p in processed if len(p) > 30
]
2025-02-25 14:09:36 -05:00
2025-02-25 16:17:13 -05:00
def process_custom(self, text):
clean_text = re.sub(r'\s+', ' ', text).strip()
chunk_size = 384
overlap = 128
chunks = [
clean_text[i:i+chunk_size]
for i in range(0, len(clean_text), chunk_size - overlap)
]
return [("[Custom] " + c, "Custom") for c in chunks if c.strip()]
2025-02-25 04:03:59 -05:00
2025-02-25 16:17:13 -05:00
class EnhancedDataCollator(DataCollatorForLanguageModeling):
2025-02-25 15:57:05 -05:00
def torch_call(self, examples):
2025-02-25 16:17:13 -05:00
batch = super().torch_call(examples)
2025-02-25 15:57:05 -05:00
if "source_idx" in examples[0]:
2025-02-25 16:17:13 -05:00
batch["source_idx"] = torch.tensor(
[ex["source_idx"] for ex in examples],
dtype=torch.long
)
2025-02-25 15:57:05 -05:00
return batch
2025-02-25 14:38:44 -05:00
def main():
2025-02-25 16:17:13 -05:00
# Konfiguracja
2025-02-25 14:38:44 -05:00
source_mapper = SourceMapper()
2025-02-25 16:17:13 -05:00
processor = LegalDataProcessor("file_catalog.json")
tokenizer = AutoTokenizer.from_pretrained("crumb/nano-mistral")
2025-02-25 14:38:44 -05:00
tokenizer.pad_token = tokenizer.eos_token
2025-02-25 16:17:13 -05:00
2025-02-25 14:38:44 -05:00
# Przygotowanie danych
2025-02-25 16:17:13 -05:00
data = []
2025-02-25 15:30:01 -05:00
2025-02-25 16:17:13 -05:00
def process_file(file_path):
nonlocal data
text = processor.extract_text(file_path)
if not text:
return
doc_type = processor.identify_document(os.path.basename(file_path))
if doc_type != "Opracowanie własne":
processed = processor.process_legal(text, doc_type)
else:
processed = processor.process_custom(text)
for text, source in processed:
source_mapper.add_source(source)
data.append({
"text": text,
"source_idx": source_mapper.get_idx(source)
})
# Przetwarzanie wielowątkowe
with ThreadPoolExecutor(max_workers=cpu_count()) as executor:
futures = []
for root, _, files in os.walk("files"):
for file in files:
futures.append(executor.submit(
process_file,
os.path.join(root, file)
))
for future in futures:
try:
future.result()
except Exception as e:
print(f"Błąd: {str(e)}")
# Augmentacja
print(f"Przed augmentacją: {len(data)} przykładów")
augmented = []
for item in data:
for _ in range(2): # 2 dodatkowe warianty
sentences = item['text'].split('. ')
random.shuffle(sentences)
augmented.append({
"text": '. '.join(sentences),
"source_idx": item["source_idx"]
})
data += augmented
print(f"Po augmentacji: {len(data)} przykładów")
# Przygotowanie datasetu
2025-02-25 16:00:00 -05:00
dataset = Dataset.from_list(data)
2025-02-25 16:17:13 -05:00
def tokenize_fn(examples):
2025-02-25 15:54:33 -05:00
tokenized = tokenizer(
2025-02-25 14:59:06 -05:00
examples["text"],
max_length=512,
2025-02-25 16:17:13 -05:00
padding="max_length",
truncation=True,
2025-02-25 14:59:06 -05:00
return_tensors="pt"
)
2025-02-25 15:54:33 -05:00
return {
2025-02-25 15:57:05 -05:00
"input_ids": tokenized["input_ids"].squeeze(),
"attention_mask": tokenized["attention_mask"].squeeze(),
2025-02-25 16:17:13 -05:00
"labels": tokenized["input_ids"].squeeze(),
"source_idx": examples["source_idx"]
2025-02-25 15:54:33 -05:00
}
2025-02-25 16:17:13 -05:00
tokenized_ds = dataset.map(
tokenize_fn,
batched=True,
batch_size=32,
num_proc=4
)
# Model
model = AutoModelForCausalLM.from_pretrained(
"crumb/nano-mistral",
trust_remote_code=True
)
model.resize_token_embeddings(len(tokenizer))
# Trening
2025-02-25 14:38:44 -05:00
training_args = TrainingArguments(
output_dir="./results",
2025-02-25 16:17:13 -05:00
num_train_epochs=5,
2025-02-25 15:54:33 -05:00
per_device_train_batch_size=2,
2025-02-25 16:17:13 -05:00
gradient_accumulation_steps=8,
learning_rate=1e-4,
2025-02-25 14:38:44 -05:00
fp16=torch.cuda.is_available(),
2025-02-25 16:17:13 -05:00
logging_steps=20,
save_strategy="epoch",
report_to="none"
2025-02-25 14:38:44 -05:00
)
2025-02-25 16:17:13 -05:00
2025-02-25 16:03:59 -05:00
trainer = Trainer(
2025-02-25 14:38:44 -05:00
model=model,
args=training_args,
2025-02-25 16:17:13 -05:00
train_dataset=tokenized_ds,
data_collator=EnhancedDataCollator(tokenizer=tokenizer, mlm=False)
2025-02-25 14:38:44 -05:00
)
2025-02-25 16:17:13 -05:00
print("Rozpoczęcie treningu...")
2025-02-25 14:38:44 -05:00
trainer.train()
2025-02-25 16:17:13 -05:00
print("Trening zakończony!")
# Zapisz model
model.save_pretrained("./trained_model")
tokenizer.save_pretrained("./trained_model")
2025-02-25 14:38:44 -05:00
if __name__ == "__main__":
main()