mod
This commit is contained in:
parent
7b6dad7f2b
commit
0db71fc40d
34
hft.py
34
hft.py
|
|
@ -5,17 +5,17 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments,
|
|||
from datasets import Dataset
|
||||
import re
|
||||
import json
|
||||
import PyPDF2
|
||||
import docx2txt
|
||||
import pytesseract
|
||||
from PIL import Image
|
||||
from collections import defaultdict
|
||||
from huggingface_hub import login
|
||||
import PyPDF2 # Dodane
|
||||
import docx2txt # Dodane
|
||||
import pytesseract # Dodane
|
||||
from PIL import Image # Dodane
|
||||
|
||||
# Konfiguracja
|
||||
os.environ['TORCH_USE_CUDA_DSA'] = '1'
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
login(token="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") # Zastąp swoim tokenem
|
||||
login(token="TWÓJ_TOKEN_HF")
|
||||
|
||||
class SourceMapper:
|
||||
def __init__(self):
|
||||
|
|
@ -57,6 +57,8 @@ def extract_text_from_file(file_path):
|
|||
return text
|
||||
elif ext in ['.doc', '.docx']:
|
||||
return docx2txt.process(file_path)
|
||||
elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
|
||||
return pytesseract.image_to_string(Image.open(file_path))
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
|
@ -138,12 +140,23 @@ def main():
|
|||
max_length=512,
|
||||
return_tensors="pt"
|
||||
)
|
||||
tokenized["labels"] = tokenized["input_ids"].clone()
|
||||
tokenized["source_idx"] = examples["source_idx"]
|
||||
return tokenized
|
||||
return {
|
||||
"input_ids": tokenized["input_ids"],
|
||||
"attention_mask": tokenized["attention_mask"],
|
||||
"labels": tokenized["input_ids"].clone(),
|
||||
"source_idx": examples["source_idx"]
|
||||
}
|
||||
|
||||
tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=8)
|
||||
|
||||
def custom_collate_fn(features):
|
||||
return {
|
||||
"input_ids": torch.stack([torch.tensor(f["input_ids"]) for f in features]),
|
||||
"attention_mask": torch.stack([torch.tensor(f["attention_mask"]) for f in features]),
|
||||
"labels": torch.stack([torch.tensor(f["labels"]) for f in features]),
|
||||
"source_idx": torch.tensor([f["source_idx"] for f in features], dtype=torch.long)
|
||||
}
|
||||
|
||||
# Model
|
||||
config = AutoModelForCausalLM.from_pretrained(model_name).config
|
||||
model = CustomModel(model_name, config)
|
||||
|
|
@ -162,14 +175,15 @@ def main():
|
|||
save_strategy="steps",
|
||||
save_steps=1000,
|
||||
report_to="none",
|
||||
weight_decay=0.01
|
||||
weight_decay=0.01,
|
||||
remove_unused_columns=False
|
||||
)
|
||||
|
||||
trainer = CustomTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized_dataset,
|
||||
data_collator=lambda x: x
|
||||
data_collator=custom_collate_fn
|
||||
)
|
||||
print("Rozpoczęcie treningu...")
|
||||
trainer.train()
|
||||
|
|
|
|||
Loading…
Reference in New Issue