import os import torch from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer from datasets import Dataset from collections import defaultdict # Konfiguracja os.environ["TOKENIZERS_PARALLELISM"] = "false" MODEL_NAME = "gpt2" # Tymczasowo używamy mniejszego modelu do testów SPECIAL_TOKENS = ["[CITATION_START]", "[CITATION_END]"] class SourceMapper: def __init__(self): self.source_to_idx = defaultdict(lambda: len(self.source_to_idx)) self.idx_to_source = {} def add_source(self, source): if source not in self.source_to_idx: idx = self.source_to_idx[source] self.idx_to_source[idx] = source def prepare_simple_dataset(): # Przykładowe dane - zastąp rzeczywistymi danymi return [ { "text": "[CITATION_START] Kodeks Pracy, Art. 1 [CITATION_END] Tekst artykułu...", "source_idx": 0 }, { "text": "[CITATION_START] Kodeks Pracy, Art. 2 [CITATION_END] Inny tekst...", "source_idx": 1 } ] def main(): # Inicjalizacja tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) tokenizer.add_special_tokens({"additional_special_tokens": SPECIAL_TOKENS}) tokenizer.pad_token = tokenizer.eos_token # Przygotowanie danych source_mapper = SourceMapper() data = prepare_simple_dataset() # Tworzenie datasetu dataset = Dataset.from_dict({ "text": [d["text"] for d in data], "source_idx": [d["source_idx"] for d in data] }) # Tokenizacja def tokenize_function(examples): tokenized = tokenizer( examples["text"], truncation=True, padding="max_length", max_length=128, return_tensors="pt" ) return { "input_ids": tokenized["input_ids"].squeeze(), "attention_mask": tokenized["attention_mask"].squeeze(), "labels": tokenized["input_ids"].squeeze().clone(), } tokenized_dataset = dataset.map(tokenize_function, batched=True) # Model model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) model.resize_token_embeddings(len(tokenizer)) # Konfiguracja treningu training_args = TrainingArguments( output_dir="./results", num_train_epochs=1, per_device_train_batch_size=2, gradient_accumulation_steps=1, learning_rate=2e-5, logging_steps=1, remove_unused_columns=False ) # Trainer trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset, ) # Rozpoczęcie treningu print("Rozpoczęcie treningu...") trainer.train() if __name__ == "__main__": main()