diff --git a/hft.py b/hft.py index be2a268..5679524 100644 --- a/hft.py +++ b/hft.py @@ -245,7 +245,7 @@ def main(): remove_unused_columns=False ) - trainer = CustomTrainer( + trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset,