diff --git a/hft.py b/hft.py index d81cc03..16b75a0 100644 --- a/hft.py +++ b/hft.py @@ -110,9 +110,8 @@ def custom_collate_fn(batch): attention_mask = torch.stack([torch.tensor(b["attention_mask"]) for b in batch]) labels = torch.stack([torch.tensor(b["labels"]) for b in batch]) - # Dodajemy domyślne source_idx, jeśli nie istnieje source_idx = torch.tensor([b.get("source_idx", -1) for b in batch], dtype=torch.long) - + print("source_idx shape:", source_idx.shape) # Debugowanie return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "source_idx": source_idx} class CustomModel(AutoModelForCausalLM): @@ -133,18 +132,19 @@ class CustomModel(AutoModelForCausalLM): ) if source_idx is not None: - # Dodajemy embedding źródła do hidden states - source_embeds = self.source_embedding(source_idx).unsqueeze(1) - outputs.logits += source_embeds - + # Tutaj dodaj logikę obsługi source_idx + pass + return outputs + class CustomTrainer(Trainer): - def compute_loss(self, model, inputs, return_outputs=False, **kwargs): + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): labels = inputs.pop("labels") - source_idx = inputs.pop("source_idx") + source_idx = inputs.pop("source_idx", None) outputs = model(**inputs, labels=labels, source_idx=source_idx) - return (outputs.loss, outputs) if return_outputs else outputs.loss + loss = outputs.loss + return (loss, outputs) if return_outputs else loss # Inicjalizacja komponentów source_mapper = SourceMapper()