From 5a2aef665ceb696f7401aabeffd6db3168ed7ddc Mon Sep 17 00:00:00 2001 From: "l.gabrysiak" Date: Tue, 25 Feb 2025 16:53:09 +0100 Subject: [PATCH] mod --- hft.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/hft.py b/hft.py index d69af70..d81cc03 100644 --- a/hft.py +++ b/hft.py @@ -109,26 +109,22 @@ def custom_collate_fn(batch): input_ids = torch.stack([torch.tensor(b["input_ids"]) for b in batch]) attention_mask = torch.stack([torch.tensor(b["attention_mask"]) for b in batch]) labels = torch.stack([torch.tensor(b["labels"]) for b in batch]) + + # Dodajemy domyślne source_idx, jeśli nie istnieje source_idx = torch.tensor([b.get("source_idx", -1) for b in batch], dtype=torch.long) - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "labels": labels, - "source_idx": source_idx - } + return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "source_idx": source_idx} class CustomModel(AutoModelForCausalLM): def __init__(self, config): super().__init__(config) self.source_embedding = nn.Embedding( - num_embeddings=1000, + num_embeddings=1000, # Maksymalna liczba unikalnych źródeł embedding_dim=config.hidden_size, padding_idx=-1 ) - def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): - source_idx = kwargs.pop('source_idx', None) # Pobierz i usuń source_idx z kwargs + def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs): outputs = super().forward( input_ids=input_ids, attention_mask=attention_mask, @@ -137,7 +133,7 @@ class CustomModel(AutoModelForCausalLM): ) if source_idx is not None: - source_idx = source_idx.to(outputs.logits.device) # Ensure same device + # Dodajemy embedding źródła do hidden states source_embeds = self.source_embedding(source_idx).unsqueeze(1) outputs.logits += source_embeds @@ -154,8 +150,7 @@ class CustomTrainer(Trainer): source_mapper = SourceMapper() model_name = "crumb/nano-mistral" #"google/gemma-2-2b" tokenizer = AutoTokenizer.from_pretrained(model_name) -if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token +tokenizer.pad_token = tokenizer.eos_token # Przygotowanie danych catalog_path = "file_catalog.json" @@ -166,7 +161,7 @@ tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=8) # Inicjalizacja modelu config = AutoModelForCausalLM.from_pretrained(model_name).config model = CustomModel.from_pretrained(model_name, config=config) -model.to("cuda" if torch.cuda.is_available() else "cpu") +model.to("cpu") # Konfiguracja treningu training_args = TrainingArguments(