mod
This commit is contained in:
parent
ce550ad79d
commit
44b4336822
18
hft.py
18
hft.py
|
|
@ -124,18 +124,13 @@ class CustomModel(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
|
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
|
||||||
outputs = self.base_model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
labels=labels,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if source_idx is not None:
|
if source_idx is not None:
|
||||||
print("outputs.logits shape:", outputs.logits.shape)
|
source_embeds = self.source_embedding(source_idx).unsqueeze(1).expand(-1, input_ids.size(1), -1)
|
||||||
source_embeds = self.source_embedding(source_idx).unsqueeze(1).expand(-1, outputs.logits.size(1), -1)
|
# Dodaj embeddingi źródła do wejścia modelu
|
||||||
print("source_embeds shape:", source_embeds.shape)
|
hidden_states = self.base_model.get_input_embeddings()(input_ids) + source_embeds
|
||||||
outputs.logits += source_embeds
|
outputs = self.base_model(inputs_embeds=hidden_states, attention_mask=attention_mask, labels=labels, **kwargs)
|
||||||
|
else:
|
||||||
|
outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
@ -161,6 +156,7 @@ tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=8)
|
||||||
|
|
||||||
# Inicjalizacja modelu
|
# Inicjalizacja modelu
|
||||||
config = AutoModelForCausalLM.from_pretrained(model_name).config
|
config = AutoModelForCausalLM.from_pretrained(model_name).config
|
||||||
|
print("Vocabulary size:", config.vocab_size)
|
||||||
model = CustomModel(model_name, config)
|
model = CustomModel(model_name, config)
|
||||||
model.to("cpu")
|
model.to("cpu")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue