From 7c24c381e0c8f5a66df5c78bcc5f51b4379abcbd Mon Sep 17 00:00:00 2001 From: "l.gabrysiak" Date: Tue, 25 Feb 2025 17:06:58 +0100 Subject: [PATCH] mod --- hft.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/hft.py b/hft.py index 32e9aa2..11f15b8 100644 --- a/hft.py +++ b/hft.py @@ -114,9 +114,10 @@ def custom_collate_fn(batch): print("source_idx shape:", source_idx.shape) # Debugowanie return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "source_idx": source_idx} -class CustomModel(AutoModelForCausalLM): - def __init__(self, config): - super().__init__(config) +class CustomModel(nn.Module): + def __init__(self, model_name, config): + super().__init__() + self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config) self.source_embedding = nn.Embedding( num_embeddings=1000, # Maksymalna liczba unikalnych źródeł embedding_dim=config.hidden_size, @@ -124,7 +125,7 @@ class CustomModel(AutoModelForCausalLM): ) def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs): - outputs = super().forward( + outputs = self.base_model( input_ids=input_ids, attention_mask=attention_mask, labels=labels, @@ -160,8 +161,7 @@ tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=8) # Inicjalizacja modelu config = AutoModelForCausalLM.from_pretrained(model_name).config -model = CustomModel.from_pretrained(model_name, config=config) -model.to("cpu") +model = CustomModel(model_name, config) # Konfiguracja treningu training_args = TrainingArguments(