From 8f8843fbb27a31a118d50e421508a5199bb0a17e Mon Sep 17 00:00:00 2001 From: "l.gabrysiak" Date: Tue, 25 Feb 2025 17:47:35 +0100 Subject: [PATCH] mod --- hft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hft.py b/hft.py index e2d40f5..28d0795 100644 --- a/hft.py +++ b/hft.py @@ -1,7 +1,7 @@ import os import torch import torch.nn as nn -from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer +from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, PreTrainedModel from datasets import Dataset from PIL import Image import re @@ -117,7 +117,7 @@ def custom_collate_fn(batch): class CustomModel(nn.Module): def __init__(self, model_name, config): - super().__init__() + super().__init__(config) self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config) self.source_embedding = nn.Embedding( num_embeddings=1000,