This commit is contained in:
l.gabrysiak 2025-02-25 22:17:13 +01:00
parent d073a1733b
commit c55fbe8632
1 changed files with 188 additions and 191 deletions

379
hft.py
View File

@ -1,21 +1,30 @@
import os import os
import torch import torch
import torch.nn as nn import random
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from datasets import Dataset
import re import re
import json import json
import PyPDF2 import PyPDF2
import docx2txt import docx2txt
import pytesseract import pytesseract
import numpy as np
from PIL import Image from PIL import Image
from collections import defaultdict from collections import defaultdict
from multiprocessing import cpu_count
from concurrent.futures import ThreadPoolExecutor
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from datasets import Dataset
from nlpaug import Augmenter, CharAugmenter, WordAugmenter
from huggingface_hub import login from huggingface_hub import login
# Konfiguracja # Konfiguracja
os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
login(token="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") login(token="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") # Zastąp prawdziwym tokenem
class SourceMapper: class SourceMapper:
def __init__(self): def __init__(self):
@ -33,227 +42,215 @@ class SourceMapper:
def get_source(self, idx): def get_source(self, idx):
return self.idx_to_source.get(idx, "Unknown") return self.idx_to_source.get(idx, "Unknown")
def load_file_catalog(catalog_path): class LegalDataProcessor:
try: def __init__(self, catalog_path):
with open(catalog_path, 'r', encoding='utf-8') as file: self.catalog = self.load_catalog(catalog_path)
return json.load(file) self.augmenter = WordAugmenter.AntonymAug()
except Exception as e:
print(f"Błąd wczytywania katalogu plików: {str(e)}")
return {}
def identify_legal_document(filename, file_catalog):
base_name = os.path.splitext(filename)[0].lower()
return file_catalog.get(base_name, "Opracowanie własne")
def extract_text_from_file(file_path):
try:
_, ext = os.path.splitext(file_path)
ext = ext.lower()
if ext in ['.txt', '.md']: def load_catalog(self, path):
with open(file_path, 'r', encoding='utf-8') as file: try:
return file.read() with open(path, 'r', encoding='utf-8') as f:
elif ext == '.pdf': return json.load(f)
text = "" except:
try: return defaultdict(str)
with open(file_path, 'rb') as file:
reader = PyPDF2.PdfReader(file) def identify_document(self, filename):
for page in reader.pages: base = os.path.splitext(filename)[0].lower()
text += page.extract_text() or "" return self.catalog.get(base, "Opracowanie własne")
except Exception as e:
print(f"Błąd PDF: {str(e)}") def extract_text(self, file_path):
return text ext = os.path.splitext(file_path)[1].lower()
elif ext in ['.doc', '.docx']: try:
return docx2txt.process(file_path) if ext == '.pdf':
elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']: return self._extract_pdf(file_path)
return pytesseract.image_to_string(Image.open(file_path)) elif ext in ['.doc', '.docx']:
else: return docx2txt.process(file_path)
print(f"Nieobsługiwany format pliku: {ext}") elif ext in ['.jpg', '.jpeg', '.png']:
return self._extract_ocr(file_path)
else:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
print(f"Błąd przetwarzania {file_path}: {str(e)}")
return "" return ""
except Exception as e:
print(f"Błąd ekstrakcji tekstu: {str(e)}")
return ""
def prepare_dataset(directory, catalog_path, source_mapper):
file_catalog = load_file_catalog(catalog_path)
data = []
print(f"\n{'='*50}\nDIAGNOSTYKA DANYCH\n{'='*50}") def _extract_pdf(self, path):
text = ""
with open(path, 'rb') as f:
reader = PyPDF2.PdfReader(f)
for page in reader.pages:
text += page.extract_text() + "\n"
return re.sub(r'\s+', ' ', text)
for root, _, files in os.walk(directory): def _extract_ocr(self, path):
for file in files: return pytesseract.image_to_string(
file_path = os.path.join(root, file) Image.open(path),
print(f"\nPrzetwarzanie pliku: {file_path}") config='--psm 4 --oem 3 -c preserve_interword_spaces=1'
try:
text = extract_text_from_file(file_path)
if not text.strip():
print("Pominięto - brak tekstu")
continue
print(f"Długość tekstu: {len(text)} znaków")
doc_type = identify_legal_document(file, file_catalog)
print(f"Rozpoznany typ dokumentu: {doc_type}")
if doc_type != "Opracowanie własne":
articles = re.split(r'(?i)(Art[\.\s]+\d+[\.\s]?)', text)
articles = [a.strip() for a in articles if a.strip()]
print(f"Znaleziono {len(articles)} fragmentów")
for i in range(0, len(articles)-1, 2):
article_number = articles[i]
article_content = articles[i+1]
if len(article_content) < 50:
continue
source = f"{doc_type}, {article_number}"
source_mapper.add_source(source)
data.append({
"text": f"{article_number} {article_content}",
"source_idx": source_mapper.get_idx(source)
})
else:
clean_text = re.sub(r'\s+', ' ', text).strip()
chunks = [clean_text[i:i+512] for i in range(0, len(clean_text), 512)]
chunks = [c for c in chunks if c.strip()]
for chunk in chunks:
data.append({
"text": chunk,
"source_idx": -1
})
print(f"Dodano {len(chunks)} chunków")
except Exception as e:
print(f"Błąd podczas przetwarzania pliku: {str(e)}")
continue
print(f"\nPodsumowanie przygotowania danych:")
print(f"Łączna liczba przykładów: {len(data)}")
if data:
print("Przykładowy wpis:")
print(json.dumps(data[0], indent=2, ensure_ascii=False))
else:
print("BRAK DANYCH - sprawdź diagnostykę powyżej")
return data
class CustomModel(nn.Module):
def __init__(self, model_name, config):
super().__init__()
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
self.source_embedding = nn.Embedding(10000, config.hidden_size, padding_idx=-1)
for param in self.base_model.parameters():
param.requires_grad = False
for param in self.base_model.get_output_embeddings().parameters():
param.requires_grad = True
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
if source_idx is not None:
valid_indices = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1)
source_embeds = self.source_embedding(valid_indices).unsqueeze(1)
inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds
return self.base_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels,
**kwargs
)
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
**kwargs
) )
def generate(self, *args, **kwargs): def process_legal(self, text, doc_type):
return self.base_model.generate(*args, **kwargs) articles = re.split(
r'(?ix)(Art\.?\s*\d+[a-z]*|§\s*\d+|Rozdział\s+[IVXLCDM]+)\b',
text
)
processed = []
current_header = ""
for item in articles:
if item and re.match(r'(?i)(Art|§|Rozdział)', item):
if current_header:
processed.append(current_header)
current_header = item.strip()
elif current_header:
processed.append(current_header + " " + item.strip())
current_header = ""
else:
processed.append(item.strip())
return [
(f"[{doc_type}] {p}", doc_type)
for p in processed if len(p) > 30
]
def process_custom(self, text):
clean_text = re.sub(r'\s+', ' ', text).strip()
chunk_size = 384
overlap = 128
chunks = [
clean_text[i:i+chunk_size]
for i in range(0, len(clean_text), chunk_size - overlap)
]
return [("[Custom] " + c, "Custom") for c in chunks if c.strip()]
class CustomDataCollator(DataCollatorForLanguageModeling): class EnhancedDataCollator(DataCollatorForLanguageModeling):
def torch_call(self, examples): def torch_call(self, examples):
# Przetwórz podstawowe pola batch = super().torch_call(examples)
input_ids = torch.stack([torch.tensor(ex["input_ids"]) for ex in examples])
attention_mask = torch.stack([torch.tensor(ex["attention_mask"]) for ex in examples])
labels = torch.stack([torch.tensor(ex["labels"]) for ex in examples])
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}
# Dodaj source_idx jeśli istnieje
if "source_idx" in examples[0]: if "source_idx" in examples[0]:
source_idx = torch.stack([torch.tensor(ex["source_idx"]) for ex in examples]) batch["source_idx"] = torch.tensor(
batch["source_idx"] = source_idx [ex["source_idx"] for ex in examples],
dtype=torch.long
)
return batch return batch
def main(): def main():
# Konfiguracja
source_mapper = SourceMapper() source_mapper = SourceMapper()
model_name = "crumb/nano-mistral" processor = LegalDataProcessor("file_catalog.json")
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained("crumb/nano-mistral")
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
# Przygotowanie danych
catalog_path = "file_catalog.json"
data = prepare_dataset("files", catalog_path, source_mapper)
if not data: # Przygotowanie danych
print("\nBrak danych do treningu!") data = []
return
def process_file(file_path):
nonlocal data
text = processor.extract_text(file_path)
if not text:
return
doc_type = processor.identify_document(os.path.basename(file_path))
if doc_type != "Opracowanie własne":
processed = processor.process_legal(text, doc_type)
else:
processed = processor.process_custom(text)
for text, source in processed:
source_mapper.add_source(source)
data.append({
"text": text,
"source_idx": source_mapper.get_idx(source)
})
# Przetwarzanie wielowątkowe
with ThreadPoolExecutor(max_workers=cpu_count()) as executor:
futures = []
for root, _, files in os.walk("files"):
for file in files:
futures.append(executor.submit(
process_file,
os.path.join(root, file)
))
for future in futures:
try:
future.result()
except Exception as e:
print(f"Błąd: {str(e)}")
# Augmentacja
print(f"Przed augmentacją: {len(data)} przykładów")
augmented = []
for item in data:
for _ in range(2): # 2 dodatkowe warianty
sentences = item['text'].split('. ')
random.shuffle(sentences)
augmented.append({
"text": '. '.join(sentences),
"source_idx": item["source_idx"]
})
data += augmented
print(f"Po augmentacji: {len(data)} przykładów")
# Przygotowanie datasetu
dataset = Dataset.from_list(data) dataset = Dataset.from_list(data)
def tokenize_function(examples): def tokenize_fn(examples):
tokenized = tokenizer( tokenized = tokenizer(
examples["text"], examples["text"],
truncation=True,
padding="max_length",
max_length=512, max_length=512,
padding="max_length",
truncation=True,
return_tensors="pt" return_tensors="pt"
) )
return { return {
"input_ids": tokenized["input_ids"].squeeze(), "input_ids": tokenized["input_ids"].squeeze(),
"attention_mask": tokenized["attention_mask"].squeeze(), "attention_mask": tokenized["attention_mask"].squeeze(),
"labels": tokenized["input_ids"].squeeze().clone(), "labels": tokenized["input_ids"].squeeze(),
"source_idx": examples["source_idx"] # Dodano bez konwersji do tensora "source_idx": examples["source_idx"]
} }
tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=16) tokenized_ds = dataset.map(
tokenize_fn,
model = CustomModel(model_name, AutoModelForCausalLM.from_pretrained(model_name).config) batched=True,
model.source_mapper = source_mapper batch_size=32,
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") num_proc=4
model.to(device) )
# Model
model = AutoModelForCausalLM.from_pretrained(
"crumb/nano-mistral",
trust_remote_code=True
)
model.resize_token_embeddings(len(tokenizer))
# Trening
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir="./results", output_dir="./results",
num_train_epochs=3, num_train_epochs=5,
per_device_train_batch_size=2, per_device_train_batch_size=2,
gradient_accumulation_steps=4, gradient_accumulation_steps=8,
learning_rate=2e-5, learning_rate=1e-4,
fp16=torch.cuda.is_available(), fp16=torch.cuda.is_available(),
logging_steps=10, logging_steps=20,
save_strategy="steps", save_strategy="epoch",
save_steps=1000, report_to="none"
report_to="none",
remove_unused_columns=False
) )
trainer = Trainer( trainer = Trainer(
model=model, model=model,
args=training_args, args=training_args,
train_dataset=tokenized_dataset, train_dataset=tokenized_ds,
data_collator=CustomDataCollator(tokenizer=tokenizer, mlm=False) data_collator=EnhancedDataCollator(tokenizer=tokenizer, mlm=False)
) )
print("\nRozpoczęcie treningu...") print("Rozpoczęcie treningu...")
trainer.train() trainer.train()
print("Trening zakończony!")
# Zapisz model
model.save_pretrained("./trained_model")
tokenizer.save_pretrained("./trained_model")
if __name__ == "__main__": if __name__ == "__main__":
main() main()