mod
This commit is contained in:
parent
b14dc7f278
commit
fb53c760e9
53
hft.py
53
hft.py
|
|
@ -12,7 +12,7 @@ import json
|
|||
from collections import defaultdict
|
||||
from huggingface_hub import login
|
||||
|
||||
# Konfiguracja
|
||||
# Konfiguracja środowiska
|
||||
os.environ['TORCH_USE_CUDA_DSA'] = '1'
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
login(token="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")
|
||||
|
|
@ -94,30 +94,6 @@ def prepare_dataset(directory, catalog_path, source_mapper):
|
|||
})
|
||||
return data
|
||||
|
||||
def tokenize_function(examples):
|
||||
tokenized = tokenizer(
|
||||
examples["text"],
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=512,
|
||||
return_tensors="pt"
|
||||
)
|
||||
tokenized["labels"] = tokenized["input_ids"].clone()
|
||||
tokenized["source_idx"] = examples["source_idx"]
|
||||
return tokenized
|
||||
|
||||
def custom_collate_fn(batch):
|
||||
input_ids = torch.stack([torch.tensor(b["input_ids"]) for b in batch])
|
||||
attention_mask = torch.stack([torch.tensor(b["attention_mask"]) for b in batch])
|
||||
labels = torch.stack([torch.tensor(b["labels"]) for b in batch])
|
||||
source_idx = torch.tensor([b.get("source_idx", -1) for b in batch], dtype=torch.long)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
"source_idx": source_idx
|
||||
}
|
||||
|
||||
class CustomModel(nn.Module):
|
||||
def __init__(self, model_name, config):
|
||||
super().__init__()
|
||||
|
|
@ -135,6 +111,13 @@ class CustomModel(nn.Module):
|
|||
def generate(self, *args, **kwargs):
|
||||
return self.base_model.generate(*args, **kwargs)
|
||||
|
||||
class CustomTrainer(Trainer):
|
||||
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
||||
labels = inputs.pop("labels")
|
||||
source_idx = inputs.pop("source_idx", None)
|
||||
outputs = model(**inputs, labels=labels, source_idx=source_idx)
|
||||
return (outputs.loss, outputs) if return_outputs else outputs.loss
|
||||
|
||||
def main():
|
||||
# Inicjalizacja
|
||||
source_mapper = SourceMapper()
|
||||
|
|
@ -146,6 +129,20 @@ def main():
|
|||
catalog_path = "file_catalog.json"
|
||||
data = prepare_dataset("files", catalog_path, source_mapper)
|
||||
dataset = Dataset.from_list(data)
|
||||
|
||||
# Tokenizacja
|
||||
def tokenize_function(examples):
|
||||
tokenized = tokenizer(
|
||||
examples["text"],
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=512,
|
||||
return_tensors="pt"
|
||||
)
|
||||
tokenized["labels"] = tokenized["input_ids"].clone()
|
||||
tokenized["source_idx"] = examples["source_idx"]
|
||||
return tokenized
|
||||
|
||||
tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=8)
|
||||
|
||||
# Model
|
||||
|
|
@ -168,11 +165,10 @@ def main():
|
|||
report_to="none"
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
trainer = CustomTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized_dataset,
|
||||
data_collator=custom_collate_fn,
|
||||
)
|
||||
print("Rozpoczęcie treningu...")
|
||||
trainer.train()
|
||||
|
|
@ -192,8 +188,7 @@ def main():
|
|||
pad_token_id=tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
answer = answer.replace(question, "").strip()
|
||||
answer = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(question, "").strip()
|
||||
|
||||
sources = set()
|
||||
for match in re.finditer(r'Art\.\s+\d+', answer):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,8 @@
|
|||
torch>=2.0.1
|
||||
transformers>=4.30.2
|
||||
datasets>=2.13.1
|
||||
Pillow>=9.4.0
|
||||
pytesseract>=0.3.10
|
||||
python-docx>=0.8.11
|
||||
PyPDF2>=3.0.1
|
||||
huggingface-hub>=0.16.4
|
||||
Loading…
Reference in New Issue