This commit is contained in:
l.gabrysiak 2025-02-25 20:59:06 +01:00
parent b14dc7f278
commit fb53c760e9
2 changed files with 32 additions and 29 deletions

53
hft.py
View File

@ -12,7 +12,7 @@ import json
from collections import defaultdict
from huggingface_hub import login
# Konfiguracja
# Konfiguracja środowiska
os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ["TOKENIZERS_PARALLELISM"] = "false"
login(token="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")
@ -94,30 +94,6 @@ def prepare_dataset(directory, catalog_path, source_mapper):
})
return data
def tokenize_function(examples):
tokenized = tokenizer(
examples["text"],
truncation=True,
padding="max_length",
max_length=512,
return_tensors="pt"
)
tokenized["labels"] = tokenized["input_ids"].clone()
tokenized["source_idx"] = examples["source_idx"]
return tokenized
def custom_collate_fn(batch):
input_ids = torch.stack([torch.tensor(b["input_ids"]) for b in batch])
attention_mask = torch.stack([torch.tensor(b["attention_mask"]) for b in batch])
labels = torch.stack([torch.tensor(b["labels"]) for b in batch])
source_idx = torch.tensor([b.get("source_idx", -1) for b in batch], dtype=torch.long)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"source_idx": source_idx
}
class CustomModel(nn.Module):
def __init__(self, model_name, config):
super().__init__()
@ -135,6 +111,13 @@ class CustomModel(nn.Module):
def generate(self, *args, **kwargs):
return self.base_model.generate(*args, **kwargs)
class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
labels = inputs.pop("labels")
source_idx = inputs.pop("source_idx", None)
outputs = model(**inputs, labels=labels, source_idx=source_idx)
return (outputs.loss, outputs) if return_outputs else outputs.loss
def main():
# Inicjalizacja
source_mapper = SourceMapper()
@ -146,6 +129,20 @@ def main():
catalog_path = "file_catalog.json"
data = prepare_dataset("files", catalog_path, source_mapper)
dataset = Dataset.from_list(data)
# Tokenizacja
def tokenize_function(examples):
tokenized = tokenizer(
examples["text"],
truncation=True,
padding="max_length",
max_length=512,
return_tensors="pt"
)
tokenized["labels"] = tokenized["input_ids"].clone()
tokenized["source_idx"] = examples["source_idx"]
return tokenized
tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=8)
# Model
@ -168,11 +165,10 @@ def main():
report_to="none"
)
trainer = Trainer(
trainer = CustomTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=custom_collate_fn,
)
print("Rozpoczęcie treningu...")
trainer.train()
@ -192,8 +188,7 @@ def main():
pad_token_id=tokenizer.eos_token_id
)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
answer = answer.replace(question, "").strip()
answer = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(question, "").strip()
sources = set()
for match in re.finditer(r'Art\.\s+\d+', answer):

8
requirements.txt Normal file
View File

@ -0,0 +1,8 @@
torch>=2.0.1
transformers>=4.30.2
datasets>=2.13.1
Pillow>=9.4.0
pytesseract>=0.3.10
python-docx>=0.8.11
PyPDF2>=3.0.1
huggingface-hub>=0.16.4