From ed44484e37ae2339b569ed8b2ab96727692c6567 Mon Sep 17 00:00:00 2001 From: "l.gabrysiak" Date: Tue, 25 Feb 2025 16:04:07 +0100 Subject: [PATCH] mod --- hft.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/hft.py b/hft.py index d81cc03..2ab3285 100644 --- a/hft.py +++ b/hft.py @@ -112,7 +112,7 @@ def custom_collate_fn(batch): # Dodajemy domyślne source_idx, jeśli nie istnieje source_idx = torch.tensor([b.get("source_idx", -1) for b in batch], dtype=torch.long) - + print("source_idx shape:", source_idx.shape) # Debugowanie return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "source_idx": source_idx} class CustomModel(AutoModelForCausalLM): @@ -133,17 +133,17 @@ class CustomModel(AutoModelForCausalLM): ) if source_idx is not None: - # Dodajemy embedding źródła do hidden states - source_embeds = self.source_embedding(source_idx).unsqueeze(1) - outputs.logits += source_embeds - + # Tutaj dodaj logikę obsługi source_idx + pass + return outputs + class CustomTrainer(Trainer): - def compute_loss(self, model, inputs, return_outputs=False, **kwargs): + def compute_loss(self, model, inputs, return_outputs=False): labels = inputs.pop("labels") - source_idx = inputs.pop("source_idx") - outputs = model(**inputs, labels=labels, source_idx=source_idx) + source_idx = inputs.pop("source_idx", None) + outputs = model(**inputs, labels=labels, source_idx=source_idx if source_idx is not None else None) return (outputs.loss, outputs) if return_outputs else outputs.loss # Inicjalizacja komponentów