This commit is contained in:
l.gabrysiak 2025-02-25 17:16:14 +01:00
parent ce550ad79d
commit 44b4336822
1 changed files with 7 additions and 11 deletions

18
hft.py
View File

@ -124,18 +124,13 @@ class CustomModel(nn.Module):
) )
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs): def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
outputs = self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
**kwargs
)
if source_idx is not None: if source_idx is not None:
print("outputs.logits shape:", outputs.logits.shape) source_embeds = self.source_embedding(source_idx).unsqueeze(1).expand(-1, input_ids.size(1), -1)
source_embeds = self.source_embedding(source_idx).unsqueeze(1).expand(-1, outputs.logits.size(1), -1) # Dodaj embeddingi źródła do wejścia modelu
print("source_embeds shape:", source_embeds.shape) hidden_states = self.base_model.get_input_embeddings()(input_ids) + source_embeds
outputs.logits += source_embeds outputs = self.base_model(inputs_embeds=hidden_states, attention_mask=attention_mask, labels=labels, **kwargs)
else:
outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
return outputs return outputs
@ -161,6 +156,7 @@ tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=8)
# Inicjalizacja modelu # Inicjalizacja modelu
config = AutoModelForCausalLM.from_pretrained(model_name).config config = AutoModelForCausalLM.from_pretrained(model_name).config
print("Vocabulary size:", config.vocab_size)
model = CustomModel(model_name, config) model = CustomModel(model_name, config)
model.to("cpu") model.to("cpu")