diff --git a/hft.py b/hft.py index 6190912..9348bae 100644 --- a/hft.py +++ b/hft.py @@ -139,7 +139,7 @@ class CustomModel(AutoModelForCausalLM): return outputs class CustomTrainer(Trainer): - def compute_loss(self, model, inputs, return_outputs=False): + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): labels = inputs.pop("labels") source_idx = inputs.pop("source_idx", None) outputs = model(**inputs, labels=labels, source_idx=source_idx)