This commit is contained in:
l.gabrysiak 2025-02-25 19:06:27 +01:00
parent 8f422c4b7c
commit 9e3ff6db7f
1 changed files with 3 additions and 1 deletions

4
hft.py
View File

@ -122,6 +122,7 @@ class CustomModel(nn.Module):
embedding_dim=config.hidden_size, embedding_dim=config.hidden_size,
padding_idx=-1 padding_idx=-1
) )
self.device = next(self.base_model.parameters()).device
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs): def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
if source_idx is not None: if source_idx is not None:
@ -145,8 +146,9 @@ class CustomTrainer(Trainer):
return (loss, outputs) if return_outputs else loss return (loss, outputs) if return_outputs else loss
def generate_answer(question, model, tokenizer, source_mapper, max_length=200): def generate_answer(question, model, tokenizer, source_mapper, max_length=200):
device = next(model.parameters()).device
inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512) inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.to(model.device) for k, v in inputs.items()} inputs = {k: v.to(device) for k, v in inputs.items()}
outputs = model.base_model.generate( outputs = model.base_model.generate(
**inputs, **inputs,