This commit is contained in:
l.gabrysiak 2025-02-25 23:31:53 +01:00
parent 85ba5346fb
commit 537e191d5f
1 changed files with 19 additions and 9 deletions

28
hft.py
View File

@ -3,6 +3,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import re import re
import json import json
import numpy as np
import PyPDF2 import PyPDF2
import docx2txt import docx2txt
import pytesseract import pytesseract
@ -18,7 +19,6 @@ from transformers import (
from datasets import Dataset, Features, Value from datasets import Dataset, Features, Value
from huggingface_hub import login from huggingface_hub import login
# Konfiguracja
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
login(token="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") login(token="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")
@ -160,8 +160,8 @@ class LegalAITrainer:
return Dataset.from_dict({ return Dataset.from_dict({
"text": [d["text"] for d in data], "text": [d["text"] for d in data],
"source_idx": [d["source_idx"] for d in data], "source_idx": np.array([d["source_idx"] for d in data], dtype=np.int32),
"is_legal": [d["is_legal"] for d in data] "is_legal": np.array([d["is_legal"] for d in data], dtype=np.int32)
}, features=features), source_mapper }, features=features), source_mapper
def train(self, model_name="crumb/nano-mistral", data_dir="data", catalog_path="catalog.json"): def train(self, model_name="crumb/nano-mistral", data_dir="data", catalog_path="catalog.json"):
@ -178,14 +178,24 @@ class LegalAITrainer:
return_tensors="pt" return_tensors="pt"
) )
return { return {
"input_ids": tokenized["input_ids"].squeeze(), "input_ids": tokenized["input_ids"].squeeze().tolist(),
"attention_mask": tokenized["attention_mask"].squeeze(), "attention_mask": tokenized["attention_mask"].squeeze().tolist(),
"labels": tokenized["input_ids"].squeeze().clone(), "labels": tokenized["input_ids"].squeeze().clone().tolist(),
"source_idx": torch.tensor(examples["source_idx"], dtype=torch.int32) "source_idx": examples["source_idx"]
} }
tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16) tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16)
class CustomDataCollator(DataCollatorForLanguageModeling):
def torch_call(self, examples):
batch = super().torch_call(examples)
if "source_idx" in examples[0]:
batch["source_idx"] = torch.tensor(
[ex["source_idx"] for ex in examples],
dtype=torch.int32
)
return batch
config = AutoModelForCausalLM.from_pretrained(model_name).config config = AutoModelForCausalLM.from_pretrained(model_name).config
model = self.LegalModel(model_name, config).to(self.device) model = self.LegalModel(model_name, config).to(self.device)
@ -218,7 +228,7 @@ class LegalAITrainer:
model=model, model=model,
args=training_args, args=training_args,
train_dataset=tokenized_dataset, train_dataset=tokenized_dataset,
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) data_collator=CustomDataCollator(tokenizer=tokenizer, mlm=False)
) )
print("Rozpoczęcie treningu...") print("Rozpoczęcie treningu...")
@ -282,5 +292,5 @@ if __name__ == "__main__":
catalog_path="./catalog.json" catalog_path="./catalog.json"
) )
test_prompt = "Jakie są kary za prowadzenie pojazdu pod wpływem alkoholu?" test_prompt = "Jakie są kary za nieprzestrzeganie przepisów RODO?"
print(legal_ai.generate_response(test_prompt)) print(legal_ai.generate_response(test_prompt))