From 2cceeb31c8dfebb1afe93129acfc53bcd2ae14af Mon Sep 17 00:00:00 2001 From: "l.gabrysiak" Date: Tue, 25 Feb 2025 20:01:50 +0100 Subject: [PATCH] Zmiana CustomModel --- hft.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/hft.py b/hft.py index acf1678..a0c3288 100644 --- a/hft.py +++ b/hft.py @@ -115,10 +115,11 @@ def custom_collate_fn(batch): #print("source_idx shape:", source_idx.shape) # Debugowanie return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "source_idx": source_idx} -class CustomModel(nn.Module): +# Zmodyfikowana klasa CustomModel +class CustomModel(AutoModelForCausalLM): # 🔵 Zmiana dziedziczenia def __init__(self, model_name, config): - super().__init__() - self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config) + super().__init__(config) # 🔵 Inicjalizacja klasy bazowej + self.model = AutoModelForCausalLM.from_pretrained(model_name, config=config) self.source_embedding = nn.Embedding( num_embeddings=1000, embedding_dim=config.hidden_size, @@ -127,16 +128,15 @@ class CustomModel(nn.Module): def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs): if source_idx is not None: - #print("Max source_idx:", torch.max(source_idx)) - #print("Num embeddings:", self.source_embedding.num_embeddings) source_idx = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings - 1) source_embeds = self.source_embedding(source_idx).unsqueeze(1).expand(-1, input_ids.size(1), -1) - hidden_states = self.base_model.get_input_embeddings()(input_ids) + source_embeds - outputs = self.base_model(inputs_embeds=hidden_states, attention_mask=attention_mask, labels=labels, **kwargs) - else: - outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs) - - return outputs + inputs_embeds = self.model.get_input_embeddings()(input_ids) + source_embeds + return self.model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, **kwargs) + return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs) + + # 🔵 Dodanie metody generate + def generate(self, *args, **kwargs): + return self.model.generate(*args, **kwargs) class CustomTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False, **kwargs):