mod
This commit is contained in:
parent
9afa461252
commit
bf034eaf8f
34
hft.py
34
hft.py
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
|
||||
from datasets import Dataset
|
||||
import re
|
||||
import json
|
||||
|
|
@ -11,6 +11,7 @@ import pytesseract
|
|||
from PIL import Image
|
||||
from collections import defaultdict
|
||||
from huggingface_hub import login
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
|
||||
# Konfiguracja
|
||||
os.environ['TORCH_USE_CUDA_DSA'] = '1'
|
||||
|
|
@ -179,6 +180,17 @@ class CustomTrainer(Trainer):
|
|||
outputs = model(**inputs, source_idx=source_idx)
|
||||
return (outputs.loss, outputs) if return_outputs else outputs.loss
|
||||
|
||||
class CustomDataCollator(DataCollatorForLanguageModeling):
|
||||
def torch_call(self, examples):
|
||||
batch = super().torch_call(examples)
|
||||
|
||||
# Dodanie source_idx do batcha
|
||||
if "source_idx" in examples[0]:
|
||||
source_idx = [ex["source_idx"] for ex in examples]
|
||||
batch["source_idx"] = torch.tensor(source_idx, dtype=torch.long)
|
||||
|
||||
return batch
|
||||
|
||||
def main():
|
||||
source_mapper = SourceMapper()
|
||||
model_name = "crumb/nano-mistral"
|
||||
|
|
@ -193,8 +205,7 @@ def main():
|
|||
print("\nBrak danych do treningu!")
|
||||
return
|
||||
|
||||
dataset = Dataset.from_list(data)
|
||||
|
||||
# Przygotowanie datasetu
|
||||
def tokenize_function(examples):
|
||||
tokenized = tokenizer(
|
||||
examples["text"],
|
||||
|
|
@ -204,19 +215,16 @@ def main():
|
|||
return_tensors="pt"
|
||||
)
|
||||
return {
|
||||
"input_ids": tokenized["input_ids"][0],
|
||||
"attention_mask": tokenized["attention_mask"][0],
|
||||
"labels": tokenized["input_ids"][0].clone(),
|
||||
"input_ids": tokenized["input_ids"].squeeze(),
|
||||
"attention_mask": tokenized["attention_mask"].squeeze(),
|
||||
"labels": tokenized["input_ids"].squeeze().clone(),
|
||||
"source_idx": examples["source_idx"]
|
||||
}
|
||||
|
||||
tokenized_dataset = dataset.map(tokenize_function, batched=False)
|
||||
|
||||
data_collator = DataCollatorForLanguageModeling(
|
||||
tokenizer=tokenizer,
|
||||
mlm=False
|
||||
)
|
||||
dataset = Dataset.from_list(data)
|
||||
tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=16)
|
||||
|
||||
# Model i trening
|
||||
model = CustomModel(model_name, AutoModelForCausalLM.from_pretrained(model_name).config)
|
||||
model.source_mapper = source_mapper
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
|
@ -240,7 +248,7 @@ def main():
|
|||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized_dataset,
|
||||
data_collator=data_collator,
|
||||
data_collator=CustomDataCollator(tokenizer=tokenizer, mlm=False),
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue