diff --git a/hft.py b/hft.py index 6e35ca2..6215459 100644 --- a/hft.py +++ b/hft.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn import re import json +import numpy as np import PyPDF2 import docx2txt import pytesseract @@ -18,7 +19,6 @@ from transformers import ( from datasets import Dataset, Features, Value from huggingface_hub import login -# Konfiguracja os.environ["TOKENIZERS_PARALLELISM"] = "false" login(token="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") @@ -160,8 +160,8 @@ class LegalAITrainer: return Dataset.from_dict({ "text": [d["text"] for d in data], - "source_idx": [d["source_idx"] for d in data], - "is_legal": [d["is_legal"] for d in data] + "source_idx": np.array([d["source_idx"] for d in data], dtype=np.int32), + "is_legal": np.array([d["is_legal"] for d in data], dtype=np.int32) }, features=features), source_mapper def train(self, model_name="crumb/nano-mistral", data_dir="data", catalog_path="catalog.json"): @@ -178,14 +178,24 @@ class LegalAITrainer: return_tensors="pt" ) return { - "input_ids": tokenized["input_ids"].squeeze(), - "attention_mask": tokenized["attention_mask"].squeeze(), - "labels": tokenized["input_ids"].squeeze().clone(), - "source_idx": torch.tensor(examples["source_idx"], dtype=torch.int32) + "input_ids": tokenized["input_ids"].squeeze().tolist(), + "attention_mask": tokenized["attention_mask"].squeeze().tolist(), + "labels": tokenized["input_ids"].squeeze().clone().tolist(), + "source_idx": examples["source_idx"] } tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16) + class CustomDataCollator(DataCollatorForLanguageModeling): + def torch_call(self, examples): + batch = super().torch_call(examples) + if "source_idx" in examples[0]: + batch["source_idx"] = torch.tensor( + [ex["source_idx"] for ex in examples], + dtype=torch.int32 + ) + return batch + config = AutoModelForCausalLM.from_pretrained(model_name).config model = self.LegalModel(model_name, config).to(self.device) @@ -218,7 +228,7 @@ class LegalAITrainer: model=model, args=training_args, train_dataset=tokenized_dataset, - data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + data_collator=CustomDataCollator(tokenizer=tokenizer, mlm=False) ) print("Rozpoczęcie treningu...") @@ -282,5 +292,5 @@ if __name__ == "__main__": catalog_path="./catalog.json" ) - test_prompt = "Jakie są kary za prowadzenie pojazdu pod wpływem alkoholu?" + test_prompt = "Jakie są kary za nieprzestrzeganie przepisów RODO?" print(legal_ai.generate_response(test_prompt)) \ No newline at end of file