diff --git a/hft.py b/hft.py index 6e921fd..07d7e83 100644 --- a/hft.py +++ b/hft.py @@ -122,6 +122,7 @@ class CustomModel(nn.Module): embedding_dim=config.hidden_size, padding_idx=-1 ) + self.device = next(self.base_model.parameters()).device def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs): if source_idx is not None: @@ -145,8 +146,9 @@ class CustomTrainer(Trainer): return (loss, outputs) if return_outputs else loss def generate_answer(question, model, tokenizer, source_mapper, max_length=200): + device = next(model.parameters()).device inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512) - inputs = {k: v.to(model.device) for k, v in inputs.items()} + inputs = {k: v.to(device) for k, v in inputs.items()} outputs = model.base_model.generate( **inputs,