From 9e3ff6db7ff8062ec943fc882fcc522e2c7a86b6 Mon Sep 17 00:00:00 2001 From: "l.gabrysiak" Date: Tue, 25 Feb 2025 19:06:27 +0100 Subject: [PATCH] mod --- hft.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/hft.py b/hft.py index 6e921fd..07d7e83 100644 --- a/hft.py +++ b/hft.py @@ -122,6 +122,7 @@ class CustomModel(nn.Module): embedding_dim=config.hidden_size, padding_idx=-1 ) + self.device = next(self.base_model.parameters()).device def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs): if source_idx is not None: @@ -145,8 +146,9 @@ class CustomTrainer(Trainer): return (loss, outputs) if return_outputs else loss def generate_answer(question, model, tokenizer, source_mapper, max_length=200): + device = next(model.parameters()).device inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512) - inputs = {k: v.to(model.device) for k, v in inputs.items()} + inputs = {k: v.to(device) for k, v in inputs.items()} outputs = model.base_model.generate( **inputs,