This commit is contained in:
l.gabrysiak 2025-02-25 22:17:13 +01:00
parent d073a1733b
commit c55fbe8632
1 changed files with 188 additions and 191 deletions

357
hft.py
View File

@ -1,21 +1,30 @@
import os
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from datasets import Dataset
import random
import re
import json
import PyPDF2
import docx2txt
import pytesseract
import numpy as np
from PIL import Image
from collections import defaultdict
from multiprocessing import cpu_count
from concurrent.futures import ThreadPoolExecutor
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from datasets import Dataset
from nlpaug import Augmenter, CharAugmenter, WordAugmenter
from huggingface_hub import login
# Konfiguracja
os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ["TOKENIZERS_PARALLELISM"] = "false"
login(token="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")
login(token="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") # Zastąp prawdziwym tokenem
class SourceMapper:
def __init__(self):
@ -33,227 +42,215 @@ class SourceMapper:
def get_source(self, idx):
return self.idx_to_source.get(idx, "Unknown")
def load_file_catalog(catalog_path):
try:
with open(catalog_path, 'r', encoding='utf-8') as file:
return json.load(file)
except Exception as e:
print(f"Błąd wczytywania katalogu plików: {str(e)}")
return {}
class LegalDataProcessor:
def __init__(self, catalog_path):
self.catalog = self.load_catalog(catalog_path)
self.augmenter = WordAugmenter.AntonymAug()
def identify_legal_document(filename, file_catalog):
base_name = os.path.splitext(filename)[0].lower()
return file_catalog.get(base_name, "Opracowanie własne")
def load_catalog(self, path):
try:
with open(path, 'r', encoding='utf-8') as f:
return json.load(f)
except:
return defaultdict(str)
def extract_text_from_file(file_path):
try:
_, ext = os.path.splitext(file_path)
ext = ext.lower()
def identify_document(self, filename):
base = os.path.splitext(filename)[0].lower()
return self.catalog.get(base, "Opracowanie własne")
if ext in ['.txt', '.md']:
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()
elif ext == '.pdf':
text = ""
try:
with open(file_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
for page in reader.pages:
text += page.extract_text() or ""
except Exception as e:
print(f"Błąd PDF: {str(e)}")
return text
elif ext in ['.doc', '.docx']:
return docx2txt.process(file_path)
elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
return pytesseract.image_to_string(Image.open(file_path))
else:
print(f"Nieobsługiwany format pliku: {ext}")
def extract_text(self, file_path):
ext = os.path.splitext(file_path)[1].lower()
try:
if ext == '.pdf':
return self._extract_pdf(file_path)
elif ext in ['.doc', '.docx']:
return docx2txt.process(file_path)
elif ext in ['.jpg', '.jpeg', '.png']:
return self._extract_ocr(file_path)
else:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
print(f"Błąd przetwarzania {file_path}: {str(e)}")
return ""
except Exception as e:
print(f"Błąd ekstrakcji tekstu: {str(e)}")
return ""
def prepare_dataset(directory, catalog_path, source_mapper):
file_catalog = load_file_catalog(catalog_path)
data = []
def _extract_pdf(self, path):
text = ""
with open(path, 'rb') as f:
reader = PyPDF2.PdfReader(f)
for page in reader.pages:
text += page.extract_text() + "\n"
return re.sub(r'\s+', ' ', text)
print(f"\n{'='*50}\nDIAGNOSTYKA DANYCH\n{'='*50}")
for root, _, files in os.walk(directory):
for file in files:
file_path = os.path.join(root, file)
print(f"\nPrzetwarzanie pliku: {file_path}")
try:
text = extract_text_from_file(file_path)
if not text.strip():
print("Pominięto - brak tekstu")
continue
print(f"Długość tekstu: {len(text)} znaków")
doc_type = identify_legal_document(file, file_catalog)
print(f"Rozpoznany typ dokumentu: {doc_type}")
if doc_type != "Opracowanie własne":
articles = re.split(r'(?i)(Art[\.\s]+\d+[\.\s]?)', text)
articles = [a.strip() for a in articles if a.strip()]
print(f"Znaleziono {len(articles)} fragmentów")
for i in range(0, len(articles)-1, 2):
article_number = articles[i]
article_content = articles[i+1]
if len(article_content) < 50:
continue
source = f"{doc_type}, {article_number}"
source_mapper.add_source(source)
data.append({
"text": f"{article_number} {article_content}",
"source_idx": source_mapper.get_idx(source)
})
else:
clean_text = re.sub(r'\s+', ' ', text).strip()
chunks = [clean_text[i:i+512] for i in range(0, len(clean_text), 512)]
chunks = [c for c in chunks if c.strip()]
for chunk in chunks:
data.append({
"text": chunk,
"source_idx": -1
})
print(f"Dodano {len(chunks)} chunków")
except Exception as e:
print(f"Błąd podczas przetwarzania pliku: {str(e)}")
continue
print(f"\nPodsumowanie przygotowania danych:")
print(f"Łączna liczba przykładów: {len(data)}")
if data:
print("Przykładowy wpis:")
print(json.dumps(data[0], indent=2, ensure_ascii=False))
else:
print("BRAK DANYCH - sprawdź diagnostykę powyżej")
return data
class CustomModel(nn.Module):
def __init__(self, model_name, config):
super().__init__()
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
self.source_embedding = nn.Embedding(10000, config.hidden_size, padding_idx=-1)
for param in self.base_model.parameters():
param.requires_grad = False
for param in self.base_model.get_output_embeddings().parameters():
param.requires_grad = True
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
if source_idx is not None:
valid_indices = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1)
source_embeds = self.source_embedding(valid_indices).unsqueeze(1)
inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds
return self.base_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels,
**kwargs
)
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
**kwargs
def _extract_ocr(self, path):
return pytesseract.image_to_string(
Image.open(path),
config='--psm 4 --oem 3 -c preserve_interword_spaces=1'
)
def generate(self, *args, **kwargs):
return self.base_model.generate(*args, **kwargs)
def process_legal(self, text, doc_type):
articles = re.split(
r'(?ix)(Art\.?\s*\d+[a-z]*|§\s*\d+|Rozdział\s+[IVXLCDM]+)\b',
text
)
processed = []
current_header = ""
class CustomDataCollator(DataCollatorForLanguageModeling):
for item in articles:
if item and re.match(r'(?i)(Art|§|Rozdział)', item):
if current_header:
processed.append(current_header)
current_header = item.strip()
elif current_header:
processed.append(current_header + " " + item.strip())
current_header = ""
else:
processed.append(item.strip())
return [
(f"[{doc_type}] {p}", doc_type)
for p in processed if len(p) > 30
]
def process_custom(self, text):
clean_text = re.sub(r'\s+', ' ', text).strip()
chunk_size = 384
overlap = 128
chunks = [
clean_text[i:i+chunk_size]
for i in range(0, len(clean_text), chunk_size - overlap)
]
return [("[Custom] " + c, "Custom") for c in chunks if c.strip()]
class EnhancedDataCollator(DataCollatorForLanguageModeling):
def torch_call(self, examples):
# Przetwórz podstawowe pola
input_ids = torch.stack([torch.tensor(ex["input_ids"]) for ex in examples])
attention_mask = torch.stack([torch.tensor(ex["attention_mask"]) for ex in examples])
labels = torch.stack([torch.tensor(ex["labels"]) for ex in examples])
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}
# Dodaj source_idx jeśli istnieje
batch = super().torch_call(examples)
if "source_idx" in examples[0]:
source_idx = torch.stack([torch.tensor(ex["source_idx"]) for ex in examples])
batch["source_idx"] = source_idx
batch["source_idx"] = torch.tensor(
[ex["source_idx"] for ex in examples],
dtype=torch.long
)
return batch
def main():
# Konfiguracja
source_mapper = SourceMapper()
model_name = "crumb/nano-mistral"
tokenizer = AutoTokenizer.from_pretrained(model_name)
processor = LegalDataProcessor("file_catalog.json")
tokenizer = AutoTokenizer.from_pretrained("crumb/nano-mistral")
tokenizer.pad_token = tokenizer.eos_token
# Przygotowanie danych
catalog_path = "file_catalog.json"
data = prepare_dataset("files", catalog_path, source_mapper)
data = []
if not data:
print("\nBrak danych do treningu!")
return
def process_file(file_path):
nonlocal data
text = processor.extract_text(file_path)
if not text:
return
doc_type = processor.identify_document(os.path.basename(file_path))
if doc_type != "Opracowanie własne":
processed = processor.process_legal(text, doc_type)
else:
processed = processor.process_custom(text)
for text, source in processed:
source_mapper.add_source(source)
data.append({
"text": text,
"source_idx": source_mapper.get_idx(source)
})
# Przetwarzanie wielowątkowe
with ThreadPoolExecutor(max_workers=cpu_count()) as executor:
futures = []
for root, _, files in os.walk("files"):
for file in files:
futures.append(executor.submit(
process_file,
os.path.join(root, file)
))
for future in futures:
try:
future.result()
except Exception as e:
print(f"Błąd: {str(e)}")
# Augmentacja
print(f"Przed augmentacją: {len(data)} przykładów")
augmented = []
for item in data:
for _ in range(2): # 2 dodatkowe warianty
sentences = item['text'].split('. ')
random.shuffle(sentences)
augmented.append({
"text": '. '.join(sentences),
"source_idx": item["source_idx"]
})
data += augmented
print(f"Po augmentacji: {len(data)} przykładów")
# Przygotowanie datasetu
dataset = Dataset.from_list(data)
def tokenize_function(examples):
def tokenize_fn(examples):
tokenized = tokenizer(
examples["text"],
truncation=True,
padding="max_length",
max_length=512,
padding="max_length",
truncation=True,
return_tensors="pt"
)
return {
"input_ids": tokenized["input_ids"].squeeze(),
"attention_mask": tokenized["attention_mask"].squeeze(),
"labels": tokenized["input_ids"].squeeze().clone(),
"source_idx": examples["source_idx"] # Dodano bez konwersji do tensora
"labels": tokenized["input_ids"].squeeze(),
"source_idx": examples["source_idx"]
}
tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=16)
tokenized_ds = dataset.map(
tokenize_fn,
batched=True,
batch_size=32,
num_proc=4
)
model = CustomModel(model_name, AutoModelForCausalLM.from_pretrained(model_name).config)
model.source_mapper = source_mapper
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Model
model = AutoModelForCausalLM.from_pretrained(
"crumb/nano-mistral",
trust_remote_code=True
)
model.resize_token_embeddings(len(tokenizer))
# Trening
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
num_train_epochs=5,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=2e-5,
gradient_accumulation_steps=8,
learning_rate=1e-4,
fp16=torch.cuda.is_available(),
logging_steps=10,
save_strategy="steps",
save_steps=1000,
report_to="none",
remove_unused_columns=False
logging_steps=20,
save_strategy="epoch",
report_to="none"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=CustomDataCollator(tokenizer=tokenizer, mlm=False)
train_dataset=tokenized_ds,
data_collator=EnhancedDataCollator(tokenizer=tokenizer, mlm=False)
)
print("\nRozpoczęcie treningu...")
print("Rozpoczęcie treningu...")
trainer.train()
print("Trening zakończony!")
# Zapisz model
model.save_pretrained("./trained_model")
tokenizer.save_pretrained("./trained_model")
if __name__ == "__main__":
main()