diff --git a/hft.py b/hft.py index 16b75a0..6190912 100644 --- a/hft.py +++ b/hft.py @@ -132,14 +132,14 @@ class CustomModel(AutoModelForCausalLM): ) if source_idx is not None: - # Tutaj dodaj logikę obsługi source_idx - pass + # Dodaj embedding źródła do logits + source_embeds = self.source_embedding(source_idx).unsqueeze(1) + outputs.logits += source_embeds return outputs - class CustomTrainer(Trainer): - def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + def compute_loss(self, model, inputs, return_outputs=False): labels = inputs.pop("labels") source_idx = inputs.pop("source_idx", None) outputs = model(**inputs, labels=labels, source_idx=source_idx)