mod
This commit is contained in:
parent
59c1b99f99
commit
7c24c381e0
12
hft.py
12
hft.py
|
|
@ -114,9 +114,10 @@ def custom_collate_fn(batch):
|
|||
print("source_idx shape:", source_idx.shape) # Debugowanie
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "source_idx": source_idx}
|
||||
|
||||
class CustomModel(AutoModelForCausalLM):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
class CustomModel(nn.Module):
|
||||
def __init__(self, model_name, config):
|
||||
super().__init__()
|
||||
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
|
||||
self.source_embedding = nn.Embedding(
|
||||
num_embeddings=1000, # Maksymalna liczba unikalnych źródeł
|
||||
embedding_dim=config.hidden_size,
|
||||
|
|
@ -124,7 +125,7 @@ class CustomModel(AutoModelForCausalLM):
|
|||
)
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
|
||||
outputs = super().forward(
|
||||
outputs = self.base_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
labels=labels,
|
||||
|
|
@ -160,8 +161,7 @@ tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=8)
|
|||
|
||||
# Inicjalizacja modelu
|
||||
config = AutoModelForCausalLM.from_pretrained(model_name).config
|
||||
model = CustomModel.from_pretrained(model_name, config=config)
|
||||
model.to("cpu")
|
||||
model = CustomModel(model_name, config)
|
||||
|
||||
# Konfiguracja treningu
|
||||
training_args = TrainingArguments(
|
||||
|
|
|
|||
Loading…
Reference in New Issue