From 4957a2898bdd8fb310f7ffd96552442da6fdcafa Mon Sep 17 00:00:00 2001 From: "l.gabrysiak" Date: Wed, 26 Feb 2025 00:08:31 +0100 Subject: [PATCH] mod --- hft.py | 30 +++++++++++++----------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/hft.py b/hft.py index 9c4ae30..7a43bad 100644 --- a/hft.py +++ b/hft.py @@ -103,21 +103,19 @@ def prepare_dataset(directory, catalog_path, source_mapper): current_section = "" current_chapter = "" - # Wykrywanie struktury dokumentu structure_matches = re.finditer( r'(DZIAŁ [A-ZĄĆĘŁŃÓŚŹŻ]+)\n+(.*?)\n(?=Art\.|Rozdział|DZIAŁ|$)' r'|(Rozdział [A-ZĄĆĘŁŃÓŚŹŻ]+)\n+(.*?)\n(?=Art\.|DZIAŁ|$)', text ) for match in structure_matches: - if match.group(1): # DZIAŁ + if match.group(1): current_section = f"{match.group(1)} - {match.group(2).strip()}" current_chapter = "" - else: # Rozdział + else: current_chapter = f"{match.group(3)} - {match.group(4).strip()}" if doc_type != "Opracowanie własne": - # Ulepszony regex dla artykułów articles = re.split( r'(?i)(Art[\.\s]+\d+[a-z]*(?:[\s§\.-]\d+)*)\.?\s*', text @@ -133,7 +131,6 @@ def prepare_dataset(directory, catalog_path, source_mapper): if len(article_content) < 50: continue - # Formatowanie cytowania citation_block = ( f"{CITATION_START}\n" f"Dokument: {doc_type}\n" @@ -177,15 +174,15 @@ def prepare_dataset(directory, catalog_path, source_mapper): return data class CustomModel(nn.Module): - def __init__(self, model_name, config): + def __init__(self, model_name, tokenizer): super().__init__() + config = AutoModelForCausalLM.from_pretrained(model_name).config self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config) self.source_embedding = nn.Embedding(10000, config.hidden_size, padding_idx=-1) - # Dodatkowa inicjalizacja tokenizera - self.tokenizer = AutoTokenizer.from_pretrained(model_name) - self.tokenizer.add_special_tokens({'additional_special_tokens': [CITATION_START, CITATION_END]}) - self.base_model.resize_token_embeddings(len(self.tokenizer)) + # Dodaj specjalne tokeny i zaktualizuj embeddings + tokenizer.add_special_tokens({'additional_special_tokens': [CITATION_START, CITATION_END]}) + self.base_model.resize_token_embeddings(len(tokenizer)) for param in self.base_model.parameters(): param.requires_grad = False @@ -218,7 +215,7 @@ class CustomDataCollator(DataCollatorForLanguageModeling): batch = super().torch_call(examples) if "source_idx" in examples[0]: - source_idx = torch.stack([torch.tensor(ex["source_idx"]) for ex in examples]) + source_idx = torch.tensor([ex["source_idx"] for ex in examples]) batch["source_idx"] = source_idx return batch @@ -226,12 +223,11 @@ class CustomDataCollator(DataCollatorForLanguageModeling): def main(): source_mapper = SourceMapper() model_name = "crumb/nano-mistral" - tokenizer = AutoTokenizer.from_pretrained(model_name) - # Dodaj specjalne tokeny do tokenizera - tokenizer.add_special_tokens({'additional_special_tokens': [CITATION_START, CITATION_END]}) + # Inicjalizacja tokenizera + tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token - + # Przygotowanie danych catalog_path = "catalog.json" data = prepare_dataset("docs", catalog_path, source_mapper) @@ -259,8 +255,8 @@ def main(): tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=16) - model = CustomModel(model_name, AutoModelForCausalLM.from_pretrained(model_name).config) - model.source_mapper = source_mapper + # Inicjalizacja modelu z tokenizerem + model = CustomModel(model_name, tokenizer) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device)