diff --git a/hft.py b/hft.py index 72ad4c0..a947664 100644 --- a/hft.py +++ b/hft.py @@ -192,7 +192,7 @@ trainer.train() # Funkcja generująca odpowiedź def generate_answer(question, max_length=200): model.eval() - inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512).to(device) + inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): outputs = model.generate(