This commit is contained in:
l.gabrysiak 2025-02-26 00:22:03 +01:00
parent b4957ee652
commit e588d3af66
1 changed files with 9 additions and 41 deletions

50
gpt.py
View File

@ -2,66 +2,37 @@ import os
import torch import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from datasets import Dataset from datasets import Dataset
from collections import defaultdict
# Konfiguracja # Konfiguracja
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
MODEL_NAME = "gpt2" # Tymczasowo używamy mniejszego modelu do testów MODEL_NAME = "gpt2"
SPECIAL_TOKENS = ["[CITATION_START]", "[CITATION_END]"] SPECIAL_TOKENS = ["[CITATION_START]", "[CITATION_END]"]
class SourceMapper:
def __init__(self):
self.source_to_idx = defaultdict(lambda: len(self.source_to_idx))
self.idx_to_source = {}
def add_source(self, source):
if source not in self.source_to_idx:
idx = self.source_to_idx[source]
self.idx_to_source[idx] = source
def prepare_simple_dataset(): def prepare_simple_dataset():
# Przykładowe dane - zastąp rzeczywistymi danymi
return [ return [
{ {"text": "[CITATION_START] Kodeks Pracy, Art. 1 [CITATION_END] Tekst artykułu..."},
"text": "[CITATION_START] Kodeks Pracy, Art. 1 [CITATION_END] Tekst artykułu...", {"text": "[CITATION_START] Kodeks Pracy, Art. 2 [CITATION_END] Inny tekst..."}
"source_idx": 0
},
{
"text": "[CITATION_START] Kodeks Pracy, Art. 2 [CITATION_END] Inny tekst...",
"source_idx": 1
}
] ]
def main(): def main():
# Inicjalizacja # Inicjalizacja tokenizera
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.add_special_tokens({"additional_special_tokens": SPECIAL_TOKENS}) tokenizer.add_special_tokens({"additional_special_tokens": SPECIAL_TOKENS})
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
# Przygotowanie danych # Przygotowanie danych
source_mapper = SourceMapper()
data = prepare_simple_dataset() data = prepare_simple_dataset()
dataset = Dataset.from_dict({"text": [d["text"] for d in data]})
# Tworzenie datasetu
dataset = Dataset.from_dict({
"text": [d["text"] for d in data],
"source_idx": [d["source_idx"] for d in data]
})
# Tokenizacja # Tokenizacja
def tokenize_function(examples): def tokenize_function(examples):
tokenized = tokenizer( return tokenizer(
examples["text"], examples["text"],
truncation=True, truncation=True,
padding="max_length", padding="max_length",
max_length=128, max_length=128,
return_tensors="pt" return_tensors="pt"
) )
return {
"input_ids": tokenized["input_ids"].squeeze(),
"attention_mask": tokenized["attention_mask"].squeeze(),
"labels": tokenized["input_ids"].squeeze().clone(),
}
tokenized_dataset = dataset.map(tokenize_function, batched=True) tokenized_dataset = dataset.map(tokenize_function, batched=True)
@ -72,12 +43,10 @@ def main():
# Konfiguracja treningu # Konfiguracja treningu
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir="./results", output_dir="./results",
num_train_epochs=1, num_train_pochs=1,
per_device_train_batch_size=2, per_device_train_batch_size=2,
gradient_accumulation_steps=1, remove_unused_columns=True, # Kluczowa zmiana
learning_rate=2e-5, logging_steps=1
logging_steps=1,
remove_unused_columns=False
) )
# Trainer # Trainer
@ -87,7 +56,6 @@ def main():
train_dataset=tokenized_dataset, train_dataset=tokenized_dataset,
) )
# Rozpoczęcie treningu
print("Rozpoczęcie treningu...") print("Rozpoczęcie treningu...")
trainer.train() trainer.train()