mod
This commit is contained in:
parent
e393cf5fd8
commit
4342eb69c4
65
hft.py
65
hft.py
|
|
@ -12,9 +12,10 @@ from transformers import (
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
|
Trainer,
|
||||||
DataCollatorForLanguageModeling
|
DataCollatorForLanguageModeling
|
||||||
)
|
)
|
||||||
from datasets import Dataset
|
from datasets import Dataset, Features, Value, Sequence
|
||||||
from huggingface_hub import login
|
from huggingface_hub import login
|
||||||
|
|
||||||
# Konfiguracja
|
# Konfiguracja
|
||||||
|
|
@ -23,8 +24,6 @@ login(token="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")
|
||||||
|
|
||||||
class LegalAITrainer:
|
class LegalAITrainer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.source_mapper = defaultdict(lambda: len(self.source_mapper))
|
|
||||||
self.idx_to_source = {}
|
|
||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
class SourceMapper:
|
class SourceMapper:
|
||||||
|
|
@ -155,15 +154,23 @@ class LegalAITrainer:
|
||||||
"is_legal": 0
|
"is_legal": 0
|
||||||
})
|
})
|
||||||
|
|
||||||
return Dataset.from_dict({k: [d[k] for d in data] for k in data[0]}), source_mapper
|
features = Features({
|
||||||
|
"text": Value("string"),
|
||||||
|
"source_idx": Value("int32"),
|
||||||
|
"is_legal": Value("int32")
|
||||||
|
})
|
||||||
|
|
||||||
|
return Dataset.from_dict({
|
||||||
|
"text": [d["text"] for d in data],
|
||||||
|
"source_idx": [d["source_idx"] for d in data],
|
||||||
|
"is_legal": [d["is_legal"] for d in data]
|
||||||
|
}, features=features), source_mapper
|
||||||
|
|
||||||
def train(self, model_name="crumb/nano-mistral", data_dir="data", catalog_path="catalog.json"):
|
def train(self, model_name="crumb/nano-mistral", data_dir="data", catalog_path="catalog.json"):
|
||||||
# Przygotowanie danych
|
|
||||||
dataset, source_mapper = self.prepare_data(data_dir, catalog_path)
|
dataset, source_mapper = self.prepare_data(data_dir, catalog_path)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
# Tokenizacja
|
|
||||||
def tokenize_fn(examples):
|
def tokenize_fn(examples):
|
||||||
tokenized = tokenizer(
|
tokenized = tokenizer(
|
||||||
examples["text"],
|
examples["text"],
|
||||||
|
|
@ -173,22 +180,20 @@ class LegalAITrainer:
|
||||||
return_tensors="pt"
|
return_tensors="pt"
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"input_ids": tokenized["input_ids"].squeeze(),
|
"input_ids": tokenized["input_ids"].squeeze().to(torch.int32),
|
||||||
"attention_mask": tokenized["attention_mask"].squeeze(),
|
"attention_mask": tokenized["attention_mask"].squeeze().to(torch.int32),
|
||||||
"labels": tokenized["input_ids"].squeeze().clone(),
|
"labels": tokenized["input_ids"].squeeze().clone().to(torch.int32),
|
||||||
"source_idx": examples["source_idx"]
|
"source_idx": torch.tensor(examples["source_idx"], dtype=torch.int32)
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16)
|
tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16)
|
||||||
|
|
||||||
# Inicjalizacja modelu
|
|
||||||
config = AutoModelForCausalLM.from_pretrained(model_name).config
|
config = AutoModelForCausalLM.from_pretrained(model_name).config
|
||||||
model = self.LegalModel(model_name, config).to(self.device)
|
model = self.LegalModel(model_name, config).to(self.device)
|
||||||
|
|
||||||
# Konfiguracja treningu
|
|
||||||
training_args = TrainingArguments(
|
training_args = TrainingArguments(
|
||||||
output_dir="./legal_ai_model",
|
output_dir="./legal_ai_model",
|
||||||
num_train_epochs=5,
|
num_train_epochs=3,
|
||||||
per_device_train_batch_size=2,
|
per_device_train_batch_size=2,
|
||||||
gradient_accumulation_steps=4,
|
gradient_accumulation_steps=4,
|
||||||
learning_rate=2e-5,
|
learning_rate=2e-5,
|
||||||
|
|
@ -200,20 +205,17 @@ class LegalAITrainer:
|
||||||
remove_unused_columns=False
|
remove_unused_columns=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Customowy Trainer
|
|
||||||
class LegalTrainer(Trainer):
|
class LegalTrainer(Trainer):
|
||||||
def compute_loss(self, model, inputs, return_outputs=False):
|
def compute_loss(self, model, inputs, return_outputs=False):
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
loss = outputs.loss
|
loss = outputs["loss"]
|
||||||
|
|
||||||
# Confidence loss
|
|
||||||
target_conf = (inputs["source_idx"] != -1).float()
|
target_conf = (inputs["source_idx"] != -1).float()
|
||||||
conf_loss = nn.BCELoss()(outputs.confidence.squeeze(), target_conf)
|
conf_loss = nn.BCELoss()(outputs["confidence"].squeeze(), target_conf)
|
||||||
|
|
||||||
total_loss = loss + 0.7*conf_loss
|
total_loss = loss + 0.7 * conf_loss
|
||||||
return (total_loss, outputs) if return_outputs else total_loss
|
return (total_loss, outputs) if return_outputs else total_loss
|
||||||
|
|
||||||
# Trening
|
|
||||||
trainer = LegalTrainer(
|
trainer = LegalTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
|
|
@ -224,26 +226,24 @@ class LegalAITrainer:
|
||||||
print("Rozpoczęcie treningu...")
|
print("Rozpoczęcie treningu...")
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
# Zapisz model
|
|
||||||
model.save_pretrained("./trained_legal_ai")
|
model.save_pretrained("./trained_legal_ai")
|
||||||
tokenizer.save_pretrained("./trained_legal_ai")
|
tokenizer.save_pretrained("./trained_legal_ai")
|
||||||
with open("./trained_legal_ai/source_mapper.json", "w") as f:
|
with open("./trained_legal_ai/source_mapper.json", "w") as f:
|
||||||
json.dump(source_mapper.idx_to_source, f)
|
json.dump(source_mapper.idx_to_source, f)
|
||||||
|
|
||||||
print("Trening zakończony i model zapisany!")
|
print("Trening zakończony!")
|
||||||
|
|
||||||
def generate_response(self, prompt, confidence_threshold=0.65):
|
def generate_response(self, prompt, confidence_threshold=0.65):
|
||||||
# Ładowanie modelu
|
model = self.LegalModel.from_pretrained(
|
||||||
model = self.LegalModel.from_pretrained("./trained_legal_ai",
|
"./trained_legal_ai",
|
||||||
config=AutoModelForCausalLM.from_pretrained("crumb/nano-mistral").config)
|
config=AutoModelForCausalLM.from_pretrained("crumb/nano-mistral").config
|
||||||
tokenizer = AutoTokenizer.from_pretrained("./trained_legal_ai")
|
).to(self.device)
|
||||||
model.to(self.device)
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("./trained_legal_ai")
|
||||||
|
|
||||||
# Ładowanie mapowania źródeł
|
|
||||||
with open("./trained_legal_ai/source_mapper.json", "r") as f:
|
with open("./trained_legal_ai/source_mapper.json", "r") as f:
|
||||||
source_mapper = json.load(f)
|
source_mapper = json.load(f)
|
||||||
|
|
||||||
# Przygotowanie wejścia
|
|
||||||
inputs = tokenizer(
|
inputs = tokenizer(
|
||||||
f"[PROMPT] {prompt} [RESPONSE]",
|
f"[PROMPT] {prompt} [RESPONSE]",
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
|
|
@ -251,7 +251,6 @@ class LegalAITrainer:
|
||||||
truncation=True
|
truncation=True
|
||||||
).to(self.device)
|
).to(self.device)
|
||||||
|
|
||||||
# Generacja
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model.generate(
|
outputs = model.generate(
|
||||||
input_ids=inputs.input_ids,
|
input_ids=inputs.input_ids,
|
||||||
|
|
@ -265,11 +264,9 @@ class LegalAITrainer:
|
||||||
return_dict_in_generate=True
|
return_dict_in_generate=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Analiza wyników
|
|
||||||
full_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
|
full_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
|
||||||
confidence = torch.sigmoid(outputs.scores[-1][:, tokenizer.eos_token_id]).item()
|
confidence = torch.sigmoid(outputs.scores[-1][:, tokenizer.eos_token_id]).item()
|
||||||
|
|
||||||
# Ekstrakcja i weryfikacja źródeł
|
|
||||||
citations = list(set(re.findall(r"Art\.\s*\d+[a-z]*", full_text)))
|
citations = list(set(re.findall(r"Art\.\s*\d+[a-z]*", full_text)))
|
||||||
verified = [c for c in citations if any(c in s for s in source_mapper.values())]
|
verified = [c for c in citations if any(c in s for s in source_mapper.values())]
|
||||||
|
|
||||||
|
|
@ -281,13 +278,13 @@ class LegalAITrainer:
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
legal_ai = LegalAITrainer()
|
legal_ai = LegalAITrainer()
|
||||||
|
|
||||||
# Etap 1: Trening
|
# Trening
|
||||||
legal_ai.train(
|
legal_ai.train(
|
||||||
model_name="crumb/nano-mistral",
|
model_name="crumb/nano-mistral",
|
||||||
data_dir="./legal_docs",
|
data_dir="./legal_docs",
|
||||||
catalog_path="./catalog.json"
|
catalog_path="./catalog.json"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Etap 2: Testowanie
|
# Test
|
||||||
test_prompt = "Ile dni urlopu przysługuje po 5 latach pracy w pełnym wymiarze?"
|
test_prompt = "Jakie są obowiązki pracodawcy w zakresie BHP?"
|
||||||
print(legal_ai.generate_response(test_prompt))
|
print(legal_ai.generate_response(test_prompt))
|
||||||
Loading…
Reference in New Issue