This commit is contained in:
l.gabrysiak 2025-02-25 23:25:21 +01:00
parent 4342eb69c4
commit d00a183104
1 changed files with 8 additions and 6 deletions

14
hft.py
View File

@ -156,13 +156,13 @@ class LegalAITrainer:
features = Features({ features = Features({
"text": Value("string"), "text": Value("string"),
"source_idx": Value("int32"), "source_idx": Sequence(Value("int32")),
"is_legal": Value("int32") "is_legal": Value("int32")
}) })
return Dataset.from_dict({ return Dataset.from_dict({
"text": [d["text"] for d in data], "text": [d["text"] for d in data],
"source_idx": [d["source_idx"] for d in data], "source_idx": [[d["source_idx"]] for d in data], # Zwracamy jako listę list
"is_legal": [d["is_legal"] for d in data] "is_legal": [d["is_legal"] for d in data]
}, features=features), source_mapper }, features=features), source_mapper
@ -179,11 +179,13 @@ class LegalAITrainer:
max_length=512, max_length=512,
return_tensors="pt" return_tensors="pt"
) )
# Konwersja tensorów do list i odpowiednich typów
return { return {
"input_ids": tokenized["input_ids"].squeeze().to(torch.int32), "input_ids": [ids.tolist() for ids in tokenized["input_ids"]],
"attention_mask": tokenized["attention_mask"].squeeze().to(torch.int32), "attention_mask": [mask.tolist() for mask in tokenized["attention_mask"]],
"labels": tokenized["input_ids"].squeeze().clone().to(torch.int32), "labels": [labels.tolist() for labels in tokenized["input_ids"]],
"source_idx": torch.tensor(examples["source_idx"], dtype=torch.int32) "source_idx": [[idx] for idx in examples["source_idx"]] # Sekwencja długości 1
} }
tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16) tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16)