diff --git a/gpt.py b/gpt.py index cce91f3..f35a456 100644 --- a/gpt.py +++ b/gpt.py @@ -99,7 +99,7 @@ def main(): print("Rozpoczęcie treningu...") trainer.train() - trainer.save_model("./trained_model") + trainer.save_model("./trained_model/gpt") if __name__ == "__main__": main() \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000..9f56d58 --- /dev/null +++ b/test.py @@ -0,0 +1,15 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer + +model_path = "./trained_model/gpt" +model = AutoModelForCausalLM.from_pretrained(model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path) + +def generate_response(prompt, max_length=100): + inputs = tokenizer(prompt, return_tensors="pt") + outputs = model.generate(inputs.input_ids, max_length=max_length, num_return_sequences=1, do_sample=True) + response = tokenizer.decode(outputs[0], skip_special_tokens=True) + return response + +prompt = "Jakie są prawa pracownika zgodnie z Kodeksem pracy?" +response = generate_response(prompt) +print(response) \ No newline at end of file