From c55fbe8632eb1e6ab11ab051f8fb952b0afc1543 Mon Sep 17 00:00:00 2001 From: "l.gabrysiak" Date: Tue, 25 Feb 2025 22:17:13 +0100 Subject: [PATCH] mod --- hft.py | 379 ++++++++++++++++++++++++++++----------------------------- 1 file changed, 188 insertions(+), 191 deletions(-) diff --git a/hft.py b/hft.py index 7f4292f..b8c90c9 100644 --- a/hft.py +++ b/hft.py @@ -1,21 +1,30 @@ import os import torch -import torch.nn as nn -from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling -from datasets import Dataset +import random import re import json import PyPDF2 import docx2txt import pytesseract +import numpy as np from PIL import Image from collections import defaultdict +from multiprocessing import cpu_count +from concurrent.futures import ThreadPoolExecutor +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + TrainingArguments, + Trainer, + DataCollatorForLanguageModeling +) +from datasets import Dataset +from nlpaug import Augmenter, CharAugmenter, WordAugmenter from huggingface_hub import login # Konfiguracja -os.environ['TORCH_USE_CUDA_DSA'] = '1' os.environ["TOKENIZERS_PARALLELISM"] = "false" -login(token="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") +login(token="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") # Zastąp prawdziwym tokenem class SourceMapper: def __init__(self): @@ -33,227 +42,215 @@ class SourceMapper: def get_source(self, idx): return self.idx_to_source.get(idx, "Unknown") -def load_file_catalog(catalog_path): - try: - with open(catalog_path, 'r', encoding='utf-8') as file: - return json.load(file) - except Exception as e: - print(f"Błąd wczytywania katalogu plików: {str(e)}") - return {} - -def identify_legal_document(filename, file_catalog): - base_name = os.path.splitext(filename)[0].lower() - return file_catalog.get(base_name, "Opracowanie własne") - -def extract_text_from_file(file_path): - try: - _, ext = os.path.splitext(file_path) - ext = ext.lower() +class LegalDataProcessor: + def __init__(self, catalog_path): + self.catalog = self.load_catalog(catalog_path) + self.augmenter = WordAugmenter.AntonymAug() - if ext in ['.txt', '.md']: - with open(file_path, 'r', encoding='utf-8') as file: - return file.read() - elif ext == '.pdf': - text = "" - try: - with open(file_path, 'rb') as file: - reader = PyPDF2.PdfReader(file) - for page in reader.pages: - text += page.extract_text() or "" - except Exception as e: - print(f"Błąd PDF: {str(e)}") - return text - elif ext in ['.doc', '.docx']: - return docx2txt.process(file_path) - elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']: - return pytesseract.image_to_string(Image.open(file_path)) - else: - print(f"Nieobsługiwany format pliku: {ext}") + def load_catalog(self, path): + try: + with open(path, 'r', encoding='utf-8') as f: + return json.load(f) + except: + return defaultdict(str) + + def identify_document(self, filename): + base = os.path.splitext(filename)[0].lower() + return self.catalog.get(base, "Opracowanie własne") + + def extract_text(self, file_path): + ext = os.path.splitext(file_path)[1].lower() + try: + if ext == '.pdf': + return self._extract_pdf(file_path) + elif ext in ['.doc', '.docx']: + return docx2txt.process(file_path) + elif ext in ['.jpg', '.jpeg', '.png']: + return self._extract_ocr(file_path) + else: + with open(file_path, 'r', encoding='utf-8') as f: + return f.read() + except Exception as e: + print(f"Błąd przetwarzania {file_path}: {str(e)}") return "" - except Exception as e: - print(f"Błąd ekstrakcji tekstu: {str(e)}") - return "" - -def prepare_dataset(directory, catalog_path, source_mapper): - file_catalog = load_file_catalog(catalog_path) - data = [] - print(f"\n{'='*50}\nDIAGNOSTYKA DANYCH\n{'='*50}") + def _extract_pdf(self, path): + text = "" + with open(path, 'rb') as f: + reader = PyPDF2.PdfReader(f) + for page in reader.pages: + text += page.extract_text() + "\n" + return re.sub(r'\s+', ' ', text) - for root, _, files in os.walk(directory): - for file in files: - file_path = os.path.join(root, file) - print(f"\nPrzetwarzanie pliku: {file_path}") - - try: - text = extract_text_from_file(file_path) - if not text.strip(): - print("Pominięto - brak tekstu") - continue - - print(f"Długość tekstu: {len(text)} znaków") - - doc_type = identify_legal_document(file, file_catalog) - print(f"Rozpoznany typ dokumentu: {doc_type}") - - if doc_type != "Opracowanie własne": - articles = re.split(r'(?i)(Art[\.\s]+\d+[\.\s]?)', text) - articles = [a.strip() for a in articles if a.strip()] - - print(f"Znaleziono {len(articles)} fragmentów") - - for i in range(0, len(articles)-1, 2): - article_number = articles[i] - article_content = articles[i+1] - - if len(article_content) < 50: - continue - - source = f"{doc_type}, {article_number}" - source_mapper.add_source(source) - data.append({ - "text": f"{article_number} {article_content}", - "source_idx": source_mapper.get_idx(source) - }) - else: - clean_text = re.sub(r'\s+', ' ', text).strip() - chunks = [clean_text[i:i+512] for i in range(0, len(clean_text), 512)] - chunks = [c for c in chunks if c.strip()] - - for chunk in chunks: - data.append({ - "text": chunk, - "source_idx": -1 - }) - print(f"Dodano {len(chunks)} chunków") - - except Exception as e: - print(f"Błąd podczas przetwarzania pliku: {str(e)}") - continue - - print(f"\nPodsumowanie przygotowania danych:") - print(f"Łączna liczba przykładów: {len(data)}") - if data: - print("Przykładowy wpis:") - print(json.dumps(data[0], indent=2, ensure_ascii=False)) - else: - print("BRAK DANYCH - sprawdź diagnostykę powyżej") - - return data - -class CustomModel(nn.Module): - def __init__(self, model_name, config): - super().__init__() - self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config) - self.source_embedding = nn.Embedding(10000, config.hidden_size, padding_idx=-1) - - for param in self.base_model.parameters(): - param.requires_grad = False - for param in self.base_model.get_output_embeddings().parameters(): - param.requires_grad = True - - def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs): - if source_idx is not None: - valid_indices = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1) - source_embeds = self.source_embedding(valid_indices).unsqueeze(1) - inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds - return self.base_model( - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - labels=labels, - **kwargs - ) - return self.base_model( - input_ids=input_ids, - attention_mask=attention_mask, - labels=labels, - **kwargs + def _extract_ocr(self, path): + return pytesseract.image_to_string( + Image.open(path), + config='--psm 4 --oem 3 -c preserve_interword_spaces=1' ) - def generate(self, *args, **kwargs): - return self.base_model.generate(*args, **kwargs) + def process_legal(self, text, doc_type): + articles = re.split( + r'(?ix)(Art\.?\s*\d+[a-z]*|§\s*\d+|Rozdział\s+[IVXLCDM]+)\b', + text + ) + processed = [] + current_header = "" + + for item in articles: + if item and re.match(r'(?i)(Art|§|Rozdział)', item): + if current_header: + processed.append(current_header) + current_header = item.strip() + elif current_header: + processed.append(current_header + " " + item.strip()) + current_header = "" + else: + processed.append(item.strip()) + + return [ + (f"[{doc_type}] {p}", doc_type) + for p in processed if len(p) > 30 + ] + + def process_custom(self, text): + clean_text = re.sub(r'\s+', ' ', text).strip() + chunk_size = 384 + overlap = 128 + + chunks = [ + clean_text[i:i+chunk_size] + for i in range(0, len(clean_text), chunk_size - overlap) + ] + return [("[Custom] " + c, "Custom") for c in chunks if c.strip()] -class CustomDataCollator(DataCollatorForLanguageModeling): +class EnhancedDataCollator(DataCollatorForLanguageModeling): def torch_call(self, examples): - # Przetwórz podstawowe pola - input_ids = torch.stack([torch.tensor(ex["input_ids"]) for ex in examples]) - attention_mask = torch.stack([torch.tensor(ex["attention_mask"]) for ex in examples]) - labels = torch.stack([torch.tensor(ex["labels"]) for ex in examples]) - - batch = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "labels": labels - } - - # Dodaj source_idx jeśli istnieje + batch = super().torch_call(examples) if "source_idx" in examples[0]: - source_idx = torch.stack([torch.tensor(ex["source_idx"]) for ex in examples]) - batch["source_idx"] = source_idx - + batch["source_idx"] = torch.tensor( + [ex["source_idx"] for ex in examples], + dtype=torch.long + ) return batch def main(): + # Konfiguracja source_mapper = SourceMapper() - model_name = "crumb/nano-mistral" - tokenizer = AutoTokenizer.from_pretrained(model_name) + processor = LegalDataProcessor("file_catalog.json") + tokenizer = AutoTokenizer.from_pretrained("crumb/nano-mistral") tokenizer.pad_token = tokenizer.eos_token - - # Przygotowanie danych - catalog_path = "file_catalog.json" - data = prepare_dataset("files", catalog_path, source_mapper) - if not data: - print("\nBrak danych do treningu!") - return - + # Przygotowanie danych + data = [] + + def process_file(file_path): + nonlocal data + text = processor.extract_text(file_path) + if not text: + return + + doc_type = processor.identify_document(os.path.basename(file_path)) + if doc_type != "Opracowanie własne": + processed = processor.process_legal(text, doc_type) + else: + processed = processor.process_custom(text) + + for text, source in processed: + source_mapper.add_source(source) + data.append({ + "text": text, + "source_idx": source_mapper.get_idx(source) + }) + + # Przetwarzanie wielowątkowe + with ThreadPoolExecutor(max_workers=cpu_count()) as executor: + futures = [] + for root, _, files in os.walk("files"): + for file in files: + futures.append(executor.submit( + process_file, + os.path.join(root, file) + )) + + for future in futures: + try: + future.result() + except Exception as e: + print(f"Błąd: {str(e)}") + + # Augmentacja + print(f"Przed augmentacją: {len(data)} przykładów") + augmented = [] + for item in data: + for _ in range(2): # 2 dodatkowe warianty + sentences = item['text'].split('. ') + random.shuffle(sentences) + augmented.append({ + "text": '. '.join(sentences), + "source_idx": item["source_idx"] + }) + data += augmented + print(f"Po augmentacji: {len(data)} przykładów") + + # Przygotowanie datasetu dataset = Dataset.from_list(data) - - def tokenize_function(examples): + + def tokenize_fn(examples): tokenized = tokenizer( examples["text"], - truncation=True, - padding="max_length", max_length=512, + padding="max_length", + truncation=True, return_tensors="pt" ) return { "input_ids": tokenized["input_ids"].squeeze(), "attention_mask": tokenized["attention_mask"].squeeze(), - "labels": tokenized["input_ids"].squeeze().clone(), - "source_idx": examples["source_idx"] # Dodano bez konwersji do tensora + "labels": tokenized["input_ids"].squeeze(), + "source_idx": examples["source_idx"] } - - tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=16) - - model = CustomModel(model_name, AutoModelForCausalLM.from_pretrained(model_name).config) - model.source_mapper = source_mapper - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model.to(device) - + + tokenized_ds = dataset.map( + tokenize_fn, + batched=True, + batch_size=32, + num_proc=4 + ) + + # Model + model = AutoModelForCausalLM.from_pretrained( + "crumb/nano-mistral", + trust_remote_code=True + ) + model.resize_token_embeddings(len(tokenizer)) + + # Trening training_args = TrainingArguments( output_dir="./results", - num_train_epochs=3, + num_train_epochs=5, per_device_train_batch_size=2, - gradient_accumulation_steps=4, - learning_rate=2e-5, + gradient_accumulation_steps=8, + learning_rate=1e-4, fp16=torch.cuda.is_available(), - logging_steps=10, - save_strategy="steps", - save_steps=1000, - report_to="none", - remove_unused_columns=False + logging_steps=20, + save_strategy="epoch", + report_to="none" ) - + trainer = Trainer( model=model, args=training_args, - train_dataset=tokenized_dataset, - data_collator=CustomDataCollator(tokenizer=tokenizer, mlm=False) + train_dataset=tokenized_ds, + data_collator=EnhancedDataCollator(tokenizer=tokenizer, mlm=False) ) - - print("\nRozpoczęcie treningu...") + + print("Rozpoczęcie treningu...") trainer.train() + print("Trening zakończony!") + + # Zapisz model + model.save_pretrained("./trained_model") + tokenizer.save_pretrained("./trained_model") if __name__ == "__main__": main() \ No newline at end of file