mod
This commit is contained in:
parent
85ba5346fb
commit
537e191d5f
28
hft.py
28
hft.py
|
|
@ -3,6 +3,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import re
|
||||
import json
|
||||
import numpy as np
|
||||
import PyPDF2
|
||||
import docx2txt
|
||||
import pytesseract
|
||||
|
|
@ -18,7 +19,6 @@ from transformers import (
|
|||
from datasets import Dataset, Features, Value
|
||||
from huggingface_hub import login
|
||||
|
||||
# Konfiguracja
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
login(token="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")
|
||||
|
||||
|
|
@ -160,8 +160,8 @@ class LegalAITrainer:
|
|||
|
||||
return Dataset.from_dict({
|
||||
"text": [d["text"] for d in data],
|
||||
"source_idx": [d["source_idx"] for d in data],
|
||||
"is_legal": [d["is_legal"] for d in data]
|
||||
"source_idx": np.array([d["source_idx"] for d in data], dtype=np.int32),
|
||||
"is_legal": np.array([d["is_legal"] for d in data], dtype=np.int32)
|
||||
}, features=features), source_mapper
|
||||
|
||||
def train(self, model_name="crumb/nano-mistral", data_dir="data", catalog_path="catalog.json"):
|
||||
|
|
@ -178,14 +178,24 @@ class LegalAITrainer:
|
|||
return_tensors="pt"
|
||||
)
|
||||
return {
|
||||
"input_ids": tokenized["input_ids"].squeeze(),
|
||||
"attention_mask": tokenized["attention_mask"].squeeze(),
|
||||
"labels": tokenized["input_ids"].squeeze().clone(),
|
||||
"source_idx": torch.tensor(examples["source_idx"], dtype=torch.int32)
|
||||
"input_ids": tokenized["input_ids"].squeeze().tolist(),
|
||||
"attention_mask": tokenized["attention_mask"].squeeze().tolist(),
|
||||
"labels": tokenized["input_ids"].squeeze().clone().tolist(),
|
||||
"source_idx": examples["source_idx"]
|
||||
}
|
||||
|
||||
tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16)
|
||||
|
||||
class CustomDataCollator(DataCollatorForLanguageModeling):
|
||||
def torch_call(self, examples):
|
||||
batch = super().torch_call(examples)
|
||||
if "source_idx" in examples[0]:
|
||||
batch["source_idx"] = torch.tensor(
|
||||
[ex["source_idx"] for ex in examples],
|
||||
dtype=torch.int32
|
||||
)
|
||||
return batch
|
||||
|
||||
config = AutoModelForCausalLM.from_pretrained(model_name).config
|
||||
model = self.LegalModel(model_name, config).to(self.device)
|
||||
|
||||
|
|
@ -218,7 +228,7 @@ class LegalAITrainer:
|
|||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized_dataset,
|
||||
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
data_collator=CustomDataCollator(tokenizer=tokenizer, mlm=False)
|
||||
)
|
||||
|
||||
print("Rozpoczęcie treningu...")
|
||||
|
|
@ -282,5 +292,5 @@ if __name__ == "__main__":
|
|||
catalog_path="./catalog.json"
|
||||
)
|
||||
|
||||
test_prompt = "Jakie są kary za prowadzenie pojazdu pod wpływem alkoholu?"
|
||||
test_prompt = "Jakie są kary za nieprzestrzeganie przepisów RODO?"
|
||||
print(legal_ai.generate_response(test_prompt))
|
||||
Loading…
Reference in New Issue