mod
This commit is contained in:
parent
a1978e7683
commit
8f8843fbb2
4
hft.py
4
hft.py
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, PreTrainedModel
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import re
|
import re
|
||||||
|
|
@ -117,7 +117,7 @@ def custom_collate_fn(batch):
|
||||||
|
|
||||||
class CustomModel(nn.Module):
|
class CustomModel(nn.Module):
|
||||||
def __init__(self, model_name, config):
|
def __init__(self, model_name, config):
|
||||||
super().__init__()
|
super().__init__(config)
|
||||||
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
|
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
|
||||||
self.source_embedding = nn.Embedding(
|
self.source_embedding = nn.Embedding(
|
||||||
num_embeddings=1000,
|
num_embeddings=1000,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue