trener mod
This commit is contained in:
parent
a380f06555
commit
329d76d072
4
hft.py
4
hft.py
|
|
@ -13,6 +13,7 @@ import json
|
|||
from huggingface_hub import login
|
||||
|
||||
login(f"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")
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
def load_file_catalog(catalog_path):
|
||||
with open(catalog_path, 'r', encoding='utf-8') as file:
|
||||
|
|
@ -101,8 +102,7 @@ class CustomTrainer(Trainer):
|
|||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
labels = inputs.pop("labels")
|
||||
source = inputs.pop("source")
|
||||
source_ids = torch.tensor([hash(s) % 1000 for s in source], device=model.device)
|
||||
outputs = model(**inputs, labels=labels, source=source_ids)
|
||||
outputs = model(**inputs, labels=labels)
|
||||
loss = outputs.loss
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue