mod
This commit is contained in:
parent
745446f6fd
commit
0ace5d1348
3
hft.py
3
hft.py
|
|
@ -10,6 +10,7 @@ import pytesseract
|
||||||
import docx2txt
|
import docx2txt
|
||||||
import PyPDF2
|
import PyPDF2
|
||||||
import json
|
import json
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from huggingface_hub import login
|
from huggingface_hub import login
|
||||||
|
|
||||||
|
|
@ -165,6 +166,8 @@ tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=32)
|
||||||
config = AutoModelForCausalLM.from_pretrained(model_name).config
|
config = AutoModelForCausalLM.from_pretrained(model_name).config
|
||||||
#model = CustomModel.from_pretrained(model_name, config=config)
|
#model = CustomModel.from_pretrained(model_name, config=config)
|
||||||
model = CustomModel.from_pretrained(model_name)
|
model = CustomModel.from_pretrained(model_name)
|
||||||
|
model.config.gradient_checkpointing = True
|
||||||
|
model.config.use_cache = False
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue