mod
This commit is contained in:
parent
f3db6adbe0
commit
c0077c7c04
8
hft.py
8
hft.py
|
|
@ -132,14 +132,14 @@ class CustomModel(AutoModelForCausalLM):
|
|||
)
|
||||
|
||||
if source_idx is not None:
|
||||
# Tutaj dodaj logikę obsługi source_idx
|
||||
pass
|
||||
# Dodaj embedding źródła do logits
|
||||
source_embeds = self.source_embedding(source_idx).unsqueeze(1)
|
||||
outputs.logits += source_embeds
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class CustomTrainer(Trainer):
|
||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
labels = inputs.pop("labels")
|
||||
source_idx = inputs.pop("source_idx", None)
|
||||
outputs = model(**inputs, labels=labels, source_idx=source_idx)
|
||||
|
|
|
|||
Loading…
Reference in New Issue