mod
This commit is contained in:
parent
e393cf5fd8
commit
4342eb69c4
65
hft.py
65
hft.py
|
|
@ -12,9 +12,10 @@ from transformers import (
|
|||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
TrainingArguments,
|
||||
Trainer,
|
||||
DataCollatorForLanguageModeling
|
||||
)
|
||||
from datasets import Dataset
|
||||
from datasets import Dataset, Features, Value, Sequence
|
||||
from huggingface_hub import login
|
||||
|
||||
# Konfiguracja
|
||||
|
|
@ -23,8 +24,6 @@ login(token="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")
|
|||
|
||||
class LegalAITrainer:
|
||||
def __init__(self):
|
||||
self.source_mapper = defaultdict(lambda: len(self.source_mapper))
|
||||
self.idx_to_source = {}
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
class SourceMapper:
|
||||
|
|
@ -155,15 +154,23 @@ class LegalAITrainer:
|
|||
"is_legal": 0
|
||||
})
|
||||
|
||||
return Dataset.from_dict({k: [d[k] for d in data] for k in data[0]}), source_mapper
|
||||
features = Features({
|
||||
"text": Value("string"),
|
||||
"source_idx": Value("int32"),
|
||||
"is_legal": Value("int32")
|
||||
})
|
||||
|
||||
return Dataset.from_dict({
|
||||
"text": [d["text"] for d in data],
|
||||
"source_idx": [d["source_idx"] for d in data],
|
||||
"is_legal": [d["is_legal"] for d in data]
|
||||
}, features=features), source_mapper
|
||||
|
||||
def train(self, model_name="crumb/nano-mistral", data_dir="data", catalog_path="catalog.json"):
|
||||
# Przygotowanie danych
|
||||
dataset, source_mapper = self.prepare_data(data_dir, catalog_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# Tokenizacja
|
||||
def tokenize_fn(examples):
|
||||
tokenized = tokenizer(
|
||||
examples["text"],
|
||||
|
|
@ -173,22 +180,20 @@ class LegalAITrainer:
|
|||
return_tensors="pt"
|
||||
)
|
||||
return {
|
||||
"input_ids": tokenized["input_ids"].squeeze(),
|
||||
"attention_mask": tokenized["attention_mask"].squeeze(),
|
||||
"labels": tokenized["input_ids"].squeeze().clone(),
|
||||
"source_idx": examples["source_idx"]
|
||||
"input_ids": tokenized["input_ids"].squeeze().to(torch.int32),
|
||||
"attention_mask": tokenized["attention_mask"].squeeze().to(torch.int32),
|
||||
"labels": tokenized["input_ids"].squeeze().clone().to(torch.int32),
|
||||
"source_idx": torch.tensor(examples["source_idx"], dtype=torch.int32)
|
||||
}
|
||||
|
||||
tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16)
|
||||
|
||||
# Inicjalizacja modelu
|
||||
config = AutoModelForCausalLM.from_pretrained(model_name).config
|
||||
model = self.LegalModel(model_name, config).to(self.device)
|
||||
|
||||
# Konfiguracja treningu
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./legal_ai_model",
|
||||
num_train_epochs=5,
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=2e-5,
|
||||
|
|
@ -200,20 +205,17 @@ class LegalAITrainer:
|
|||
remove_unused_columns=False
|
||||
)
|
||||
|
||||
# Customowy Trainer
|
||||
class LegalTrainer(Trainer):
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
outputs = model(**inputs)
|
||||
loss = outputs.loss
|
||||
loss = outputs["loss"]
|
||||
|
||||
# Confidence loss
|
||||
target_conf = (inputs["source_idx"] != -1).float()
|
||||
conf_loss = nn.BCELoss()(outputs.confidence.squeeze(), target_conf)
|
||||
conf_loss = nn.BCELoss()(outputs["confidence"].squeeze(), target_conf)
|
||||
|
||||
total_loss = loss + 0.7*conf_loss
|
||||
total_loss = loss + 0.7 * conf_loss
|
||||
return (total_loss, outputs) if return_outputs else total_loss
|
||||
|
||||
# Trening
|
||||
trainer = LegalTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
|
|
@ -224,26 +226,24 @@ class LegalAITrainer:
|
|||
print("Rozpoczęcie treningu...")
|
||||
trainer.train()
|
||||
|
||||
# Zapisz model
|
||||
model.save_pretrained("./trained_legal_ai")
|
||||
tokenizer.save_pretrained("./trained_legal_ai")
|
||||
with open("./trained_legal_ai/source_mapper.json", "w") as f:
|
||||
json.dump(source_mapper.idx_to_source, f)
|
||||
|
||||
print("Trening zakończony i model zapisany!")
|
||||
print("Trening zakończony!")
|
||||
|
||||
def generate_response(self, prompt, confidence_threshold=0.65):
|
||||
# Ładowanie modelu
|
||||
model = self.LegalModel.from_pretrained("./trained_legal_ai",
|
||||
config=AutoModelForCausalLM.from_pretrained("crumb/nano-mistral").config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("./trained_legal_ai")
|
||||
model.to(self.device)
|
||||
model = self.LegalModel.from_pretrained(
|
||||
"./trained_legal_ai",
|
||||
config=AutoModelForCausalLM.from_pretrained("crumb/nano-mistral").config
|
||||
).to(self.device)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("./trained_legal_ai")
|
||||
|
||||
# Ładowanie mapowania źródeł
|
||||
with open("./trained_legal_ai/source_mapper.json", "r") as f:
|
||||
source_mapper = json.load(f)
|
||||
|
||||
# Przygotowanie wejścia
|
||||
inputs = tokenizer(
|
||||
f"[PROMPT] {prompt} [RESPONSE]",
|
||||
return_tensors="pt",
|
||||
|
|
@ -251,7 +251,6 @@ class LegalAITrainer:
|
|||
truncation=True
|
||||
).to(self.device)
|
||||
|
||||
# Generacja
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
input_ids=inputs.input_ids,
|
||||
|
|
@ -265,11 +264,9 @@ class LegalAITrainer:
|
|||
return_dict_in_generate=True
|
||||
)
|
||||
|
||||
# Analiza wyników
|
||||
full_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
|
||||
confidence = torch.sigmoid(outputs.scores[-1][:, tokenizer.eos_token_id]).item()
|
||||
|
||||
# Ekstrakcja i weryfikacja źródeł
|
||||
citations = list(set(re.findall(r"Art\.\s*\d+[a-z]*", full_text)))
|
||||
verified = [c for c in citations if any(c in s for s in source_mapper.values())]
|
||||
|
||||
|
|
@ -281,13 +278,13 @@ class LegalAITrainer:
|
|||
if __name__ == "__main__":
|
||||
legal_ai = LegalAITrainer()
|
||||
|
||||
# Etap 1: Trening
|
||||
# Trening
|
||||
legal_ai.train(
|
||||
model_name="crumb/nano-mistral",
|
||||
data_dir="./legal_docs",
|
||||
catalog_path="./catalog.json"
|
||||
)
|
||||
|
||||
# Etap 2: Testowanie
|
||||
test_prompt = "Ile dni urlopu przysługuje po 5 latach pracy w pełnym wymiarze?"
|
||||
# Test
|
||||
test_prompt = "Jakie są obowiązki pracodawcy w zakresie BHP?"
|
||||
print(legal_ai.generate_response(test_prompt))
|
||||
Loading…
Reference in New Issue