This commit is contained in:
l.gabrysiak 2025-02-25 13:51:10 +01:00
parent 7796cc9ef0
commit 9cbcaa36ee
1 changed files with 3 additions and 1 deletions

4
hft.py
View File

@ -164,6 +164,7 @@ tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=32)
# Inicjalizacja modelu
config = AutoModelForCausalLM.from_pretrained(model_name).config
model = CustomModel.from_pretrained(model_name, config=config)
model.gradient_checkpointing_enable()
# Konfiguracja treningu
training_args = TrainingArguments(
@ -185,7 +186,8 @@ trainer = CustomTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=custom_collate_fn # Użyj niestandardowego collate_fn
data_collator=custom_collate_fn, # Użyj niestandardowego collate_fn
batch_size=8 # zmniejszenie rozmiaru batcha
)
trainer.train()