diff --git a/hft.py b/hft.py index 3ef5864..1110e83 100644 --- a/hft.py +++ b/hft.py @@ -1,7 +1,7 @@ import os import torch import torch.nn as nn -from transformers import AutoTokenizer, GPTNeoForCausalLM # Poprawiono importy +from transformers import AutoTokenizer, GPTNeoForCausalLM, Trainer # Poprawiono importy from datasets import Dataset from PIL import Image import re