285 lines
9.2 KiB
Python
285 lines
9.2 KiB
Python
import os
|
|
import torch
|
|
import random
|
|
import re
|
|
import json
|
|
import PyPDF2
|
|
import docx2txt
|
|
import pytesseract
|
|
import numpy as np
|
|
from PIL import Image
|
|
from collections import defaultdict
|
|
from multiprocessing import cpu_count
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from transformers import (
|
|
AutoTokenizer,
|
|
AutoModelForCausalLM,
|
|
TrainingArguments,
|
|
Trainer,
|
|
DataCollatorForLanguageModeling
|
|
)
|
|
from datasets import Dataset
|
|
from nlpaug.augmenter.word import WordAugmenter
|
|
from huggingface_hub import login
|
|
|
|
# Konfiguracja
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
login(token="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") # Zastąp swoim tokenem
|
|
|
|
class SourceMapper:
|
|
def __init__(self):
|
|
self.source_to_idx = defaultdict(lambda: len(self.source_to_idx))
|
|
self.idx_to_source = {}
|
|
|
|
def add_source(self, source):
|
|
if source and source not in self.source_to_idx:
|
|
idx = self.source_to_idx[source]
|
|
self.idx_to_source[idx] = source
|
|
|
|
def get_idx(self, source):
|
|
return self.source_to_idx[source] if source else -1
|
|
|
|
def get_source(self, idx):
|
|
return self.idx_to_source.get(idx, "Unknown")
|
|
|
|
class LegalProcessor:
|
|
def __init__(self, catalog_path):
|
|
self.catalog = self.load_catalog(catalog_path)
|
|
self.augmenter = self.init_augmenter()
|
|
|
|
def load_catalog(self, path):
|
|
try:
|
|
with open(path, 'r', encoding='utf-8') as f:
|
|
return json.load(f)
|
|
except:
|
|
return defaultdict(str)
|
|
|
|
def init_augmenter(self):
|
|
return WordAugmenter.SynonymAug(aug_src='wordnet', aug_max=3)
|
|
|
|
def process_file(self, file_path):
|
|
text = self.extract_text(file_path)
|
|
if not text:
|
|
return []
|
|
|
|
doc_type = self.identify_doc_type(file_path)
|
|
return self.split_content(text, doc_type)
|
|
|
|
def extract_text(self, file_path):
|
|
ext = os.path.splitext(file_path)[1].lower()
|
|
try:
|
|
if ext == '.pdf':
|
|
return self.extract_pdf(file_path)
|
|
elif ext in ['.doc', '.docx']:
|
|
return docx2txt.process(file_path)
|
|
elif ext in ['.jpg', '.jpeg', '.png']:
|
|
return self.extract_image(file_path)
|
|
else:
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
return f.read()
|
|
except Exception as e:
|
|
print(f"Błąd przetwarzania {file_path}: {str(e)}")
|
|
return ""
|
|
|
|
def extract_pdf(self, path):
|
|
text = ""
|
|
with open(path, 'rb') as f:
|
|
reader = PyPDF2.PdfReader(f)
|
|
for page in reader.pages:
|
|
text += page.extract_text() + "\n"
|
|
return re.sub(r'\s+', ' ', text)
|
|
|
|
def extract_image(self, path):
|
|
return pytesseract.image_to_string(
|
|
Image.open(path),
|
|
config='--psm 4 --oem 3 -c preserve_interword_spaces=1'
|
|
)
|
|
|
|
def identify_doc_type(self, file_path):
|
|
base = os.path.splitext(os.path.basename(file_path))[0].lower()
|
|
return self.catalog.get(base, "Custom")
|
|
|
|
def split_content(self, text, doc_type):
|
|
if doc_type == "Custom":
|
|
return self.split_custom(text)
|
|
return self.split_legal(text, doc_type)
|
|
|
|
def split_legal(self, text, doc_type):
|
|
pattern = r'(?i)(Art[\.\s]*\d+[a-z]*|§\s*\d+|Rozdział\s+[IVXLCDM]+)'
|
|
parts = re.split(pattern, text)
|
|
results = []
|
|
current_header = ""
|
|
|
|
for part in parts:
|
|
if not part:
|
|
continue
|
|
if re.match(pattern, part):
|
|
if current_header:
|
|
results.append(current_header)
|
|
current_header = f"[{doc_type}] {part.strip()}"
|
|
else:
|
|
if current_header:
|
|
results.append(f"{current_header}: {part.strip()}")
|
|
current_header = ""
|
|
else:
|
|
results.append(part.strip())
|
|
|
|
return [text for text in results if len(text) > 50]
|
|
|
|
def split_custom(self, text):
|
|
clean_text = re.sub(r'\s+', ' ', text).strip()
|
|
chunk_size = 384
|
|
overlap = 64
|
|
|
|
chunks = []
|
|
start = 0
|
|
while start < len(clean_text):
|
|
end = start + chunk_size
|
|
chunks.append(clean_text[start:end])
|
|
start = end - overlap
|
|
|
|
return [f"[Custom] {chunk}" for chunk in chunks if chunk.strip()]
|
|
|
|
class CustomModel(torch.nn.Module):
|
|
def __init__(self, model_name):
|
|
super().__init__()
|
|
self.base_model = AutoModelForCausalLM.from_pretrained(model_name)
|
|
self.source_emb = torch.nn.Embedding(1000, self.base_model.config.hidden_size)
|
|
|
|
# Zamrożenie parametrów bazowych
|
|
for param in self.base_model.parameters():
|
|
param.requires_grad = False
|
|
|
|
# Odmrożenie ostatnich warstw
|
|
for layer in self.base_model.transformer.h[-2:]:
|
|
for param in layer.parameters():
|
|
param.requires_grad = True
|
|
|
|
self.base_model.get_output_embeddings().requires_grad_(True)
|
|
|
|
def forward(self, input_ids, attention_mask, labels, source_idx):
|
|
inputs_embeds = self.base_model.get_input_embeddings()(input_ids)
|
|
source_emb = self.source_emb(source_idx.clamp(0, 999)).unsqueeze(1)
|
|
inputs_embeds += source_emb
|
|
|
|
return self.base_model(
|
|
inputs_embeds=inputs_embeds,
|
|
attention_mask=attention_mask,
|
|
labels=labels
|
|
)
|
|
|
|
def main():
|
|
# Inicjalizacja komponentów
|
|
source_mapper = SourceMapper()
|
|
processor = LegalProcessor("file_catalog.json")
|
|
tokenizer = AutoTokenizer.from_pretrained("crumb/nano-mistral")
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
# Przetwarzanie danych
|
|
data = []
|
|
|
|
def process_and_augment(file_path):
|
|
try:
|
|
items = processor.process_file(file_path)
|
|
for text in items:
|
|
source = text.split("]")[0][1:]
|
|
source_mapper.add_source(source)
|
|
|
|
# Oryginalny tekst
|
|
data.append({
|
|
"text": text,
|
|
"source_idx": source_mapper.get_idx(source)
|
|
})
|
|
|
|
# Augmentacja - 2 warianty
|
|
for _ in range(2):
|
|
words = text.split()
|
|
if len(words) > 5:
|
|
# Losowa zamiana kolejności słów
|
|
random.shuffle(words)
|
|
augmented = " ".join(words)
|
|
data.append({
|
|
"text": augmented,
|
|
"source_idx": source_mapper.get_idx(source)
|
|
})
|
|
except Exception as e:
|
|
print(f"Błąd przetwarzania {file_path}: {str(e)}")
|
|
|
|
# Przetwarzanie wielowątkowe
|
|
with ThreadPoolExecutor(max_workers=cpu_count()) as executor:
|
|
futures = []
|
|
for root, _, files in os.walk("files"): # Folder z danymi
|
|
for file in files:
|
|
file_path = os.path.join(root, file)
|
|
futures.append(executor.submit(process_and_augment, file_path))
|
|
|
|
for future in futures:
|
|
future.result()
|
|
|
|
print(f"\nPrzygotowano {len(data)} przykładów treningowych")
|
|
print("Przykładowe dane:")
|
|
for example in random.sample(data, 3):
|
|
print(f"\nŹródło: {source_mapper.get_source(example['source_idx'])}")
|
|
print(f"Tekst: {example['text'][:150]}...")
|
|
|
|
# Przygotowanie datasetu
|
|
dataset = Dataset.from_list(data)
|
|
|
|
def tokenize_fn(examples):
|
|
tokenized = tokenizer(
|
|
examples["text"],
|
|
max_length=512,
|
|
padding="max_length",
|
|
truncation=True,
|
|
return_tensors="pt"
|
|
)
|
|
return {
|
|
"input_ids": tokenized["input_ids"].squeeze(),
|
|
"attention_mask": tokenized["attention_mask"].squeeze(),
|
|
"labels": tokenized["input_ids"].squeeze(),
|
|
"source_idx": examples["source_idx"]
|
|
}
|
|
|
|
tokenized_ds = dataset.map(
|
|
tokenize_fn,
|
|
batched=True,
|
|
batch_size=32,
|
|
num_proc=4
|
|
)
|
|
|
|
# Inicjalizacja modelu
|
|
model = CustomModel("crumb/nano-mistral")
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model.to(device)
|
|
|
|
# Konfiguracja treningu
|
|
training_args = TrainingArguments(
|
|
output_dir="./wyniki",
|
|
num_train_epochs=5,
|
|
per_device_train_batch_size=2,
|
|
gradient_accumulation_steps=8,
|
|
learning_rate=2e-5,
|
|
fp16=torch.cuda.is_available(),
|
|
logging_steps=20,
|
|
save_strategy="epoch",
|
|
report_to="none"
|
|
)
|
|
|
|
trainer = Trainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=tokenized_ds,
|
|
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
|
)
|
|
|
|
# Trening
|
|
print("\nRozpoczynanie treningu...")
|
|
trainer.train()
|
|
|
|
# Zapis modelu
|
|
model.save_pretrained("./trained_legal_model")
|
|
tokenizer.save_pretrained("./trained_legal_model")
|
|
print("Trening zakończony pomyślnie!")
|
|
|
|
if __name__ == "__main__":
|
|
main() |