mod
This commit is contained in:
parent
7796cc9ef0
commit
9cbcaa36ee
4
hft.py
4
hft.py
|
|
@ -164,6 +164,7 @@ tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=32)
|
||||||
# Inicjalizacja modelu
|
# Inicjalizacja modelu
|
||||||
config = AutoModelForCausalLM.from_pretrained(model_name).config
|
config = AutoModelForCausalLM.from_pretrained(model_name).config
|
||||||
model = CustomModel.from_pretrained(model_name, config=config)
|
model = CustomModel.from_pretrained(model_name, config=config)
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
# Konfiguracja treningu
|
# Konfiguracja treningu
|
||||||
training_args = TrainingArguments(
|
training_args = TrainingArguments(
|
||||||
|
|
@ -185,7 +186,8 @@ trainer = CustomTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
train_dataset=tokenized_dataset,
|
train_dataset=tokenized_dataset,
|
||||||
data_collator=custom_collate_fn # Użyj niestandardowego collate_fn
|
data_collator=custom_collate_fn, # Użyj niestandardowego collate_fn
|
||||||
|
batch_size=8 # zmniejszenie rozmiaru batcha
|
||||||
)
|
)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue