diff --git a/hft.py b/hft.py index fbcbd8c..05ab308 100644 --- a/hft.py +++ b/hft.py @@ -170,7 +170,6 @@ model.gradient_checkpointing_enable() training_args = TrainingArguments( output_dir="./results", num_train_epochs=3, - per_device_train_batch_size=2, gradient_accumulation_steps=4, learning_rate=2e-5, fp16=True,