diff --git a/hft.py b/hft.py index 1f2a4e7..9c4ae30 100644 --- a/hft.py +++ b/hft.py @@ -17,6 +17,10 @@ os.environ['TORCH_USE_CUDA_DSA'] = '1' os.environ["TOKENIZERS_PARALLELISM"] = "false" login(token="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") +# Nowe tokeny specjalne +CITATION_START = "▌▌CITATION_START" +CITATION_END = "▌▌CITATION_END" + class SourceMapper: def __init__(self): self.source_to_idx = defaultdict(lambda: len(self.source_to_idx)) @@ -96,8 +100,28 @@ def prepare_dataset(directory, catalog_path, source_mapper): doc_type = identify_legal_document(file, file_catalog) print(f"Rozpoznany typ dokumentu: {doc_type}") + current_section = "" + current_chapter = "" + + # Wykrywanie struktury dokumentu + structure_matches = re.finditer( + r'(DZIAŁ [A-ZĄĆĘŁŃÓŚŹŻ]+)\n+(.*?)\n(?=Art\.|Rozdział|DZIAŁ|$)' + r'|(Rozdział [A-ZĄĆĘŁŃÓŚŹŻ]+)\n+(.*?)\n(?=Art\.|DZIAŁ|$)', + text + ) + for match in structure_matches: + if match.group(1): # DZIAŁ + current_section = f"{match.group(1)} - {match.group(2).strip()}" + current_chapter = "" + else: # Rozdział + current_chapter = f"{match.group(3)} - {match.group(4).strip()}" + if doc_type != "Opracowanie własne": - articles = re.split(r'(?i)(Art[\.\s]+\d+[\.\s]?)', text) + # Ulepszony regex dla artykułów + articles = re.split( + r'(?i)(Art[\.\s]+\d+[a-z]*(?:[\s§\.-]\d+)*)\.?\s*', + text + ) articles = [a.strip() for a in articles if a.strip()] print(f"Znaleziono {len(articles)} fragmentów") @@ -109,10 +133,21 @@ def prepare_dataset(directory, catalog_path, source_mapper): if len(article_content) < 50: continue + # Formatowanie cytowania + citation_block = ( + f"{CITATION_START}\n" + f"Dokument: {doc_type}\n" + f"Artykuł: {article_number}\n" + f"Sekcja: {current_section}\n" + f"Rozdział: {current_chapter}\n" + f"{CITATION_END}\n" + f"{article_content}" + ) + source = f"{doc_type}, {article_number}" source_mapper.add_source(source) data.append({ - "text": f"{article_number} {article_content}", + "text": citation_block, "source_idx": source_mapper.get_idx(source) }) else: @@ -147,6 +182,11 @@ class CustomModel(nn.Module): self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config) self.source_embedding = nn.Embedding(10000, config.hidden_size, padding_idx=-1) + # Dodatkowa inicjalizacja tokenizera + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.tokenizer.add_special_tokens({'additional_special_tokens': [CITATION_START, CITATION_END]}) + self.base_model.resize_token_embeddings(len(self.tokenizer)) + for param in self.base_model.parameters(): param.requires_grad = False for param in self.base_model.get_output_embeddings().parameters(): @@ -175,18 +215,8 @@ class CustomModel(nn.Module): class CustomDataCollator(DataCollatorForLanguageModeling): def torch_call(self, examples): - # Przetwórz podstawowe pola - input_ids = torch.stack([torch.tensor(ex["input_ids"]) for ex in examples]) - attention_mask = torch.stack([torch.tensor(ex["attention_mask"]) for ex in examples]) - labels = torch.stack([torch.tensor(ex["labels"]) for ex in examples]) + batch = super().torch_call(examples) - batch = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "labels": labels - } - - # Dodaj source_idx jeśli istnieje if "source_idx" in examples[0]: source_idx = torch.stack([torch.tensor(ex["source_idx"]) for ex in examples]) batch["source_idx"] = source_idx @@ -197,6 +227,9 @@ def main(): source_mapper = SourceMapper() model_name = "crumb/nano-mistral" tokenizer = AutoTokenizer.from_pretrained(model_name) + + # Dodaj specjalne tokeny do tokenizera + tokenizer.add_special_tokens({'additional_special_tokens': [CITATION_START, CITATION_END]}) tokenizer.pad_token = tokenizer.eos_token # Przygotowanie danych @@ -207,10 +240,8 @@ def main(): print("\nBrak danych do treningu!") return - #dataset = Dataset.from_list(data) dataset = Dataset.from_dict({k: [d[k] for d in data] for k in data[0]}) - def tokenize_function(examples): tokenized = tokenizer( examples["text"], @@ -223,7 +254,7 @@ def main(): "input_ids": tokenized["input_ids"].squeeze(), "attention_mask": tokenized["attention_mask"].squeeze(), "labels": tokenized["input_ids"].squeeze().clone(), - "source_idx": examples["source_idx"] # Dodano bez konwersji do tensora + "source_idx": examples["source_idx"] } tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=16)