This commit is contained in:
l.gabrysiak 2025-02-25 17:47:35 +01:00
parent a1978e7683
commit 8f8843fbb2
1 changed files with 2 additions and 2 deletions

4
hft.py
View File

@ -1,7 +1,7 @@
import os
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, PreTrainedModel
from datasets import Dataset
from PIL import Image
import re
@ -117,7 +117,7 @@ def custom_collate_fn(batch):
class CustomModel(nn.Module):
def __init__(self, model_name, config):
super().__init__()
super().__init__(config)
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
self.source_embedding = nn.Embedding(
num_embeddings=1000,