From 18dc6ab28a6a4c11f4cd2ae5bfd37ef6a5c451e8 Mon Sep 17 00:00:00 2001 From: "l.gabrysiak" Date: Tue, 25 Feb 2025 13:42:51 +0100 Subject: [PATCH] mod --- hft.py | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/hft.py b/hft.py index 4e0ae1d..3e30477 100644 --- a/hft.py +++ b/hft.py @@ -102,31 +102,19 @@ def tokenize_function(examples): return_tensors="pt" ) tokenized["labels"] = tokenized["input_ids"].clone() - tokenized["source_idx"] = torch.tensor(examples["source_idx"], dtype=torch.long) # Upewnij się, że to tensor + tokenized["source_idx"] = examples["source_idx"] return tokenized def custom_collate_fn(batch): input_ids = torch.stack([torch.tensor(b["input_ids"]) for b in batch]) attention_mask = torch.stack([torch.tensor(b["attention_mask"]) for b in batch]) labels = torch.stack([torch.tensor(b["labels"]) for b in batch]) - source_idx = torch.tensor([b["source_idx"] for b in batch], dtype=torch.long) + + # Dodajemy domyślne source_idx, jeśli nie istnieje + source_idx = torch.tensor([b.get("source_idx", -1) for b in batch], dtype=torch.long) return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "source_idx": source_idx} -def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs): - outputs = super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - labels=labels, - **kwargs - ) - - if source_idx is not None and source_idx.dim() == 1: - source_embeds = self.source_embedding(source_idx).unsqueeze(1) # (batch_size, 1, hidden_size) - outputs.logits += source_embeds # Upewnij się, że wymiary się zgadzają - - return outputs - class CustomModel(AutoModelForCausalLM): def __init__(self, config): super().__init__(config) @@ -193,7 +181,7 @@ trainer = CustomTrainer( model=model, args=training_args, train_dataset=tokenized_dataset, - data_collator=custom_collate_fn # Dodaj customową funkcję collate + data_collator=custom_collate_fn # Użyj niestandardowego collate_fn ) trainer.train()