mod
This commit is contained in:
parent
8f422c4b7c
commit
9e3ff6db7f
4
hft.py
4
hft.py
|
|
@ -122,6 +122,7 @@ class CustomModel(nn.Module):
|
||||||
embedding_dim=config.hidden_size,
|
embedding_dim=config.hidden_size,
|
||||||
padding_idx=-1
|
padding_idx=-1
|
||||||
)
|
)
|
||||||
|
self.device = next(self.base_model.parameters()).device
|
||||||
|
|
||||||
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
|
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
|
||||||
if source_idx is not None:
|
if source_idx is not None:
|
||||||
|
|
@ -145,8 +146,9 @@ class CustomTrainer(Trainer):
|
||||||
return (loss, outputs) if return_outputs else loss
|
return (loss, outputs) if return_outputs else loss
|
||||||
|
|
||||||
def generate_answer(question, model, tokenizer, source_mapper, max_length=200):
|
def generate_answer(question, model, tokenizer, source_mapper, max_length=200):
|
||||||
|
device = next(model.parameters()).device
|
||||||
inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512)
|
inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512)
|
||||||
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||||
|
|
||||||
outputs = model.base_model.generate(
|
outputs = model.base_model.generate(
|
||||||
**inputs,
|
**inputs,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue