From f49fd406c529e29994ebbf81e594f70874fb6d0c Mon Sep 17 00:00:00 2001 From: "l.gabrysiak" Date: Wed, 26 Feb 2025 09:49:28 +0100 Subject: [PATCH] mod gpt + test --- gpt.py | 2 +- test.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) create mode 100644 test.py diff --git a/gpt.py b/gpt.py index cce91f3..f35a456 100644 --- a/gpt.py +++ b/gpt.py @@ -99,7 +99,7 @@ def main(): print("Rozpoczęcie treningu...") trainer.train() - trainer.save_model("./trained_model") + trainer.save_model("./trained_model/gpt") if __name__ == "__main__": main() \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000..9f56d58 --- /dev/null +++ b/test.py @@ -0,0 +1,15 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer + +model_path = "./trained_model/gpt" +model = AutoModelForCausalLM.from_pretrained(model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path) + +def generate_response(prompt, max_length=100): + inputs = tokenizer(prompt, return_tensors="pt") + outputs = model.generate(inputs.input_ids, max_length=max_length, num_return_sequences=1, do_sample=True) + response = tokenizer.decode(outputs[0], skip_special_tokens=True) + return response + +prompt = "Jakie są prawa pracownika zgodnie z Kodeksem pracy?" +response = generate_response(prompt) +print(response) \ No newline at end of file