mod
This commit is contained in:
parent
9afa461252
commit
bf034eaf8f
34
hft.py
34
hft.py
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
|
|
@ -11,6 +11,7 @@ import pytesseract
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from huggingface_hub import login
|
from huggingface_hub import login
|
||||||
|
from transformers import DataCollatorForLanguageModeling
|
||||||
|
|
||||||
# Konfiguracja
|
# Konfiguracja
|
||||||
os.environ['TORCH_USE_CUDA_DSA'] = '1'
|
os.environ['TORCH_USE_CUDA_DSA'] = '1'
|
||||||
|
|
@ -179,6 +180,17 @@ class CustomTrainer(Trainer):
|
||||||
outputs = model(**inputs, source_idx=source_idx)
|
outputs = model(**inputs, source_idx=source_idx)
|
||||||
return (outputs.loss, outputs) if return_outputs else outputs.loss
|
return (outputs.loss, outputs) if return_outputs else outputs.loss
|
||||||
|
|
||||||
|
class CustomDataCollator(DataCollatorForLanguageModeling):
|
||||||
|
def torch_call(self, examples):
|
||||||
|
batch = super().torch_call(examples)
|
||||||
|
|
||||||
|
# Dodanie source_idx do batcha
|
||||||
|
if "source_idx" in examples[0]:
|
||||||
|
source_idx = [ex["source_idx"] for ex in examples]
|
||||||
|
batch["source_idx"] = torch.tensor(source_idx, dtype=torch.long)
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
source_mapper = SourceMapper()
|
source_mapper = SourceMapper()
|
||||||
model_name = "crumb/nano-mistral"
|
model_name = "crumb/nano-mistral"
|
||||||
|
|
@ -193,8 +205,7 @@ def main():
|
||||||
print("\nBrak danych do treningu!")
|
print("\nBrak danych do treningu!")
|
||||||
return
|
return
|
||||||
|
|
||||||
dataset = Dataset.from_list(data)
|
# Przygotowanie datasetu
|
||||||
|
|
||||||
def tokenize_function(examples):
|
def tokenize_function(examples):
|
||||||
tokenized = tokenizer(
|
tokenized = tokenizer(
|
||||||
examples["text"],
|
examples["text"],
|
||||||
|
|
@ -204,19 +215,16 @@ def main():
|
||||||
return_tensors="pt"
|
return_tensors="pt"
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"input_ids": tokenized["input_ids"][0],
|
"input_ids": tokenized["input_ids"].squeeze(),
|
||||||
"attention_mask": tokenized["attention_mask"][0],
|
"attention_mask": tokenized["attention_mask"].squeeze(),
|
||||||
"labels": tokenized["input_ids"][0].clone(),
|
"labels": tokenized["input_ids"].squeeze().clone(),
|
||||||
"source_idx": examples["source_idx"]
|
"source_idx": examples["source_idx"]
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenized_dataset = dataset.map(tokenize_function, batched=False)
|
dataset = Dataset.from_list(data)
|
||||||
|
tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=16)
|
||||||
data_collator = DataCollatorForLanguageModeling(
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
mlm=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Model i trening
|
||||||
model = CustomModel(model_name, AutoModelForCausalLM.from_pretrained(model_name).config)
|
model = CustomModel(model_name, AutoModelForCausalLM.from_pretrained(model_name).config)
|
||||||
model.source_mapper = source_mapper
|
model.source_mapper = source_mapper
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
@ -240,7 +248,7 @@ def main():
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
train_dataset=tokenized_dataset,
|
train_dataset=tokenized_dataset,
|
||||||
data_collator=data_collator,
|
data_collator=CustomDataCollator(tokenizer=tokenizer, mlm=False),
|
||||||
tokenizer=tokenizer
|
tokenizer=tokenizer
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue