diff --git a/hft.py b/hft.py index 88ab3b9..f47689c 100644 --- a/hft.py +++ b/hft.py @@ -99,7 +99,7 @@ class CustomModel(AutoModelForCausalLM): # Dostosowany Trainer class CustomTrainer(Trainer): - def compute_loss(self, model, inputs, return_outputs=False): + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): labels = inputs.pop("labels") source = inputs.pop("source") outputs = model(**inputs, labels=labels)