TEN KOD DZIAŁA

This commit is contained in:
l.gabrysiak 2025-02-25 22:54:44 +01:00
parent 999eded568
commit 9004cd8cc1
1 changed files with 194 additions and 218 deletions

412
hft.py
View File

@ -1,35 +1,21 @@
import nltk
nltk.download('averaged_perceptron_tagger', quiet=True)
nltk.download('wordnet', quiet=True)
nltk.download('punkt', quiet=True)
import os
import torch
import random
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from datasets import Dataset
import re
import json
import PyPDF2
import docx2txt
import pytesseract
import numpy as np
from PIL import Image
from collections import defaultdict
from multiprocessing import cpu_count
from concurrent.futures import ThreadPoolExecutor
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from datasets import Dataset
from nlpaug.augmenter.word import SynonymAug
from huggingface_hub import login
# Konfiguracja
os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ["TOKENIZERS_PARALLELISM"] = "false"
login(token="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") # Zastąp swoim tokenem
login(token="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")
class SourceMapper:
def __init__(self):
@ -47,237 +33,227 @@ class SourceMapper:
def get_source(self, idx):
return self.idx_to_source.get(idx, "Unknown")
class LegalProcessor:
def __init__(self, catalog_path):
self.catalog = self.load_catalog(catalog_path)
self.augmenter = SynonymAug(aug_src='wordnet', aug_max=3, lang='pol')
def load_file_catalog(catalog_path):
try:
with open(catalog_path, 'r', encoding='utf-8') as file:
return json.load(file)
except Exception as e:
print(f"Błąd wczytywania katalogu plików: {str(e)}")
return {}
def identify_legal_document(filename, file_catalog):
base_name = os.path.splitext(filename)[0].lower()
return file_catalog.get(base_name, "Opracowanie własne")
def extract_text_from_file(file_path):
try:
_, ext = os.path.splitext(file_path)
ext = ext.lower()
def load_catalog(self, path):
try:
with open(path, 'r', encoding='utf-8') as f:
return json.load(f)
except:
return defaultdict(str)
def process_file(self, file_path):
text = self.extract_text(file_path)
if not text:
return []
doc_type = self.identify_doc_type(file_path)
return self.split_content(text, doc_type)
def extract_text(self, file_path):
ext = os.path.splitext(file_path)[1].lower()
try:
if ext == '.pdf':
return self.extract_pdf(file_path)
elif ext in ['.doc', '.docx']:
return docx2txt.process(file_path)
elif ext in ['.jpg', '.jpeg', '.png']:
return self.extract_image(file_path)
else:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
print(f"Błąd przetwarzania {file_path}: {str(e)}")
if ext in ['.txt', '.md']:
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()
elif ext == '.pdf':
text = ""
try:
with open(file_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
for page in reader.pages:
text += page.extract_text() or ""
except Exception as e:
print(f"Błąd PDF: {str(e)}")
return text
elif ext in ['.doc', '.docx']:
return docx2txt.process(file_path)
elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
return pytesseract.image_to_string(Image.open(file_path))
else:
print(f"Nieobsługiwany format pliku: {ext}")
return ""
def extract_pdf(self, path):
text = ""
with open(path, 'rb') as f:
reader = PyPDF2.PdfReader(f)
for page in reader.pages:
text += page.extract_text() + "\n"
return re.sub(r'\s+', ' ', text)
def extract_image(self, path):
return pytesseract.image_to_string(
Image.open(path),
config='--psm 4 --oem 3 -c preserve_interword_spaces=1'
)
def identify_doc_type(self, file_path):
base = os.path.splitext(os.path.basename(file_path))[0].lower()
return self.catalog.get(base, "Custom")
def split_content(self, text, doc_type):
if doc_type == "Custom":
return self.split_custom(text)
return self.split_legal(text, doc_type)
def split_legal(self, text, doc_type):
pattern = r'(?i)(Art[\.\s]*\d+[a-z]*|§\s*\d+|Rozdział\s+[IVXLCDM]+)'
parts = re.split(pattern, text)
results = []
current_header = ""
for part in parts:
if not part:
continue
if re.match(pattern, part):
if current_header:
results.append(current_header)
current_header = f"[{doc_type}] {part.strip()}"
else:
if current_header:
results.append(f"{current_header}: {part.strip()}")
current_header = ""
else:
results.append(part.strip())
return [text for text in results if len(text) > 50]
def split_custom(self, text):
clean_text = re.sub(r'\s+', ' ', text).strip()
chunk_size = 384
overlap = 64
chunks = []
start = 0
while start < len(clean_text):
end = start + chunk_size
chunks.append(clean_text[start:end])
start = end - overlap
return [f"[Custom] {chunk}" for chunk in chunks if chunk.strip()]
except Exception as e:
print(f"Błąd ekstrakcji tekstu: {str(e)}")
return ""
class CustomModel(torch.nn.Module):
def __init__(self, model_name):
super().__init__()
self.base_model = AutoModelForCausalLM.from_pretrained(model_name)
self.source_emb = torch.nn.Embedding(1000, self.base_model.config.hidden_size)
# Zamrożenie parametrów bazowych
for param in self.base_model.parameters():
param.requires_grad = False
# Odmrożenie ostatnich warstw
for layer in self.base_model.transformer.h[-2:]:
for param in layer.parameters():
param.requires_grad = True
self.base_model.get_output_embeddings().requires_grad_(True)
def forward(self, input_ids, attention_mask, labels, source_idx):
inputs_embeds = self.base_model.get_input_embeddings()(input_ids)
source_emb = self.source_emb(source_idx.clamp(0, 999)).unsqueeze(1)
inputs_embeds += source_emb
return self.base_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels
)
def main():
# Inicjalizacja komponentów
source_mapper = SourceMapper()
processor = LegalProcessor("file_catalog.json")
tokenizer = AutoTokenizer.from_pretrained("crumb/nano-mistral")
tokenizer.pad_token = tokenizer.eos_token
# Przetwarzanie danych
def prepare_dataset(directory, catalog_path, source_mapper):
file_catalog = load_file_catalog(catalog_path)
data = []
def process_and_augment(file_path):
try:
items = processor.process_file(file_path)
for text in items:
source = text.split("]")[0][1:]
source_mapper.add_source(source)
# Oryginalny tekst
data.append({
"text": text,
"source_idx": source_mapper.get_idx(source)
})
# Augmentacja
augmented = processor.augmenter.augment(text)
if augmented != text:
data.append({
"text": augmented,
"source_idx": source_mapper.get_idx(source)
})
except Exception as e:
print(f"Błąd przetwarzania {file_path}: {str(e)}")
print(f"\n{'='*50}\nDIAGNOSTYKA DANYCH\n{'='*50}")
# Przetwarzanie wielowątkowe
with ThreadPoolExecutor(max_workers=cpu_count()) as executor:
futures = []
for root, _, files in os.walk("files"):
for file in files:
file_path = os.path.join(root, file)
futures.append(executor.submit(process_and_augment, file_path))
for root, _, files in os.walk(directory):
for file in files:
file_path = os.path.join(root, file)
print(f"\nPrzetwarzanie pliku: {file_path}")
try:
text = extract_text_from_file(file_path)
if not text.strip():
print("Pominięto - brak tekstu")
continue
print(f"Długość tekstu: {len(text)} znaków")
for future in futures:
future.result()
doc_type = identify_legal_document(file, file_catalog)
print(f"Rozpoznany typ dokumentu: {doc_type}")
if doc_type != "Opracowanie własne":
articles = re.split(r'(?i)(Art[\.\s]+\d+[\.\s]?)', text)
articles = [a.strip() for a in articles if a.strip()]
print(f"Znaleziono {len(articles)} fragmentów")
for i in range(0, len(articles)-1, 2):
article_number = articles[i]
article_content = articles[i+1]
if len(article_content) < 50:
continue
source = f"{doc_type}, {article_number}"
source_mapper.add_source(source)
data.append({
"text": f"{article_number} {article_content}",
"source_idx": source_mapper.get_idx(source)
})
else:
clean_text = re.sub(r'\s+', ' ', text).strip()
chunks = [clean_text[i:i+512] for i in range(0, len(clean_text), 512)]
chunks = [c for c in chunks if c.strip()]
for chunk in chunks:
data.append({
"text": chunk,
"source_idx": -1
})
print(f"Dodano {len(chunks)} chunków")
except Exception as e:
print(f"Błąd podczas przetwarzania pliku: {str(e)}")
continue
print(f"\nPodsumowanie przygotowania danych:")
print(f"Łączna liczba przykładów: {len(data)}")
if data:
print("Przykładowy wpis:")
print(json.dumps(data[0], indent=2, ensure_ascii=False))
else:
print("BRAK DANYCH - sprawdź diagnostykę powyżej")
return data
print(f"\nPrzygotowano {len(data)} przykładów treningowych")
print("Przykładowe dane:")
for example in random.sample(data, 3):
print(f"\nŹródło: {source_mapper.get_source(example['source_idx'])}")
print(f"Tekst: {example['text'][:150]}...")
class CustomModel(nn.Module):
def __init__(self, model_name, config):
super().__init__()
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
self.source_embedding = nn.Embedding(10000, config.hidden_size, padding_idx=-1)
for param in self.base_model.parameters():
param.requires_grad = False
for param in self.base_model.get_output_embeddings().parameters():
param.requires_grad = True
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
if source_idx is not None:
valid_indices = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1)
source_embeds = self.source_embedding(valid_indices).unsqueeze(1)
inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds
return self.base_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels,
**kwargs
)
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
**kwargs
)
def generate(self, *args, **kwargs):
return self.base_model.generate(*args, **kwargs)
class CustomDataCollator(DataCollatorForLanguageModeling):
def torch_call(self, examples):
# Przetwórz podstawowe pola
input_ids = torch.stack([torch.tensor(ex["input_ids"]) for ex in examples])
attention_mask = torch.stack([torch.tensor(ex["attention_mask"]) for ex in examples])
labels = torch.stack([torch.tensor(ex["labels"]) for ex in examples])
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}
# Dodaj source_idx jeśli istnieje
if "source_idx" in examples[0]:
source_idx = torch.stack([torch.tensor(ex["source_idx"]) for ex in examples])
batch["source_idx"] = source_idx
return batch
def main():
source_mapper = SourceMapper()
model_name = "crumb/nano-mistral"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# Przygotowanie danych
catalog_path = "file_catalog.json"
data = prepare_dataset("files", catalog_path, source_mapper)
if not data:
print("\nBrak danych do treningu!")
return
# Przygotowanie datasetu
dataset = Dataset.from_list(data)
def tokenize_fn(examples):
def tokenize_function(examples):
tokenized = tokenizer(
examples["text"],
max_length=512,
padding="max_length",
truncation=True,
padding="max_length",
max_length=512,
return_tensors="pt"
)
return {
"input_ids": tokenized["input_ids"].squeeze(),
"attention_mask": tokenized["attention_mask"].squeeze(),
"labels": tokenized["input_ids"].squeeze(),
"source_idx": examples["source_idx"]
"labels": tokenized["input_ids"].squeeze().clone(),
"source_idx": examples["source_idx"] # Dodano bez konwersji do tensora
}
tokenized_ds = dataset.map(
tokenize_fn,
batched=True,
batch_size=32,
num_proc=4
)
# Inicjalizacja modelu
model = CustomModel("crumb/nano-mistral")
tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=16)
model = CustomModel(model_name, AutoModelForCausalLM.from_pretrained(model_name).config)
model.source_mapper = source_mapper
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Konfiguracja treningu
training_args = TrainingArguments(
output_dir="./wyniki",
num_train_epochs=5,
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=2,
gradient_accumulation_steps=8,
gradient_accumulation_steps=4,
learning_rate=2e-5,
fp16=torch.cuda.is_available(),
logging_steps=20,
save_strategy="epoch",
report_to="none"
logging_steps=10,
save_strategy="steps",
save_steps=1000,
report_to="none",
remove_unused_columns=False
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_ds,
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
train_dataset=tokenized_dataset,
data_collator=CustomDataCollator(tokenizer=tokenizer, mlm=False)
)
# Trening
print("\nRozpoczynanie treningu...")
print("\nRozpoczęcie treningu...")
trainer.train()
# Zapis modelu
model.save_pretrained("./trained_legal_model")
tokenizer.save_pretrained("./trained_legal_model")
print("Trening zakończony pomyślnie!")
if __name__ == "__main__":
main()