This commit is contained in:
l.gabrysiak 2025-02-25 17:16:14 +01:00
parent ce550ad79d
commit 44b4336822
1 changed files with 7 additions and 11 deletions

18
hft.py
View File

@ -124,18 +124,13 @@ class CustomModel(nn.Module):
)
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
outputs = self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
**kwargs
)
if source_idx is not None:
print("outputs.logits shape:", outputs.logits.shape)
source_embeds = self.source_embedding(source_idx).unsqueeze(1).expand(-1, outputs.logits.size(1), -1)
print("source_embeds shape:", source_embeds.shape)
outputs.logits += source_embeds
source_embeds = self.source_embedding(source_idx).unsqueeze(1).expand(-1, input_ids.size(1), -1)
# Dodaj embeddingi źródła do wejścia modelu
hidden_states = self.base_model.get_input_embeddings()(input_ids) + source_embeds
outputs = self.base_model(inputs_embeds=hidden_states, attention_mask=attention_mask, labels=labels, **kwargs)
else:
outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
return outputs
@ -161,6 +156,7 @@ tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=8)
# Inicjalizacja modelu
config = AutoModelForCausalLM.from_pretrained(model_name).config
print("Vocabulary size:", config.vocab_size)
model = CustomModel(model_name, config)
model.to("cpu")