From b8d28ec05568f1ca2c9f013bcafc4ca27b3ceefe Mon Sep 17 00:00:00 2001 From: "l.gabrysiak" Date: Tue, 25 Feb 2025 19:43:37 +0100 Subject: [PATCH] testowanie --- hft.py | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/hft.py b/hft.py index c226e03..72ad4c0 100644 --- a/hft.py +++ b/hft.py @@ -190,22 +190,20 @@ trainer = CustomTrainer( trainer.train() # Funkcja generująca odpowiedź -def generate_answer(question, model, tokenizer, source_mapper, max_length=200): - inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512) - - outputs = model.base_model.generate( - **inputs, - max_length=max_length, - num_return_sequences=1, - return_dict_in_generate=True, - output_scores=True, - ) +def generate_answer(question, max_length=200): + model.eval() + inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512).to(device) + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_length=max_length, + num_return_sequences=1, + return_dict_in_generate=True + ) + answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) - - # Pobierz źródło z ostatniego tokena - last_token_id = outputs.sequences[0][-1].item() - source_idx = model.source_embeddi + return answer # Utwórz katalog do zapisu modelu save_directory = "./trained_model/ably.do/hse" @@ -227,4 +225,13 @@ with open(os.path.join(save_directory, "source_mapper.json"), 'w') as f: json.dump(source_mapper_data, f) # 4. Zapisz konfigurację modelu (opcjonalnie, ale zalecane) -model.base_model.config.save_pretrained(save_directory) \ No newline at end of file +model.base_model.config.save_pretrained(save_directory) + +# Przeprowadź testy +test_questions = [ + "Ile dni urlopu przysługuje pracownikowi, który przepracował w pełnym wymiarze pracy 5 lat?" +] + +for q in test_questions: + print(f"Pytanie: {q}") + print(f"Odpowiedź: {generate_answer(q)}\n{'='*50}") \ No newline at end of file