TEN KOD DZIAŁA

This commit is contained in:
l.gabrysiak 2025-02-25 22:54:44 +01:00
parent 999eded568
commit 9004cd8cc1
1 changed files with 194 additions and 218 deletions

346
hft.py
View File

@ -1,35 +1,21 @@
import nltk
nltk.download('averaged_perceptron_tagger', quiet=True)
nltk.download('wordnet', quiet=True)
nltk.download('punkt', quiet=True)
import os import os
import torch import torch
import random import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from datasets import Dataset
import re import re
import json import json
import PyPDF2 import PyPDF2
import docx2txt import docx2txt
import pytesseract import pytesseract
import numpy as np
from PIL import Image from PIL import Image
from collections import defaultdict from collections import defaultdict
from multiprocessing import cpu_count
from concurrent.futures import ThreadPoolExecutor
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from datasets import Dataset
from nlpaug.augmenter.word import SynonymAug
from huggingface_hub import login from huggingface_hub import login
# Konfiguracja # Konfiguracja
os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
login(token="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") # Zastąp swoim tokenem login(token="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")
class SourceMapper: class SourceMapper:
def __init__(self): def __init__(self):
@ -47,237 +33,227 @@ class SourceMapper:
def get_source(self, idx): def get_source(self, idx):
return self.idx_to_source.get(idx, "Unknown") return self.idx_to_source.get(idx, "Unknown")
class LegalProcessor: def load_file_catalog(catalog_path):
def __init__(self, catalog_path):
self.catalog = self.load_catalog(catalog_path)
self.augmenter = SynonymAug(aug_src='wordnet', aug_max=3, lang='pol')
def load_catalog(self, path):
try: try:
with open(path, 'r', encoding='utf-8') as f: with open(catalog_path, 'r', encoding='utf-8') as file:
return json.load(f) return json.load(file)
except: except Exception as e:
return defaultdict(str) print(f"Błąd wczytywania katalogu plików: {str(e)}")
return {}
def process_file(self, file_path): def identify_legal_document(filename, file_catalog):
text = self.extract_text(file_path) base_name = os.path.splitext(filename)[0].lower()
if not text: return file_catalog.get(base_name, "Opracowanie własne")
return []
doc_type = self.identify_doc_type(file_path) def extract_text_from_file(file_path):
return self.split_content(text, doc_type)
def extract_text(self, file_path):
ext = os.path.splitext(file_path)[1].lower()
try: try:
if ext == '.pdf': _, ext = os.path.splitext(file_path)
return self.extract_pdf(file_path) ext = ext.lower()
if ext in ['.txt', '.md']:
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()
elif ext == '.pdf':
text = ""
try:
with open(file_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
for page in reader.pages:
text += page.extract_text() or ""
except Exception as e:
print(f"Błąd PDF: {str(e)}")
return text
elif ext in ['.doc', '.docx']: elif ext in ['.doc', '.docx']:
return docx2txt.process(file_path) return docx2txt.process(file_path)
elif ext in ['.jpg', '.jpeg', '.png']: elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
return self.extract_image(file_path) return pytesseract.image_to_string(Image.open(file_path))
else: else:
with open(file_path, 'r', encoding='utf-8') as f: print(f"Nieobsługiwany format pliku: {ext}")
return f.read() return ""
except Exception as e: except Exception as e:
print(f"Błąd przetwarzania {file_path}: {str(e)}") print(f"Błąd ekstrakcji tekstu: {str(e)}")
return "" return ""
def extract_pdf(self, path): def prepare_dataset(directory, catalog_path, source_mapper):
text = "" file_catalog = load_file_catalog(catalog_path)
with open(path, 'rb') as f: data = []
reader = PyPDF2.PdfReader(f)
for page in reader.pages:
text += page.extract_text() + "\n"
return re.sub(r'\s+', ' ', text)
def extract_image(self, path): print(f"\n{'='*50}\nDIAGNOSTYKA DANYCH\n{'='*50}")
return pytesseract.image_to_string(
Image.open(path),
config='--psm 4 --oem 3 -c preserve_interword_spaces=1'
)
def identify_doc_type(self, file_path): for root, _, files in os.walk(directory):
base = os.path.splitext(os.path.basename(file_path))[0].lower() for file in files:
return self.catalog.get(base, "Custom") file_path = os.path.join(root, file)
print(f"\nPrzetwarzanie pliku: {file_path}")
def split_content(self, text, doc_type): try:
if doc_type == "Custom": text = extract_text_from_file(file_path)
return self.split_custom(text) if not text.strip():
return self.split_legal(text, doc_type) print("Pominięto - brak tekstu")
def split_legal(self, text, doc_type):
pattern = r'(?i)(Art[\.\s]*\d+[a-z]*|§\s*\d+|Rozdział\s+[IVXLCDM]+)'
parts = re.split(pattern, text)
results = []
current_header = ""
for part in parts:
if not part:
continue continue
if re.match(pattern, part):
if current_header:
results.append(current_header)
current_header = f"[{doc_type}] {part.strip()}"
else:
if current_header:
results.append(f"{current_header}: {part.strip()}")
current_header = ""
else:
results.append(part.strip())
return [text for text in results if len(text) > 50] print(f"Długość tekstu: {len(text)} znaków")
def split_custom(self, text): doc_type = identify_legal_document(file, file_catalog)
print(f"Rozpoznany typ dokumentu: {doc_type}")
if doc_type != "Opracowanie własne":
articles = re.split(r'(?i)(Art[\.\s]+\d+[\.\s]?)', text)
articles = [a.strip() for a in articles if a.strip()]
print(f"Znaleziono {len(articles)} fragmentów")
for i in range(0, len(articles)-1, 2):
article_number = articles[i]
article_content = articles[i+1]
if len(article_content) < 50:
continue
source = f"{doc_type}, {article_number}"
source_mapper.add_source(source)
data.append({
"text": f"{article_number} {article_content}",
"source_idx": source_mapper.get_idx(source)
})
else:
clean_text = re.sub(r'\s+', ' ', text).strip() clean_text = re.sub(r'\s+', ' ', text).strip()
chunk_size = 384 chunks = [clean_text[i:i+512] for i in range(0, len(clean_text), 512)]
overlap = 64 chunks = [c for c in chunks if c.strip()]
chunks = [] for chunk in chunks:
start = 0 data.append({
while start < len(clean_text): "text": chunk,
end = start + chunk_size "source_idx": -1
chunks.append(clean_text[start:end]) })
start = end - overlap print(f"Dodano {len(chunks)} chunków")
return [f"[Custom] {chunk}" for chunk in chunks if chunk.strip()] except Exception as e:
print(f"Błąd podczas przetwarzania pliku: {str(e)}")
continue
class CustomModel(torch.nn.Module): print(f"\nPodsumowanie przygotowania danych:")
def __init__(self, model_name): print(f"Łączna liczba przykładów: {len(data)}")
if data:
print("Przykładowy wpis:")
print(json.dumps(data[0], indent=2, ensure_ascii=False))
else:
print("BRAK DANYCH - sprawdź diagnostykę powyżej")
return data
class CustomModel(nn.Module):
def __init__(self, model_name, config):
super().__init__() super().__init__()
self.base_model = AutoModelForCausalLM.from_pretrained(model_name) self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
self.source_emb = torch.nn.Embedding(1000, self.base_model.config.hidden_size) self.source_embedding = nn.Embedding(10000, config.hidden_size, padding_idx=-1)
# Zamrożenie parametrów bazowych
for param in self.base_model.parameters(): for param in self.base_model.parameters():
param.requires_grad = False param.requires_grad = False
for param in self.base_model.get_output_embeddings().parameters():
# Odmrożenie ostatnich warstw
for layer in self.base_model.transformer.h[-2:]:
for param in layer.parameters():
param.requires_grad = True param.requires_grad = True
self.base_model.get_output_embeddings().requires_grad_(True) def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
if source_idx is not None:
def forward(self, input_ids, attention_mask, labels, source_idx): valid_indices = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1)
inputs_embeds = self.base_model.get_input_embeddings()(input_ids) source_embeds = self.source_embedding(valid_indices).unsqueeze(1)
source_emb = self.source_emb(source_idx.clamp(0, 999)).unsqueeze(1) inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds
inputs_embeds += source_emb
return self.base_model( return self.base_model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
attention_mask=attention_mask, attention_mask=attention_mask,
labels=labels labels=labels,
**kwargs
)
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
**kwargs
) )
def generate(self, *args, **kwargs):
return self.base_model.generate(*args, **kwargs)
class CustomDataCollator(DataCollatorForLanguageModeling):
def torch_call(self, examples):
# Przetwórz podstawowe pola
input_ids = torch.stack([torch.tensor(ex["input_ids"]) for ex in examples])
attention_mask = torch.stack([torch.tensor(ex["attention_mask"]) for ex in examples])
labels = torch.stack([torch.tensor(ex["labels"]) for ex in examples])
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}
# Dodaj source_idx jeśli istnieje
if "source_idx" in examples[0]:
source_idx = torch.stack([torch.tensor(ex["source_idx"]) for ex in examples])
batch["source_idx"] = source_idx
return batch
def main(): def main():
# Inicjalizacja komponentów
source_mapper = SourceMapper() source_mapper = SourceMapper()
processor = LegalProcessor("file_catalog.json") model_name = "crumb/nano-mistral"
tokenizer = AutoTokenizer.from_pretrained("crumb/nano-mistral") tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
# Przetwarzanie danych # Przygotowanie danych
data = [] catalog_path = "file_catalog.json"
data = prepare_dataset("files", catalog_path, source_mapper)
def process_and_augment(file_path): if not data:
try: print("\nBrak danych do treningu!")
items = processor.process_file(file_path) return
for text in items:
source = text.split("]")[0][1:]
source_mapper.add_source(source)
# Oryginalny tekst
data.append({
"text": text,
"source_idx": source_mapper.get_idx(source)
})
# Augmentacja
augmented = processor.augmenter.augment(text)
if augmented != text:
data.append({
"text": augmented,
"source_idx": source_mapper.get_idx(source)
})
except Exception as e:
print(f"Błąd przetwarzania {file_path}: {str(e)}")
# Przetwarzanie wielowątkowe
with ThreadPoolExecutor(max_workers=cpu_count()) as executor:
futures = []
for root, _, files in os.walk("files"):
for file in files:
file_path = os.path.join(root, file)
futures.append(executor.submit(process_and_augment, file_path))
for future in futures:
future.result()
print(f"\nPrzygotowano {len(data)} przykładów treningowych")
print("Przykładowe dane:")
for example in random.sample(data, 3):
print(f"\nŹródło: {source_mapper.get_source(example['source_idx'])}")
print(f"Tekst: {example['text'][:150]}...")
# Przygotowanie datasetu
dataset = Dataset.from_list(data) dataset = Dataset.from_list(data)
def tokenize_fn(examples): def tokenize_function(examples):
tokenized = tokenizer( tokenized = tokenizer(
examples["text"], examples["text"],
max_length=512,
padding="max_length",
truncation=True, truncation=True,
padding="max_length",
max_length=512,
return_tensors="pt" return_tensors="pt"
) )
return { return {
"input_ids": tokenized["input_ids"].squeeze(), "input_ids": tokenized["input_ids"].squeeze(),
"attention_mask": tokenized["attention_mask"].squeeze(), "attention_mask": tokenized["attention_mask"].squeeze(),
"labels": tokenized["input_ids"].squeeze(), "labels": tokenized["input_ids"].squeeze().clone(),
"source_idx": examples["source_idx"] "source_idx": examples["source_idx"] # Dodano bez konwersji do tensora
} }
tokenized_ds = dataset.map( tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=16)
tokenize_fn,
batched=True,
batch_size=32,
num_proc=4
)
# Inicjalizacja modelu model = CustomModel(model_name, AutoModelForCausalLM.from_pretrained(model_name).config)
model = CustomModel("crumb/nano-mistral") model.source_mapper = source_mapper
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device) model.to(device)
# Konfiguracja treningu
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir="./wyniki", output_dir="./results",
num_train_epochs=5, num_train_epochs=3,
per_device_train_batch_size=2, per_device_train_batch_size=2,
gradient_accumulation_steps=8, gradient_accumulation_steps=4,
learning_rate=2e-5, learning_rate=2e-5,
fp16=torch.cuda.is_available(), fp16=torch.cuda.is_available(),
logging_steps=20, logging_steps=10,
save_strategy="epoch", save_strategy="steps",
report_to="none" save_steps=1000,
report_to="none",
remove_unused_columns=False
) )
trainer = Trainer( trainer = Trainer(
model=model, model=model,
args=training_args, args=training_args,
train_dataset=tokenized_ds, train_dataset=tokenized_dataset,
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) data_collator=CustomDataCollator(tokenizer=tokenizer, mlm=False)
) )
# Trening print("\nRozpoczęcie treningu...")
print("\nRozpoczynanie treningu...")
trainer.train() trainer.train()
# Zapis modelu
model.save_pretrained("./trained_legal_model")
tokenizer.save_pretrained("./trained_legal_model")
print("Trening zakończony pomyślnie!")
if __name__ == "__main__": if __name__ == "__main__":
main() main()