diff --git a/gpt.py b/gpt.py index 9b43685..4d3ad46 100644 --- a/gpt.py +++ b/gpt.py @@ -2,66 +2,37 @@ import os import torch from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer from datasets import Dataset -from collections import defaultdict # Konfiguracja os.environ["TOKENIZERS_PARALLELISM"] = "false" -MODEL_NAME = "gpt2" # Tymczasowo używamy mniejszego modelu do testów +MODEL_NAME = "gpt2" SPECIAL_TOKENS = ["[CITATION_START]", "[CITATION_END]"] -class SourceMapper: - def __init__(self): - self.source_to_idx = defaultdict(lambda: len(self.source_to_idx)) - self.idx_to_source = {} - - def add_source(self, source): - if source not in self.source_to_idx: - idx = self.source_to_idx[source] - self.idx_to_source[idx] = source - def prepare_simple_dataset(): - # Przykładowe dane - zastąp rzeczywistymi danymi return [ - { - "text": "[CITATION_START] Kodeks Pracy, Art. 1 [CITATION_END] Tekst artykułu...", - "source_idx": 0 - }, - { - "text": "[CITATION_START] Kodeks Pracy, Art. 2 [CITATION_END] Inny tekst...", - "source_idx": 1 - } + {"text": "[CITATION_START] Kodeks Pracy, Art. 1 [CITATION_END] Tekst artykułu..."}, + {"text": "[CITATION_START] Kodeks Pracy, Art. 2 [CITATION_END] Inny tekst..."} ] def main(): - # Inicjalizacja + # Inicjalizacja tokenizera tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) tokenizer.add_special_tokens({"additional_special_tokens": SPECIAL_TOKENS}) tokenizer.pad_token = tokenizer.eos_token # Przygotowanie danych - source_mapper = SourceMapper() data = prepare_simple_dataset() - - # Tworzenie datasetu - dataset = Dataset.from_dict({ - "text": [d["text"] for d in data], - "source_idx": [d["source_idx"] for d in data] - }) + dataset = Dataset.from_dict({"text": [d["text"] for d in data]}) # Tokenizacja def tokenize_function(examples): - tokenized = tokenizer( + return tokenizer( examples["text"], truncation=True, padding="max_length", max_length=128, return_tensors="pt" ) - return { - "input_ids": tokenized["input_ids"].squeeze(), - "attention_mask": tokenized["attention_mask"].squeeze(), - "labels": tokenized["input_ids"].squeeze().clone(), - } tokenized_dataset = dataset.map(tokenize_function, batched=True) @@ -72,12 +43,10 @@ def main(): # Konfiguracja treningu training_args = TrainingArguments( output_dir="./results", - num_train_epochs=1, + num_train_pochs=1, per_device_train_batch_size=2, - gradient_accumulation_steps=1, - learning_rate=2e-5, - logging_steps=1, - remove_unused_columns=False + remove_unused_columns=True, # Kluczowa zmiana + logging_steps=1 ) # Trainer @@ -87,7 +56,6 @@ def main(): train_dataset=tokenized_dataset, ) - # Rozpoczęcie treningu print("Rozpoczęcie treningu...") trainer.train()