From 2ea7c63c7fbd9e4b2f0abf58c1d5740ab7474021 Mon Sep 17 00:00:00 2001 From: "l.gabrysiak" Date: Wed, 26 Feb 2025 10:18:49 +0100 Subject: [PATCH] mod test --- test.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/test.py b/test.py index 9f56d58..2e11130 100644 --- a/test.py +++ b/test.py @@ -3,13 +3,20 @@ from transformers import AutoModelForCausalLM, AutoTokenizer model_path = "./trained_model/gpt" model = AutoModelForCausalLM.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path) +tokenizer.pad_token = tokenizer.eos_token +model.config.pad_token_id = tokenizer.eos_token_id def generate_response(prompt, max_length=100): - inputs = tokenizer(prompt, return_tensors="pt") - outputs = model.generate(inputs.input_ids, max_length=max_length, num_return_sequences=1, do_sample=True) + inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) + outputs = model.generate( + inputs.input_ids, + attention_mask=inputs.attention_mask, + pad_token_id=tokenizer.pad_token_id, + max_length=100 + ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) return response -prompt = "Jakie są prawa pracownika zgodnie z Kodeksem pracy?" +prompt = "Ile dni urlopu przysługuje pracownikowi z 5 letnim stazem pracy w pełnym wymiarze pracy?" response = generate_response(prompt) print(response) \ No newline at end of file