This commit is contained in:
l.gabrysiak 2025-02-25 17:47:35 +01:00
parent a1978e7683
commit 8f8843fbb2
1 changed files with 2 additions and 2 deletions

4
hft.py
View File

@ -1,7 +1,7 @@
import os import os
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, PreTrainedModel
from datasets import Dataset from datasets import Dataset
from PIL import Image from PIL import Image
import re import re
@ -117,7 +117,7 @@ def custom_collate_fn(batch):
class CustomModel(nn.Module): class CustomModel(nn.Module):
def __init__(self, model_name, config): def __init__(self, model_name, config):
super().__init__() super().__init__(config)
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config) self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
self.source_embedding = nn.Embedding( self.source_embedding = nn.Embedding(
num_embeddings=1000, num_embeddings=1000,