This commit is contained in:
l.gabrysiak 2025-02-25 23:31:53 +01:00
parent 85ba5346fb
commit 537e191d5f
1 changed files with 19 additions and 9 deletions

28
hft.py
View File

@ -3,6 +3,7 @@ import torch
import torch.nn as nn
import re
import json
import numpy as np
import PyPDF2
import docx2txt
import pytesseract
@ -18,7 +19,6 @@ from transformers import (
from datasets import Dataset, Features, Value
from huggingface_hub import login
# Konfiguracja
os.environ["TOKENIZERS_PARALLELISM"] = "false"
login(token="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")
@ -160,8 +160,8 @@ class LegalAITrainer:
return Dataset.from_dict({
"text": [d["text"] for d in data],
"source_idx": [d["source_idx"] for d in data],
"is_legal": [d["is_legal"] for d in data]
"source_idx": np.array([d["source_idx"] for d in data], dtype=np.int32),
"is_legal": np.array([d["is_legal"] for d in data], dtype=np.int32)
}, features=features), source_mapper
def train(self, model_name="crumb/nano-mistral", data_dir="data", catalog_path="catalog.json"):
@ -178,14 +178,24 @@ class LegalAITrainer:
return_tensors="pt"
)
return {
"input_ids": tokenized["input_ids"].squeeze(),
"attention_mask": tokenized["attention_mask"].squeeze(),
"labels": tokenized["input_ids"].squeeze().clone(),
"source_idx": torch.tensor(examples["source_idx"], dtype=torch.int32)
"input_ids": tokenized["input_ids"].squeeze().tolist(),
"attention_mask": tokenized["attention_mask"].squeeze().tolist(),
"labels": tokenized["input_ids"].squeeze().clone().tolist(),
"source_idx": examples["source_idx"]
}
tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16)
class CustomDataCollator(DataCollatorForLanguageModeling):
def torch_call(self, examples):
batch = super().torch_call(examples)
if "source_idx" in examples[0]:
batch["source_idx"] = torch.tensor(
[ex["source_idx"] for ex in examples],
dtype=torch.int32
)
return batch
config = AutoModelForCausalLM.from_pretrained(model_name).config
model = self.LegalModel(model_name, config).to(self.device)
@ -218,7 +228,7 @@ class LegalAITrainer:
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
data_collator=CustomDataCollator(tokenizer=tokenizer, mlm=False)
)
print("Rozpoczęcie treningu...")
@ -282,5 +292,5 @@ if __name__ == "__main__":
catalog_path="./catalog.json"
)
test_prompt = "Jakie są kary za prowadzenie pojazdu pod wpływem alkoholu?"
test_prompt = "Jakie są kary za nieprzestrzeganie przepisów RODO?"
print(legal_ai.generate_response(test_prompt))