From d00a1831047605569f1a74b9c9f39238bbc94398 Mon Sep 17 00:00:00 2001 From: "l.gabrysiak" Date: Tue, 25 Feb 2025 23:25:21 +0100 Subject: [PATCH] mod --- hft.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/hft.py b/hft.py index 5a2e8f7..13ee571 100644 --- a/hft.py +++ b/hft.py @@ -156,13 +156,13 @@ class LegalAITrainer: features = Features({ "text": Value("string"), - "source_idx": Value("int32"), + "source_idx": Sequence(Value("int32")), "is_legal": Value("int32") }) return Dataset.from_dict({ "text": [d["text"] for d in data], - "source_idx": [d["source_idx"] for d in data], + "source_idx": [[d["source_idx"]] for d in data], # Zwracamy jako listę list "is_legal": [d["is_legal"] for d in data] }, features=features), source_mapper @@ -179,11 +179,13 @@ class LegalAITrainer: max_length=512, return_tensors="pt" ) + + # Konwersja tensorów do list i odpowiednich typów return { - "input_ids": tokenized["input_ids"].squeeze().to(torch.int32), - "attention_mask": tokenized["attention_mask"].squeeze().to(torch.int32), - "labels": tokenized["input_ids"].squeeze().clone().to(torch.int32), - "source_idx": torch.tensor(examples["source_idx"], dtype=torch.int32) + "input_ids": [ids.tolist() for ids in tokenized["input_ids"]], + "attention_mask": [mask.tolist() for mask in tokenized["attention_mask"]], + "labels": [labels.tolist() for labels in tokenized["input_ids"]], + "source_idx": [[idx] for idx in examples["source_idx"]] # Sekwencja długości 1 } tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16)