From 9cbcaa36ee799739b62c277e1e2af3c6157f91af Mon Sep 17 00:00:00 2001 From: "l.gabrysiak" Date: Tue, 25 Feb 2025 13:51:10 +0100 Subject: [PATCH] mod --- hft.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/hft.py b/hft.py index 195993f..eccec6c 100644 --- a/hft.py +++ b/hft.py @@ -164,6 +164,7 @@ tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=32) # Inicjalizacja modelu config = AutoModelForCausalLM.from_pretrained(model_name).config model = CustomModel.from_pretrained(model_name, config=config) +model.gradient_checkpointing_enable() # Konfiguracja treningu training_args = TrainingArguments( @@ -185,7 +186,8 @@ trainer = CustomTrainer( model=model, args=training_args, train_dataset=tokenized_dataset, - data_collator=custom_collate_fn # Użyj niestandardowego collate_fn + data_collator=custom_collate_fn, # Użyj niestandardowego collate_fn + batch_size=8 # zmniejszenie rozmiaru batcha ) trainer.train()