This commit is contained in:
l.gabrysiak 2025-02-26 00:26:35 +01:00
parent 992b55745e
commit 2dd8198c3a
1 changed files with 13 additions and 5 deletions

18
gpt.py
View File

@ -1,6 +1,6 @@
import os import os
import torch import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from datasets import Dataset from datasets import Dataset
# Konfiguracja # Konfiguracja
@ -24,21 +24,28 @@ def main():
data = prepare_simple_dataset() data = prepare_simple_dataset()
dataset = Dataset.from_dict({"text": [d["text"] for d in data]}) dataset = Dataset.from_dict({"text": [d["text"] for d in data]})
# Tokenizacja # Tokenizacja z prawidłowymi etykietami
def tokenize_function(examples): def tokenize_function(examples):
return tokenizer( tokenized = tokenizer(
examples["text"], examples["text"],
truncation=True, truncation=True,
padding="max_length", padding="max_length",
max_length=128, max_length=128,
return_tensors="pt" return_tensors="pt"
) )
tokenized["labels"] = tokenized["input_ids"].clone()
return tokenized
tokenized_dataset = dataset.map(tokenize_function, batched=True) tokenized_dataset = dataset.map(tokenize_function, batched=True)
# Model # Model i data collator
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
model.resize_token_embeddings(len(tokenizer)) model.resize_token_embeddings(len(tokenizer))
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
# Konfiguracja treningu # Konfiguracja treningu
training_args = TrainingArguments( training_args = TrainingArguments(
@ -47,7 +54,7 @@ def main():
per_device_train_batch_size=2, per_device_train_batch_size=2,
remove_unused_columns=True, remove_unused_columns=True,
logging_steps=1, logging_steps=1,
report_to="none" # Wyłączenie raportowania report_to="none"
) )
# Trainer # Trainer
@ -55,6 +62,7 @@ def main():
model=model, model=model,
args=training_args, args=training_args,
train_dataset=tokenized_dataset, train_dataset=tokenized_dataset,
data_collator=data_collator
) )
print("Rozpoczęcie treningu...") print("Rozpoczęcie treningu...")