diff --git a/hft.py b/hft.py index e2d40f5..28d0795 100644 --- a/hft.py +++ b/hft.py @@ -1,7 +1,7 @@ import os import torch import torch.nn as nn -from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer +from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, PreTrainedModel from datasets import Dataset from PIL import Image import re @@ -117,7 +117,7 @@ def custom_collate_fn(batch): class CustomModel(nn.Module): def __init__(self, model_name, config): - super().__init__() + super().__init__(config) self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config) self.source_embedding = nn.Embedding( num_embeddings=1000,