diff --git a/hft.py b/hft.py index a3069f3..0e14915 100644 --- a/hft.py +++ b/hft.py @@ -12,7 +12,7 @@ import json from collections import defaultdict from huggingface_hub import login -# Konfiguracja +# Konfiguracja środowiska os.environ['TORCH_USE_CUDA_DSA'] = '1' os.environ["TOKENIZERS_PARALLELISM"] = "false" login(token="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") @@ -94,30 +94,6 @@ def prepare_dataset(directory, catalog_path, source_mapper): }) return data -def tokenize_function(examples): - tokenized = tokenizer( - examples["text"], - truncation=True, - padding="max_length", - max_length=512, - return_tensors="pt" - ) - tokenized["labels"] = tokenized["input_ids"].clone() - tokenized["source_idx"] = examples["source_idx"] - return tokenized - -def custom_collate_fn(batch): - input_ids = torch.stack([torch.tensor(b["input_ids"]) for b in batch]) - attention_mask = torch.stack([torch.tensor(b["attention_mask"]) for b in batch]) - labels = torch.stack([torch.tensor(b["labels"]) for b in batch]) - source_idx = torch.tensor([b.get("source_idx", -1) for b in batch], dtype=torch.long) - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "labels": labels, - "source_idx": source_idx - } - class CustomModel(nn.Module): def __init__(self, model_name, config): super().__init__() @@ -135,6 +111,13 @@ class CustomModel(nn.Module): def generate(self, *args, **kwargs): return self.base_model.generate(*args, **kwargs) +class CustomTrainer(Trainer): + def compute_loss(self, model, inputs, return_outputs=False, **kwargs): + labels = inputs.pop("labels") + source_idx = inputs.pop("source_idx", None) + outputs = model(**inputs, labels=labels, source_idx=source_idx) + return (outputs.loss, outputs) if return_outputs else outputs.loss + def main(): # Inicjalizacja source_mapper = SourceMapper() @@ -146,6 +129,20 @@ def main(): catalog_path = "file_catalog.json" data = prepare_dataset("files", catalog_path, source_mapper) dataset = Dataset.from_list(data) + + # Tokenizacja + def tokenize_function(examples): + tokenized = tokenizer( + examples["text"], + truncation=True, + padding="max_length", + max_length=512, + return_tensors="pt" + ) + tokenized["labels"] = tokenized["input_ids"].clone() + tokenized["source_idx"] = examples["source_idx"] + return tokenized + tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=8) # Model @@ -168,11 +165,10 @@ def main(): report_to="none" ) - trainer = Trainer( + trainer = CustomTrainer( model=model, args=training_args, train_dataset=tokenized_dataset, - data_collator=custom_collate_fn, ) print("Rozpoczęcie treningu...") trainer.train() @@ -192,8 +188,7 @@ def main(): pad_token_id=tokenizer.eos_token_id ) - answer = tokenizer.decode(outputs[0], skip_special_tokens=True) - answer = answer.replace(question, "").strip() + answer = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(question, "").strip() sources = set() for match in re.finditer(r'Art\.\s+\d+', answer): diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..cfc1745 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +torch>=2.0.1 +transformers>=4.30.2 +datasets>=2.13.1 +Pillow>=9.4.0 +pytesseract>=0.3.10 +python-docx>=0.8.11 +PyPDF2>=3.0.1 +huggingface-hub>=0.16.4 \ No newline at end of file