This commit is contained in:
l.gabrysiak 2025-02-25 22:17:13 +01:00
parent d073a1733b
commit c55fbe8632
1 changed files with 188 additions and 191 deletions

339
hft.py
View File

@ -1,21 +1,30 @@
import os import os
import torch import torch
import torch.nn as nn import random
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from datasets import Dataset
import re import re
import json import json
import PyPDF2 import PyPDF2
import docx2txt import docx2txt
import pytesseract import pytesseract
import numpy as np
from PIL import Image from PIL import Image
from collections import defaultdict from collections import defaultdict
from multiprocessing import cpu_count
from concurrent.futures import ThreadPoolExecutor
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from datasets import Dataset
from nlpaug import Augmenter, CharAugmenter, WordAugmenter
from huggingface_hub import login from huggingface_hub import login
# Konfiguracja # Konfiguracja
os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
login(token="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") login(token="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") # Zastąp prawdziwym tokenem
class SourceMapper: class SourceMapper:
def __init__(self): def __init__(self):
@ -33,227 +42,215 @@ class SourceMapper:
def get_source(self, idx): def get_source(self, idx):
return self.idx_to_source.get(idx, "Unknown") return self.idx_to_source.get(idx, "Unknown")
def load_file_catalog(catalog_path): class LegalDataProcessor:
try: def __init__(self, catalog_path):
with open(catalog_path, 'r', encoding='utf-8') as file: self.catalog = self.load_catalog(catalog_path)
return json.load(file) self.augmenter = WordAugmenter.AntonymAug()
except Exception as e:
print(f"Błąd wczytywania katalogu plików: {str(e)}")
return {}
def identify_legal_document(filename, file_catalog): def load_catalog(self, path):
base_name = os.path.splitext(filename)[0].lower()
return file_catalog.get(base_name, "Opracowanie własne")
def extract_text_from_file(file_path):
try: try:
_, ext = os.path.splitext(file_path) with open(path, 'r', encoding='utf-8') as f:
ext = ext.lower() return json.load(f)
except:
return defaultdict(str)
if ext in ['.txt', '.md']: def identify_document(self, filename):
with open(file_path, 'r', encoding='utf-8') as file: base = os.path.splitext(filename)[0].lower()
return file.read() return self.catalog.get(base, "Opracowanie własne")
elif ext == '.pdf':
text = "" def extract_text(self, file_path):
ext = os.path.splitext(file_path)[1].lower()
try: try:
with open(file_path, 'rb') as file: if ext == '.pdf':
reader = PyPDF2.PdfReader(file) return self._extract_pdf(file_path)
for page in reader.pages:
text += page.extract_text() or ""
except Exception as e:
print(f"Błąd PDF: {str(e)}")
return text
elif ext in ['.doc', '.docx']: elif ext in ['.doc', '.docx']:
return docx2txt.process(file_path) return docx2txt.process(file_path)
elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']: elif ext in ['.jpg', '.jpeg', '.png']:
return pytesseract.image_to_string(Image.open(file_path)) return self._extract_ocr(file_path)
else: else:
print(f"Nieobsługiwany format pliku: {ext}") with open(file_path, 'r', encoding='utf-8') as f:
return "" return f.read()
except Exception as e: except Exception as e:
print(f"Błąd ekstrakcji tekstu: {str(e)}") print(f"Błąd przetwarzania {file_path}: {str(e)}")
return "" return ""
def prepare_dataset(directory, catalog_path, source_mapper): def _extract_pdf(self, path):
file_catalog = load_file_catalog(catalog_path) text = ""
data = [] with open(path, 'rb') as f:
reader = PyPDF2.PdfReader(f)
for page in reader.pages:
text += page.extract_text() + "\n"
return re.sub(r'\s+', ' ', text)
print(f"\n{'='*50}\nDIAGNOSTYKA DANYCH\n{'='*50}") def _extract_ocr(self, path):
return pytesseract.image_to_string(
Image.open(path),
config='--psm 4 --oem 3 -c preserve_interword_spaces=1'
)
for root, _, files in os.walk(directory): def process_legal(self, text, doc_type):
for file in files: articles = re.split(
file_path = os.path.join(root, file) r'(?ix)(Art\.?\s*\d+[a-z]*|§\s*\d+|Rozdział\s+[IVXLCDM]+)\b',
print(f"\nPrzetwarzanie pliku: {file_path}") text
)
processed = []
current_header = ""
try: for item in articles:
text = extract_text_from_file(file_path) if item and re.match(r'(?i)(Art|§|Rozdział)', item):
if not text.strip(): if current_header:
print("Pominięto - brak tekstu") processed.append(current_header)
continue current_header = item.strip()
elif current_header:
print(f"Długość tekstu: {len(text)} znaków") processed.append(current_header + " " + item.strip())
current_header = ""
doc_type = identify_legal_document(file, file_catalog)
print(f"Rozpoznany typ dokumentu: {doc_type}")
if doc_type != "Opracowanie własne":
articles = re.split(r'(?i)(Art[\.\s]+\d+[\.\s]?)', text)
articles = [a.strip() for a in articles if a.strip()]
print(f"Znaleziono {len(articles)} fragmentów")
for i in range(0, len(articles)-1, 2):
article_number = articles[i]
article_content = articles[i+1]
if len(article_content) < 50:
continue
source = f"{doc_type}, {article_number}"
source_mapper.add_source(source)
data.append({
"text": f"{article_number} {article_content}",
"source_idx": source_mapper.get_idx(source)
})
else: else:
processed.append(item.strip())
return [
(f"[{doc_type}] {p}", doc_type)
for p in processed if len(p) > 30
]
def process_custom(self, text):
clean_text = re.sub(r'\s+', ' ', text).strip() clean_text = re.sub(r'\s+', ' ', text).strip()
chunks = [clean_text[i:i+512] for i in range(0, len(clean_text), 512)] chunk_size = 384
chunks = [c for c in chunks if c.strip()] overlap = 128
for chunk in chunks: chunks = [
data.append({ clean_text[i:i+chunk_size]
"text": chunk, for i in range(0, len(clean_text), chunk_size - overlap)
"source_idx": -1 ]
}) return [("[Custom] " + c, "Custom") for c in chunks if c.strip()]
print(f"Dodano {len(chunks)} chunków")
except Exception as e: class EnhancedDataCollator(DataCollatorForLanguageModeling):
print(f"Błąd podczas przetwarzania pliku: {str(e)}")
continue
print(f"\nPodsumowanie przygotowania danych:")
print(f"Łączna liczba przykładów: {len(data)}")
if data:
print("Przykładowy wpis:")
print(json.dumps(data[0], indent=2, ensure_ascii=False))
else:
print("BRAK DANYCH - sprawdź diagnostykę powyżej")
return data
class CustomModel(nn.Module):
def __init__(self, model_name, config):
super().__init__()
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
self.source_embedding = nn.Embedding(10000, config.hidden_size, padding_idx=-1)
for param in self.base_model.parameters():
param.requires_grad = False
for param in self.base_model.get_output_embeddings().parameters():
param.requires_grad = True
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
if source_idx is not None:
valid_indices = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1)
source_embeds = self.source_embedding(valid_indices).unsqueeze(1)
inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds
return self.base_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels,
**kwargs
)
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
**kwargs
)
def generate(self, *args, **kwargs):
return self.base_model.generate(*args, **kwargs)
class CustomDataCollator(DataCollatorForLanguageModeling):
def torch_call(self, examples): def torch_call(self, examples):
# Przetwórz podstawowe pola batch = super().torch_call(examples)
input_ids = torch.stack([torch.tensor(ex["input_ids"]) for ex in examples])
attention_mask = torch.stack([torch.tensor(ex["attention_mask"]) for ex in examples])
labels = torch.stack([torch.tensor(ex["labels"]) for ex in examples])
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}
# Dodaj source_idx jeśli istnieje
if "source_idx" in examples[0]: if "source_idx" in examples[0]:
source_idx = torch.stack([torch.tensor(ex["source_idx"]) for ex in examples]) batch["source_idx"] = torch.tensor(
batch["source_idx"] = source_idx [ex["source_idx"] for ex in examples],
dtype=torch.long
)
return batch return batch
def main(): def main():
# Konfiguracja
source_mapper = SourceMapper() source_mapper = SourceMapper()
model_name = "crumb/nano-mistral" processor = LegalDataProcessor("file_catalog.json")
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained("crumb/nano-mistral")
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
# Przygotowanie danych # Przygotowanie danych
catalog_path = "file_catalog.json" data = []
data = prepare_dataset("files", catalog_path, source_mapper)
if not data: def process_file(file_path):
print("\nBrak danych do treningu!") nonlocal data
text = processor.extract_text(file_path)
if not text:
return return
doc_type = processor.identify_document(os.path.basename(file_path))
if doc_type != "Opracowanie własne":
processed = processor.process_legal(text, doc_type)
else:
processed = processor.process_custom(text)
for text, source in processed:
source_mapper.add_source(source)
data.append({
"text": text,
"source_idx": source_mapper.get_idx(source)
})
# Przetwarzanie wielowątkowe
with ThreadPoolExecutor(max_workers=cpu_count()) as executor:
futures = []
for root, _, files in os.walk("files"):
for file in files:
futures.append(executor.submit(
process_file,
os.path.join(root, file)
))
for future in futures:
try:
future.result()
except Exception as e:
print(f"Błąd: {str(e)}")
# Augmentacja
print(f"Przed augmentacją: {len(data)} przykładów")
augmented = []
for item in data:
for _ in range(2): # 2 dodatkowe warianty
sentences = item['text'].split('. ')
random.shuffle(sentences)
augmented.append({
"text": '. '.join(sentences),
"source_idx": item["source_idx"]
})
data += augmented
print(f"Po augmentacji: {len(data)} przykładów")
# Przygotowanie datasetu
dataset = Dataset.from_list(data) dataset = Dataset.from_list(data)
def tokenize_function(examples): def tokenize_fn(examples):
tokenized = tokenizer( tokenized = tokenizer(
examples["text"], examples["text"],
truncation=True,
padding="max_length",
max_length=512, max_length=512,
padding="max_length",
truncation=True,
return_tensors="pt" return_tensors="pt"
) )
return { return {
"input_ids": tokenized["input_ids"].squeeze(), "input_ids": tokenized["input_ids"].squeeze(),
"attention_mask": tokenized["attention_mask"].squeeze(), "attention_mask": tokenized["attention_mask"].squeeze(),
"labels": tokenized["input_ids"].squeeze().clone(), "labels": tokenized["input_ids"].squeeze(),
"source_idx": examples["source_idx"] # Dodano bez konwersji do tensora "source_idx": examples["source_idx"]
} }
tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=16) tokenized_ds = dataset.map(
tokenize_fn,
batched=True,
batch_size=32,
num_proc=4
)
model = CustomModel(model_name, AutoModelForCausalLM.from_pretrained(model_name).config) # Model
model.source_mapper = source_mapper model = AutoModelForCausalLM.from_pretrained(
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") "crumb/nano-mistral",
model.to(device) trust_remote_code=True
)
model.resize_token_embeddings(len(tokenizer))
# Trening
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir="./results", output_dir="./results",
num_train_epochs=3, num_train_epochs=5,
per_device_train_batch_size=2, per_device_train_batch_size=2,
gradient_accumulation_steps=4, gradient_accumulation_steps=8,
learning_rate=2e-5, learning_rate=1e-4,
fp16=torch.cuda.is_available(), fp16=torch.cuda.is_available(),
logging_steps=10, logging_steps=20,
save_strategy="steps", save_strategy="epoch",
save_steps=1000, report_to="none"
report_to="none",
remove_unused_columns=False
) )
trainer = Trainer( trainer = Trainer(
model=model, model=model,
args=training_args, args=training_args,
train_dataset=tokenized_dataset, train_dataset=tokenized_ds,
data_collator=CustomDataCollator(tokenizer=tokenizer, mlm=False) data_collator=EnhancedDataCollator(tokenizer=tokenizer, mlm=False)
) )
print("\nRozpoczęcie treningu...") print("Rozpoczęcie treningu...")
trainer.train() trainer.train()
print("Trening zakończony!")
# Zapisz model
model.save_pretrained("./trained_model")
tokenizer.save_pretrained("./trained_model")
if __name__ == "__main__": if __name__ == "__main__":
main() main()