diff --git a/hft.py b/hft.py index f774ead..5a2e8f7 100644 --- a/hft.py +++ b/hft.py @@ -12,9 +12,10 @@ from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, + Trainer, DataCollatorForLanguageModeling ) -from datasets import Dataset +from datasets import Dataset, Features, Value, Sequence from huggingface_hub import login # Konfiguracja @@ -23,8 +24,6 @@ login(token="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") class LegalAITrainer: def __init__(self): - self.source_mapper = defaultdict(lambda: len(self.source_mapper)) - self.idx_to_source = {} self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class SourceMapper: @@ -155,15 +154,23 @@ class LegalAITrainer: "is_legal": 0 }) - return Dataset.from_dict({k: [d[k] for d in data] for k in data[0]}), source_mapper + features = Features({ + "text": Value("string"), + "source_idx": Value("int32"), + "is_legal": Value("int32") + }) + + return Dataset.from_dict({ + "text": [d["text"] for d in data], + "source_idx": [d["source_idx"] for d in data], + "is_legal": [d["is_legal"] for d in data] + }, features=features), source_mapper def train(self, model_name="crumb/nano-mistral", data_dir="data", catalog_path="catalog.json"): - # Przygotowanie danych dataset, source_mapper = self.prepare_data(data_dir, catalog_path) tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token - # Tokenizacja def tokenize_fn(examples): tokenized = tokenizer( examples["text"], @@ -173,22 +180,20 @@ class LegalAITrainer: return_tensors="pt" ) return { - "input_ids": tokenized["input_ids"].squeeze(), - "attention_mask": tokenized["attention_mask"].squeeze(), - "labels": tokenized["input_ids"].squeeze().clone(), - "source_idx": examples["source_idx"] + "input_ids": tokenized["input_ids"].squeeze().to(torch.int32), + "attention_mask": tokenized["attention_mask"].squeeze().to(torch.int32), + "labels": tokenized["input_ids"].squeeze().clone().to(torch.int32), + "source_idx": torch.tensor(examples["source_idx"], dtype=torch.int32) } tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16) - # Inicjalizacja modelu config = AutoModelForCausalLM.from_pretrained(model_name).config model = self.LegalModel(model_name, config).to(self.device) - # Konfiguracja treningu training_args = TrainingArguments( output_dir="./legal_ai_model", - num_train_epochs=5, + num_train_epochs=3, per_device_train_batch_size=2, gradient_accumulation_steps=4, learning_rate=2e-5, @@ -200,20 +205,17 @@ class LegalAITrainer: remove_unused_columns=False ) - # Customowy Trainer class LegalTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): outputs = model(**inputs) - loss = outputs.loss + loss = outputs["loss"] - # Confidence loss target_conf = (inputs["source_idx"] != -1).float() - conf_loss = nn.BCELoss()(outputs.confidence.squeeze(), target_conf) + conf_loss = nn.BCELoss()(outputs["confidence"].squeeze(), target_conf) - total_loss = loss + 0.7*conf_loss + total_loss = loss + 0.7 * conf_loss return (total_loss, outputs) if return_outputs else total_loss - # Trening trainer = LegalTrainer( model=model, args=training_args, @@ -224,26 +226,24 @@ class LegalAITrainer: print("Rozpoczęcie treningu...") trainer.train() - # Zapisz model model.save_pretrained("./trained_legal_ai") tokenizer.save_pretrained("./trained_legal_ai") with open("./trained_legal_ai/source_mapper.json", "w") as f: json.dump(source_mapper.idx_to_source, f) - print("Trening zakończony i model zapisany!") + print("Trening zakończony!") def generate_response(self, prompt, confidence_threshold=0.65): - # Ładowanie modelu - model = self.LegalModel.from_pretrained("./trained_legal_ai", - config=AutoModelForCausalLM.from_pretrained("crumb/nano-mistral").config) - tokenizer = AutoTokenizer.from_pretrained("./trained_legal_ai") - model.to(self.device) + model = self.LegalModel.from_pretrained( + "./trained_legal_ai", + config=AutoModelForCausalLM.from_pretrained("crumb/nano-mistral").config + ).to(self.device) + + tokenizer = AutoTokenizer.from_pretrained("./trained_legal_ai") - # Ładowanie mapowania źródeł with open("./trained_legal_ai/source_mapper.json", "r") as f: source_mapper = json.load(f) - # Przygotowanie wejścia inputs = tokenizer( f"[PROMPT] {prompt} [RESPONSE]", return_tensors="pt", @@ -251,7 +251,6 @@ class LegalAITrainer: truncation=True ).to(self.device) - # Generacja with torch.no_grad(): outputs = model.generate( input_ids=inputs.input_ids, @@ -265,11 +264,9 @@ class LegalAITrainer: return_dict_in_generate=True ) - # Analiza wyników full_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) confidence = torch.sigmoid(outputs.scores[-1][:, tokenizer.eos_token_id]).item() - # Ekstrakcja i weryfikacja źródeł citations = list(set(re.findall(r"Art\.\s*\d+[a-z]*", full_text))) verified = [c for c in citations if any(c in s for s in source_mapper.values())] @@ -281,13 +278,13 @@ class LegalAITrainer: if __name__ == "__main__": legal_ai = LegalAITrainer() - # Etap 1: Trening + # Trening legal_ai.train( model_name="crumb/nano-mistral", data_dir="./legal_docs", catalog_path="./catalog.json" ) - # Etap 2: Testowanie - test_prompt = "Ile dni urlopu przysługuje po 5 latach pracy w pełnym wymiarze?" + # Test + test_prompt = "Jakie są obowiązki pracodawcy w zakresie BHP?" print(legal_ai.generate_response(test_prompt)) \ No newline at end of file