From 37e14536cb388f52aa4823ba7f27834b772becd9 Mon Sep 17 00:00:00 2001 From: "l.gabrysiak" Date: Tue, 25 Feb 2025 18:23:00 +0100 Subject: [PATCH] mod --- hft.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/hft.py b/hft.py index a26323d..7f0d415 100644 --- a/hft.py +++ b/hft.py @@ -107,11 +107,10 @@ def tokenize_function(examples): return tokenized def custom_collate_fn(batch): - device = next(model.parameters()).device - input_ids = torch.stack([torch.tensor(b["input_ids"]) for b in batch]).to(device) - attention_mask = torch.stack([torch.tensor(b["attention_mask"]) for b in batch]).to(device) - labels = torch.stack([torch.tensor(b["labels"]) for b in batch]).to(device) - source_idx = torch.tensor([b.get("source_idx", -1) for b in batch], dtype=torch.long).to(device) + input_ids = torch.stack([torch.tensor(b["input_ids"]) for b in batch]).cpu() + attention_mask = torch.stack([torch.tensor(b["attention_mask"]) for b in batch]).cpu() + labels = torch.stack([torch.tensor(b["labels"]) for b in batch]).cpu() + source_idx = torch.tensor([b.get("source_idx", -1) for b in batch], dtype=torch.long).cpu() return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "source_idx": source_idx} class CustomModel(nn.Module): @@ -137,6 +136,8 @@ class CustomModel(nn.Module): class CustomTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False, **kwargs): + device = next(model.parameters()).device + inputs = {k: v.to(device) for k, v in inputs.items()} labels = inputs.pop("labels") source_idx = inputs.pop("source_idx", None) outputs = model(**inputs, labels=labels, source_idx=source_idx) @@ -196,6 +197,7 @@ training_args = TrainingArguments( save_steps=1000, logging_strategy="no", report_to="none", + pin_memory=True, ) # Trening