poprawka trener c.d.
This commit is contained in:
parent
6d1150308b
commit
d4f742d0a8
5
hft.py
5
hft.py
|
|
@ -99,13 +99,14 @@ class CustomModel(AutoModelForCausalLM):
|
||||||
|
|
||||||
# Dostosowany Trainer
|
# Dostosowany Trainer
|
||||||
class CustomTrainer(Trainer):
|
class CustomTrainer(Trainer):
|
||||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
def compute_loss(self, model, inputs, return_outputs=False):
|
||||||
labels = inputs.pop("labels")
|
labels = inputs.pop("labels")
|
||||||
source = inputs.pop("source")
|
source = inputs.pop("source", None) # Użyj None jako wartości domyślnej
|
||||||
outputs = model(**inputs, labels=labels)
|
outputs = model(**inputs, labels=labels)
|
||||||
loss = outputs.loss
|
loss = outputs.loss
|
||||||
return (loss, outputs) if return_outputs else loss
|
return (loss, outputs) if return_outputs else loss
|
||||||
|
|
||||||
|
|
||||||
# Przygotowanie modelu i tokenizera
|
# Przygotowanie modelu i tokenizera
|
||||||
model_name = "google/gemma-2-2b"
|
model_name = "google/gemma-2-2b"
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue