This commit is contained in:
l.gabrysiak 2025-02-25 16:59:14 +01:00
parent f3db6adbe0
commit c0077c7c04
1 changed files with 4 additions and 4 deletions

8
hft.py
View File

@ -132,14 +132,14 @@ class CustomModel(AutoModelForCausalLM):
) )
if source_idx is not None: if source_idx is not None:
# Tutaj dodaj logikę obsługi source_idx # Dodaj embedding źródła do logits
pass source_embeds = self.source_embedding(source_idx).unsqueeze(1)
outputs.logits += source_embeds
return outputs return outputs
class CustomTrainer(Trainer): class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels") labels = inputs.pop("labels")
source_idx = inputs.pop("source_idx", None) source_idx = inputs.pop("source_idx", None)
outputs = model(**inputs, labels=labels, source_idx=source_idx) outputs = model(**inputs, labels=labels, source_idx=source_idx)