mod
This commit is contained in:
parent
ce550ad79d
commit
44b4336822
18
hft.py
18
hft.py
|
|
@ -124,18 +124,13 @@ class CustomModel(nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
|
||||
outputs = self.base_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
labels=labels,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if source_idx is not None:
|
||||
print("outputs.logits shape:", outputs.logits.shape)
|
||||
source_embeds = self.source_embedding(source_idx).unsqueeze(1).expand(-1, outputs.logits.size(1), -1)
|
||||
print("source_embeds shape:", source_embeds.shape)
|
||||
outputs.logits += source_embeds
|
||||
source_embeds = self.source_embedding(source_idx).unsqueeze(1).expand(-1, input_ids.size(1), -1)
|
||||
# Dodaj embeddingi źródła do wejścia modelu
|
||||
hidden_states = self.base_model.get_input_embeddings()(input_ids) + source_embeds
|
||||
outputs = self.base_model(inputs_embeds=hidden_states, attention_mask=attention_mask, labels=labels, **kwargs)
|
||||
else:
|
||||
outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
|
@ -161,6 +156,7 @@ tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=8)
|
|||
|
||||
# Inicjalizacja modelu
|
||||
config = AutoModelForCausalLM.from_pretrained(model_name).config
|
||||
print("Vocabulary size:", config.vocab_size)
|
||||
model = CustomModel(model_name, config)
|
||||
model.to("cpu")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue