diff --git a/gpt.py b/gpt.py index 99c2dda..213c655 100644 --- a/gpt.py +++ b/gpt.py @@ -39,11 +39,8 @@ def main(): tokenized_dataset = dataset.map(tokenize_function, batched=True) # Model i data collator - model = AutoModelForCausalLM.from_pretrained( - MODEL_NAME, - mean_resizing=False - ) - model.resize_token_embeddings(len(tokenizer)) + model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) + model.resize_token_embeddings(len(tokenizer), mean_resizing=False) data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer,