From 44b43368226f2c7859b5a5e4f962861cb64a5494 Mon Sep 17 00:00:00 2001 From: "l.gabrysiak" Date: Tue, 25 Feb 2025 17:16:14 +0100 Subject: [PATCH] mod --- hft.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/hft.py b/hft.py index 05014cd..d58f531 100644 --- a/hft.py +++ b/hft.py @@ -124,18 +124,13 @@ class CustomModel(nn.Module): ) def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs): - outputs = self.base_model( - input_ids=input_ids, - attention_mask=attention_mask, - labels=labels, - **kwargs - ) - if source_idx is not None: - print("outputs.logits shape:", outputs.logits.shape) - source_embeds = self.source_embedding(source_idx).unsqueeze(1).expand(-1, outputs.logits.size(1), -1) - print("source_embeds shape:", source_embeds.shape) - outputs.logits += source_embeds + source_embeds = self.source_embedding(source_idx).unsqueeze(1).expand(-1, input_ids.size(1), -1) + # Dodaj embeddingi źródła do wejścia modelu + hidden_states = self.base_model.get_input_embeddings()(input_ids) + source_embeds + outputs = self.base_model(inputs_embeds=hidden_states, attention_mask=attention_mask, labels=labels, **kwargs) + else: + outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs) return outputs @@ -161,6 +156,7 @@ tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=8) # Inicjalizacja modelu config = AutoModelForCausalLM.from_pretrained(model_name).config +print("Vocabulary size:", config.vocab_size) model = CustomModel(model_name, config) model.to("cpu")