This commit is contained in:
l.gabrysiak 2025-02-25 23:21:49 +01:00
parent e393cf5fd8
commit 4342eb69c4
1 changed files with 31 additions and 34 deletions

65
hft.py
View File

@ -12,9 +12,10 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
AutoModelForCausalLM, AutoModelForCausalLM,
TrainingArguments, TrainingArguments,
Trainer,
DataCollatorForLanguageModeling DataCollatorForLanguageModeling
) )
from datasets import Dataset from datasets import Dataset, Features, Value, Sequence
from huggingface_hub import login from huggingface_hub import login
# Konfiguracja # Konfiguracja
@ -23,8 +24,6 @@ login(token="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")
class LegalAITrainer: class LegalAITrainer:
def __init__(self): def __init__(self):
self.source_mapper = defaultdict(lambda: len(self.source_mapper))
self.idx_to_source = {}
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SourceMapper: class SourceMapper:
@ -155,15 +154,23 @@ class LegalAITrainer:
"is_legal": 0 "is_legal": 0
}) })
return Dataset.from_dict({k: [d[k] for d in data] for k in data[0]}), source_mapper features = Features({
"text": Value("string"),
"source_idx": Value("int32"),
"is_legal": Value("int32")
})
return Dataset.from_dict({
"text": [d["text"] for d in data],
"source_idx": [d["source_idx"] for d in data],
"is_legal": [d["is_legal"] for d in data]
}, features=features), source_mapper
def train(self, model_name="crumb/nano-mistral", data_dir="data", catalog_path="catalog.json"): def train(self, model_name="crumb/nano-mistral", data_dir="data", catalog_path="catalog.json"):
# Przygotowanie danych
dataset, source_mapper = self.prepare_data(data_dir, catalog_path) dataset, source_mapper = self.prepare_data(data_dir, catalog_path)
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
# Tokenizacja
def tokenize_fn(examples): def tokenize_fn(examples):
tokenized = tokenizer( tokenized = tokenizer(
examples["text"], examples["text"],
@ -173,22 +180,20 @@ class LegalAITrainer:
return_tensors="pt" return_tensors="pt"
) )
return { return {
"input_ids": tokenized["input_ids"].squeeze(), "input_ids": tokenized["input_ids"].squeeze().to(torch.int32),
"attention_mask": tokenized["attention_mask"].squeeze(), "attention_mask": tokenized["attention_mask"].squeeze().to(torch.int32),
"labels": tokenized["input_ids"].squeeze().clone(), "labels": tokenized["input_ids"].squeeze().clone().to(torch.int32),
"source_idx": examples["source_idx"] "source_idx": torch.tensor(examples["source_idx"], dtype=torch.int32)
} }
tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16) tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16)
# Inicjalizacja modelu
config = AutoModelForCausalLM.from_pretrained(model_name).config config = AutoModelForCausalLM.from_pretrained(model_name).config
model = self.LegalModel(model_name, config).to(self.device) model = self.LegalModel(model_name, config).to(self.device)
# Konfiguracja treningu
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir="./legal_ai_model", output_dir="./legal_ai_model",
num_train_epochs=5, num_train_epochs=3,
per_device_train_batch_size=2, per_device_train_batch_size=2,
gradient_accumulation_steps=4, gradient_accumulation_steps=4,
learning_rate=2e-5, learning_rate=2e-5,
@ -200,20 +205,17 @@ class LegalAITrainer:
remove_unused_columns=False remove_unused_columns=False
) )
# Customowy Trainer
class LegalTrainer(Trainer): class LegalTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False): def compute_loss(self, model, inputs, return_outputs=False):
outputs = model(**inputs) outputs = model(**inputs)
loss = outputs.loss loss = outputs["loss"]
# Confidence loss
target_conf = (inputs["source_idx"] != -1).float() target_conf = (inputs["source_idx"] != -1).float()
conf_loss = nn.BCELoss()(outputs.confidence.squeeze(), target_conf) conf_loss = nn.BCELoss()(outputs["confidence"].squeeze(), target_conf)
total_loss = loss + 0.7*conf_loss total_loss = loss + 0.7 * conf_loss
return (total_loss, outputs) if return_outputs else total_loss return (total_loss, outputs) if return_outputs else total_loss
# Trening
trainer = LegalTrainer( trainer = LegalTrainer(
model=model, model=model,
args=training_args, args=training_args,
@ -224,26 +226,24 @@ class LegalAITrainer:
print("Rozpoczęcie treningu...") print("Rozpoczęcie treningu...")
trainer.train() trainer.train()
# Zapisz model
model.save_pretrained("./trained_legal_ai") model.save_pretrained("./trained_legal_ai")
tokenizer.save_pretrained("./trained_legal_ai") tokenizer.save_pretrained("./trained_legal_ai")
with open("./trained_legal_ai/source_mapper.json", "w") as f: with open("./trained_legal_ai/source_mapper.json", "w") as f:
json.dump(source_mapper.idx_to_source, f) json.dump(source_mapper.idx_to_source, f)
print("Trening zakończony i model zapisany!") print("Trening zakończony!")
def generate_response(self, prompt, confidence_threshold=0.65): def generate_response(self, prompt, confidence_threshold=0.65):
# Ładowanie modelu model = self.LegalModel.from_pretrained(
model = self.LegalModel.from_pretrained("./trained_legal_ai", "./trained_legal_ai",
config=AutoModelForCausalLM.from_pretrained("crumb/nano-mistral").config) config=AutoModelForCausalLM.from_pretrained("crumb/nano-mistral").config
tokenizer = AutoTokenizer.from_pretrained("./trained_legal_ai") ).to(self.device)
model.to(self.device)
tokenizer = AutoTokenizer.from_pretrained("./trained_legal_ai")
# Ładowanie mapowania źródeł
with open("./trained_legal_ai/source_mapper.json", "r") as f: with open("./trained_legal_ai/source_mapper.json", "r") as f:
source_mapper = json.load(f) source_mapper = json.load(f)
# Przygotowanie wejścia
inputs = tokenizer( inputs = tokenizer(
f"[PROMPT] {prompt} [RESPONSE]", f"[PROMPT] {prompt} [RESPONSE]",
return_tensors="pt", return_tensors="pt",
@ -251,7 +251,6 @@ class LegalAITrainer:
truncation=True truncation=True
).to(self.device) ).to(self.device)
# Generacja
with torch.no_grad(): with torch.no_grad():
outputs = model.generate( outputs = model.generate(
input_ids=inputs.input_ids, input_ids=inputs.input_ids,
@ -265,11 +264,9 @@ class LegalAITrainer:
return_dict_in_generate=True return_dict_in_generate=True
) )
# Analiza wyników
full_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) full_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
confidence = torch.sigmoid(outputs.scores[-1][:, tokenizer.eos_token_id]).item() confidence = torch.sigmoid(outputs.scores[-1][:, tokenizer.eos_token_id]).item()
# Ekstrakcja i weryfikacja źródeł
citations = list(set(re.findall(r"Art\.\s*\d+[a-z]*", full_text))) citations = list(set(re.findall(r"Art\.\s*\d+[a-z]*", full_text)))
verified = [c for c in citations if any(c in s for s in source_mapper.values())] verified = [c for c in citations if any(c in s for s in source_mapper.values())]
@ -281,13 +278,13 @@ class LegalAITrainer:
if __name__ == "__main__": if __name__ == "__main__":
legal_ai = LegalAITrainer() legal_ai = LegalAITrainer()
# Etap 1: Trening # Trening
legal_ai.train( legal_ai.train(
model_name="crumb/nano-mistral", model_name="crumb/nano-mistral",
data_dir="./legal_docs", data_dir="./legal_docs",
catalog_path="./catalog.json" catalog_path="./catalog.json"
) )
# Etap 2: Testowanie # Test
test_prompt = "Ile dni urlopu przysługuje po 5 latach pracy w pełnym wymiarze?" test_prompt = "Jakie są obowiązki pracodawcy w zakresie BHP?"
print(legal_ai.generate_response(test_prompt)) print(legal_ai.generate_response(test_prompt))