mod
This commit is contained in:
parent
c0077c7c04
commit
7c77d1c5b7
2
hft.py
2
hft.py
|
|
@ -139,7 +139,7 @@ class CustomModel(AutoModelForCausalLM):
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
class CustomTrainer(Trainer):
|
class CustomTrainer(Trainer):
|
||||||
def compute_loss(self, model, inputs, return_outputs=False):
|
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||||
labels = inputs.pop("labels")
|
labels = inputs.pop("labels")
|
||||||
source_idx = inputs.pop("source_idx", None)
|
source_idx = inputs.pop("source_idx", None)
|
||||||
outputs = model(**inputs, labels=labels, source_idx=source_idx)
|
outputs = model(**inputs, labels=labels, source_idx=source_idx)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue