This commit is contained in:
l.gabrysiak 2025-02-26 10:18:49 +01:00
parent 6816562163
commit 2ea7c63c7f
1 changed files with 10 additions and 3 deletions

13
test.py
View File

@ -3,13 +3,20 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
model_path = "./trained_model/gpt" model_path = "./trained_model/gpt"
model = AutoModelForCausalLM.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id
def generate_response(prompt, max_length=100): def generate_response(prompt, max_length=100):
inputs = tokenizer(prompt, return_tensors="pt") inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
outputs = model.generate(inputs.input_ids, max_length=max_length, num_return_sequences=1, do_sample=True) outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
pad_token_id=tokenizer.pad_token_id,
max_length=100
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True) response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response return response
prompt = "Jakie są prawa pracownika zgodnie z Kodeksem pracy?" prompt = "Ile dni urlopu przysługuje pracownikowi z 5 letnim stazem pracy w pełnym wymiarze pracy?"
response = generate_response(prompt) response = generate_response(prompt)
print(response) print(response)