diff --git a/gemma.py b/gemma.py index 86f6bd1..aeaa6c8 100644 --- a/gemma.py +++ b/gemma.py @@ -38,7 +38,7 @@ dataset = create_training_data() # 5️⃣ Ładowanie modelu Gemma 2 7B device = "cuda" if torch.cuda.is_available() else "cpu" -model_name = "google/gemma-2b" +model_name = "google/gemma-2-2b" model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device) tokenizer = AutoTokenizer.from_pretrained(model_name)