diff --git a/hft.py b/hft.py index 2ae6488..27240fd 100644 --- a/hft.py +++ b/hft.py @@ -192,7 +192,7 @@ trainer.train() def generate_answer(question, model, tokenizer, source_mapper, max_length=200): inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512) - outputs = model.generate( + outputs = model.base_model.generate( **inputs, max_length=max_length, num_return_sequences=1, @@ -205,3 +205,23 @@ def generate_answer(question, model, tokenizer, source_mapper, max_length=200): # Pobierz źródło z ostatniego tokena last_token_id = outputs.sequences[0][-1].item() source_idx = model.source_embeddi + + + + + +# Po zakończeniu treningu modelu + +# Przygotowanie niezbędnych komponentów +model.eval() # Przełącz model w tryb ewaluacji +model = model.to("cuda" if torch.cuda.is_available() else "cpu") # Przenieś model na GPU, jeśli jest dostępne + +# Przykładowe pytanie +question = "Ile dni urlopu przysługuje pracownikowi?" + +# Generowanie odpowiedzi +answer = generate_answer(question, model, tokenizer, source_mapper) + +# Wyświetlenie wyniku +print("Pytanie:", question) +print("Odpowiedź:", answer) \ No newline at end of file