ably.do/hft.py

285 lines
9.2 KiB
Python

import os
import torch
import random
import re
import json
import PyPDF2
import docx2txt
import pytesseract
import numpy as np
from PIL import Image
from collections import defaultdict
from multiprocessing import cpu_count
from concurrent.futures import ThreadPoolExecutor
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from datasets import Dataset
from nlpaug.augmenter.word import WordAugmenter
from huggingface_hub import login
# Konfiguracja
os.environ["TOKENIZERS_PARALLELISM"] = "false"
login(token="TWÓJ_TOKEN_HF") # Zastąp swoim tokenem
class SourceMapper:
def __init__(self):
self.source_to_idx = defaultdict(lambda: len(self.source_to_idx))
self.idx_to_source = {}
def add_source(self, source):
if source and source not in self.source_to_idx:
idx = self.source_to_idx[source]
self.idx_to_source[idx] = source
def get_idx(self, source):
return self.source_to_idx[source] if source else -1
def get_source(self, idx):
return self.idx_to_source.get(idx, "Unknown")
class LegalProcessor:
def __init__(self, catalog_path):
self.catalog = self.load_catalog(catalog_path)
self.augmenter = self.init_augmenter()
def load_catalog(self, path):
try:
with open(path, 'r', encoding='utf-8') as f:
return json.load(f)
except:
return defaultdict(str)
def init_augmenter(self):
return WordAugmenter.SynonymAug(aug_src='wordnet', aug_max=3)
def process_file(self, file_path):
text = self.extract_text(file_path)
if not text:
return []
doc_type = self.identify_doc_type(file_path)
return self.split_content(text, doc_type)
def extract_text(self, file_path):
ext = os.path.splitext(file_path)[1].lower()
try:
if ext == '.pdf':
return self.extract_pdf(file_path)
elif ext in ['.doc', '.docx']:
return docx2txt.process(file_path)
elif ext in ['.jpg', '.jpeg', '.png']:
return self.extract_image(file_path)
else:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
print(f"Błąd przetwarzania {file_path}: {str(e)}")
return ""
def extract_pdf(self, path):
text = ""
with open(path, 'rb') as f:
reader = PyPDF2.PdfReader(f)
for page in reader.pages:
text += page.extract_text() + "\n"
return re.sub(r'\s+', ' ', text)
def extract_image(self, path):
return pytesseract.image_to_string(
Image.open(path),
config='--psm 4 --oem 3 -c preserve_interword_spaces=1'
)
def identify_doc_type(self, file_path):
base = os.path.splitext(os.path.basename(file_path))[0].lower()
return self.catalog.get(base, "Custom")
def split_content(self, text, doc_type):
if doc_type == "Custom":
return self.split_custom(text)
return self.split_legal(text, doc_type)
def split_legal(self, text, doc_type):
pattern = r'(?i)(Art[\.\s]*\d+[a-z]*|§\s*\d+|Rozdział\s+[IVXLCDM]+)'
parts = re.split(pattern, text)
results = []
current_header = ""
for part in parts:
if not part:
continue
if re.match(pattern, part):
if current_header:
results.append(current_header)
current_header = f"[{doc_type}] {part.strip()}"
else:
if current_header:
results.append(f"{current_header}: {part.strip()}")
current_header = ""
else:
results.append(part.strip())
return [text for text in results if len(text) > 50]
def split_custom(self, text):
clean_text = re.sub(r'\s+', ' ', text).strip()
chunk_size = 384
overlap = 64
chunks = []
start = 0
while start < len(clean_text):
end = start + chunk_size
chunks.append(clean_text[start:end])
start = end - overlap
return [f"[Custom] {chunk}" for chunk in chunks if chunk.strip()]
class CustomModel(torch.nn.Module):
def __init__(self, model_name):
super().__init__()
self.base_model = AutoModelForCausalLM.from_pretrained(model_name)
self.source_emb = torch.nn.Embedding(1000, self.base_model.config.hidden_size)
# Zamrożenie parametrów bazowych
for param in self.base_model.parameters():
param.requires_grad = False
# Odmrożenie ostatnich warstw
for layer in self.base_model.transformer.h[-2:]:
for param in layer.parameters():
param.requires_grad = True
self.base_model.get_output_embeddings().requires_grad_(True)
def forward(self, input_ids, attention_mask, labels, source_idx):
inputs_embeds = self.base_model.get_input_embeddings()(input_ids)
source_emb = self.source_emb(source_idx.clamp(0, 999)).unsqueeze(1)
inputs_embeds += source_emb
return self.base_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels
)
def main():
# Inicjalizacja komponentów
source_mapper = SourceMapper()
processor = LegalProcessor("file_catalog.json")
tokenizer = AutoTokenizer.from_pretrained("crumb/nano-mistral")
tokenizer.pad_token = tokenizer.eos_token
# Przetwarzanie danych
data = []
def process_and_augment(file_path):
try:
items = processor.process_file(file_path)
for text in items:
source = text.split("]")[0][1:]
source_mapper.add_source(source)
# Oryginalny tekst
data.append({
"text": text,
"source_idx": source_mapper.get_idx(source)
})
# Augmentacja - 2 warianty
for _ in range(2):
words = text.split()
if len(words) > 5:
# Losowa zamiana kolejności słów
random.shuffle(words)
augmented = " ".join(words)
data.append({
"text": augmented,
"source_idx": source_mapper.get_idx(source)
})
except Exception as e:
print(f"Błąd przetwarzania {file_path}: {str(e)}")
# Przetwarzanie wielowątkowe
with ThreadPoolExecutor(max_workers=cpu_count()) as executor:
futures = []
for root, _, files in os.walk("files"): # Folder z danymi
for file in files:
file_path = os.path.join(root, file)
futures.append(executor.submit(process_and_augment, file_path))
for future in futures:
future.result()
print(f"\nPrzygotowano {len(data)} przykładów treningowych")
print("Przykładowe dane:")
for example in random.sample(data, 3):
print(f"\nŹródło: {source_mapper.get_source(example['source_idx'])}")
print(f"Tekst: {example['text'][:150]}...")
# Przygotowanie datasetu
dataset = Dataset.from_list(data)
def tokenize_fn(examples):
tokenized = tokenizer(
examples["text"],
max_length=512,
padding="max_length",
truncation=True,
return_tensors="pt"
)
return {
"input_ids": tokenized["input_ids"].squeeze(),
"attention_mask": tokenized["attention_mask"].squeeze(),
"labels": tokenized["input_ids"].squeeze(),
"source_idx": examples["source_idx"]
}
tokenized_ds = dataset.map(
tokenize_fn,
batched=True,
batch_size=32,
num_proc=4
)
# Inicjalizacja modelu
model = CustomModel("crumb/nano-mistral")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Konfiguracja treningu
training_args = TrainingArguments(
output_dir="./wyniki",
num_train_epochs=5,
per_device_train_batch_size=2,
gradient_accumulation_steps=8,
learning_rate=2e-5,
fp16=torch.cuda.is_available(),
logging_steps=20,
save_strategy="epoch",
report_to="none"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_ds,
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
)
# Trening
print("\nRozpoczynanie treningu...")
trainer.train()
# Zapis modelu
model.save_pretrained("./trained_legal_model")
tokenizer.save_pretrained("./trained_legal_model")
print("Trening zakończony pomyślnie!")
if __name__ == "__main__":
main()