mod
This commit is contained in:
parent
85ba5346fb
commit
537e191d5f
28
hft.py
28
hft.py
|
|
@ -3,6 +3,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
|
import numpy as np
|
||||||
import PyPDF2
|
import PyPDF2
|
||||||
import docx2txt
|
import docx2txt
|
||||||
import pytesseract
|
import pytesseract
|
||||||
|
|
@ -18,7 +19,6 @@ from transformers import (
|
||||||
from datasets import Dataset, Features, Value
|
from datasets import Dataset, Features, Value
|
||||||
from huggingface_hub import login
|
from huggingface_hub import login
|
||||||
|
|
||||||
# Konfiguracja
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
login(token="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")
|
login(token="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")
|
||||||
|
|
||||||
|
|
@ -160,8 +160,8 @@ class LegalAITrainer:
|
||||||
|
|
||||||
return Dataset.from_dict({
|
return Dataset.from_dict({
|
||||||
"text": [d["text"] for d in data],
|
"text": [d["text"] for d in data],
|
||||||
"source_idx": [d["source_idx"] for d in data],
|
"source_idx": np.array([d["source_idx"] for d in data], dtype=np.int32),
|
||||||
"is_legal": [d["is_legal"] for d in data]
|
"is_legal": np.array([d["is_legal"] for d in data], dtype=np.int32)
|
||||||
}, features=features), source_mapper
|
}, features=features), source_mapper
|
||||||
|
|
||||||
def train(self, model_name="crumb/nano-mistral", data_dir="data", catalog_path="catalog.json"):
|
def train(self, model_name="crumb/nano-mistral", data_dir="data", catalog_path="catalog.json"):
|
||||||
|
|
@ -178,14 +178,24 @@ class LegalAITrainer:
|
||||||
return_tensors="pt"
|
return_tensors="pt"
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"input_ids": tokenized["input_ids"].squeeze(),
|
"input_ids": tokenized["input_ids"].squeeze().tolist(),
|
||||||
"attention_mask": tokenized["attention_mask"].squeeze(),
|
"attention_mask": tokenized["attention_mask"].squeeze().tolist(),
|
||||||
"labels": tokenized["input_ids"].squeeze().clone(),
|
"labels": tokenized["input_ids"].squeeze().clone().tolist(),
|
||||||
"source_idx": torch.tensor(examples["source_idx"], dtype=torch.int32)
|
"source_idx": examples["source_idx"]
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16)
|
tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16)
|
||||||
|
|
||||||
|
class CustomDataCollator(DataCollatorForLanguageModeling):
|
||||||
|
def torch_call(self, examples):
|
||||||
|
batch = super().torch_call(examples)
|
||||||
|
if "source_idx" in examples[0]:
|
||||||
|
batch["source_idx"] = torch.tensor(
|
||||||
|
[ex["source_idx"] for ex in examples],
|
||||||
|
dtype=torch.int32
|
||||||
|
)
|
||||||
|
return batch
|
||||||
|
|
||||||
config = AutoModelForCausalLM.from_pretrained(model_name).config
|
config = AutoModelForCausalLM.from_pretrained(model_name).config
|
||||||
model = self.LegalModel(model_name, config).to(self.device)
|
model = self.LegalModel(model_name, config).to(self.device)
|
||||||
|
|
||||||
|
|
@ -218,7 +228,7 @@ class LegalAITrainer:
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
train_dataset=tokenized_dataset,
|
train_dataset=tokenized_dataset,
|
||||||
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
data_collator=CustomDataCollator(tokenizer=tokenizer, mlm=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Rozpoczęcie treningu...")
|
print("Rozpoczęcie treningu...")
|
||||||
|
|
@ -282,5 +292,5 @@ if __name__ == "__main__":
|
||||||
catalog_path="./catalog.json"
|
catalog_path="./catalog.json"
|
||||||
)
|
)
|
||||||
|
|
||||||
test_prompt = "Jakie są kary za prowadzenie pojazdu pod wpływem alkoholu?"
|
test_prompt = "Jakie są kary za nieprzestrzeganie przepisów RODO?"
|
||||||
print(legal_ai.generate_response(test_prompt))
|
print(legal_ai.generate_response(test_prompt))
|
||||||
Loading…
Reference in New Issue