This commit is contained in:
l.gabrysiak 2025-02-25 17:06:58 +01:00
parent 59c1b99f99
commit 7c24c381e0
1 changed files with 6 additions and 6 deletions

12
hft.py
View File

@ -114,9 +114,10 @@ def custom_collate_fn(batch):
print("source_idx shape:", source_idx.shape) # Debugowanie
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "source_idx": source_idx}
class CustomModel(AutoModelForCausalLM):
def __init__(self, config):
super().__init__(config)
class CustomModel(nn.Module):
def __init__(self, model_name, config):
super().__init__()
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
self.source_embedding = nn.Embedding(
num_embeddings=1000, # Maksymalna liczba unikalnych źródeł
embedding_dim=config.hidden_size,
@ -124,7 +125,7 @@ class CustomModel(AutoModelForCausalLM):
)
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
outputs = super().forward(
outputs = self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
@ -160,8 +161,7 @@ tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=8)
# Inicjalizacja modelu
config = AutoModelForCausalLM.from_pretrained(model_name).config
model = CustomModel.from_pretrained(model_name, config=config)
model.to("cpu")
model = CustomModel(model_name, config)
# Konfiguracja treningu
training_args = TrainingArguments(