This commit is contained in:
l.gabrysiak 2025-02-25 23:21:49 +01:00
parent e393cf5fd8
commit 4342eb69c4
1 changed files with 31 additions and 34 deletions

65
hft.py
View File

@ -12,9 +12,10 @@ from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from datasets import Dataset
from datasets import Dataset, Features, Value, Sequence
from huggingface_hub import login
# Konfiguracja
@ -23,8 +24,6 @@ login(token="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")
class LegalAITrainer:
def __init__(self):
self.source_mapper = defaultdict(lambda: len(self.source_mapper))
self.idx_to_source = {}
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SourceMapper:
@ -155,15 +154,23 @@ class LegalAITrainer:
"is_legal": 0
})
return Dataset.from_dict({k: [d[k] for d in data] for k in data[0]}), source_mapper
features = Features({
"text": Value("string"),
"source_idx": Value("int32"),
"is_legal": Value("int32")
})
return Dataset.from_dict({
"text": [d["text"] for d in data],
"source_idx": [d["source_idx"] for d in data],
"is_legal": [d["is_legal"] for d in data]
}, features=features), source_mapper
def train(self, model_name="crumb/nano-mistral", data_dir="data", catalog_path="catalog.json"):
# Przygotowanie danych
dataset, source_mapper = self.prepare_data(data_dir, catalog_path)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# Tokenizacja
def tokenize_fn(examples):
tokenized = tokenizer(
examples["text"],
@ -173,22 +180,20 @@ class LegalAITrainer:
return_tensors="pt"
)
return {
"input_ids": tokenized["input_ids"].squeeze(),
"attention_mask": tokenized["attention_mask"].squeeze(),
"labels": tokenized["input_ids"].squeeze().clone(),
"source_idx": examples["source_idx"]
"input_ids": tokenized["input_ids"].squeeze().to(torch.int32),
"attention_mask": tokenized["attention_mask"].squeeze().to(torch.int32),
"labels": tokenized["input_ids"].squeeze().clone().to(torch.int32),
"source_idx": torch.tensor(examples["source_idx"], dtype=torch.int32)
}
tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16)
# Inicjalizacja modelu
config = AutoModelForCausalLM.from_pretrained(model_name).config
model = self.LegalModel(model_name, config).to(self.device)
# Konfiguracja treningu
training_args = TrainingArguments(
output_dir="./legal_ai_model",
num_train_epochs=5,
num_train_epochs=3,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=2e-5,
@ -200,20 +205,17 @@ class LegalAITrainer:
remove_unused_columns=False
)
# Customowy Trainer
class LegalTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
outputs = model(**inputs)
loss = outputs.loss
loss = outputs["loss"]
# Confidence loss
target_conf = (inputs["source_idx"] != -1).float()
conf_loss = nn.BCELoss()(outputs.confidence.squeeze(), target_conf)
conf_loss = nn.BCELoss()(outputs["confidence"].squeeze(), target_conf)
total_loss = loss + 0.7*conf_loss
total_loss = loss + 0.7 * conf_loss
return (total_loss, outputs) if return_outputs else total_loss
# Trening
trainer = LegalTrainer(
model=model,
args=training_args,
@ -224,26 +226,24 @@ class LegalAITrainer:
print("Rozpoczęcie treningu...")
trainer.train()
# Zapisz model
model.save_pretrained("./trained_legal_ai")
tokenizer.save_pretrained("./trained_legal_ai")
with open("./trained_legal_ai/source_mapper.json", "w") as f:
json.dump(source_mapper.idx_to_source, f)
print("Trening zakończony i model zapisany!")
print("Trening zakończony!")
def generate_response(self, prompt, confidence_threshold=0.65):
# Ładowanie modelu
model = self.LegalModel.from_pretrained("./trained_legal_ai",
config=AutoModelForCausalLM.from_pretrained("crumb/nano-mistral").config)
tokenizer = AutoTokenizer.from_pretrained("./trained_legal_ai")
model.to(self.device)
model = self.LegalModel.from_pretrained(
"./trained_legal_ai",
config=AutoModelForCausalLM.from_pretrained("crumb/nano-mistral").config
).to(self.device)
tokenizer = AutoTokenizer.from_pretrained("./trained_legal_ai")
# Ładowanie mapowania źródeł
with open("./trained_legal_ai/source_mapper.json", "r") as f:
source_mapper = json.load(f)
# Przygotowanie wejścia
inputs = tokenizer(
f"[PROMPT] {prompt} [RESPONSE]",
return_tensors="pt",
@ -251,7 +251,6 @@ class LegalAITrainer:
truncation=True
).to(self.device)
# Generacja
with torch.no_grad():
outputs = model.generate(
input_ids=inputs.input_ids,
@ -265,11 +264,9 @@ class LegalAITrainer:
return_dict_in_generate=True
)
# Analiza wyników
full_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
confidence = torch.sigmoid(outputs.scores[-1][:, tokenizer.eos_token_id]).item()
# Ekstrakcja i weryfikacja źródeł
citations = list(set(re.findall(r"Art\.\s*\d+[a-z]*", full_text)))
verified = [c for c in citations if any(c in s for s in source_mapper.values())]
@ -281,13 +278,13 @@ class LegalAITrainer:
if __name__ == "__main__":
legal_ai = LegalAITrainer()
# Etap 1: Trening
# Trening
legal_ai.train(
model_name="crumb/nano-mistral",
data_dir="./legal_docs",
catalog_path="./catalog.json"
)
# Etap 2: Testowanie
test_prompt = "Ile dni urlopu przysługuje po 5 latach pracy w pełnym wymiarze?"
# Test
test_prompt = "Jakie są obowiązki pracodawcy w zakresie BHP?"
print(legal_ai.generate_response(test_prompt))