From 204dd4421a41707d3aa5b6ad6db72b8bdb23b254 Mon Sep 17 00:00:00 2001 From: "l.gabrysiak" Date: Tue, 25 Feb 2025 15:02:36 +0100 Subject: [PATCH] mod --- hft.py | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/hft.py b/hft.py index 266ccf3..1a679ce 100644 --- a/hft.py +++ b/hft.py @@ -18,6 +18,11 @@ torch.cuda.empty_cache() # Logowanie do Hugging Face Hub login(token="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") os.environ["TOKENIZERS_PARALLELISM"] = "false" +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + +def free_memory(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() class SourceMapper: def __init__(self): @@ -54,7 +59,7 @@ def extract_text_from_file(file_path): with open(file_path, 'rb') as file: reader = PyPDF2.PdfReader(file) for page in reader.pages: - text += page.extract_text() + text += page.extract_text() or "" return text elif ext in ['.doc', '.docx']: return docx2txt.process(file_path) @@ -76,7 +81,7 @@ def prepare_dataset(directory, catalog_path, source_mapper): doc_type = identify_legal_document(file, file_catalog) if doc_type != "Opracowanie własne": - articles = re.split(r'(Art\.?\s+\d+[\.\s])', text) + articles = re.split(r'(Art\.\s+\d+\.)', text) for i in range(1, len(articles), 2): article_number = articles[i].strip() article_content = articles[i+1].strip() if i+1 < len(articles) else "" @@ -137,14 +142,6 @@ class CustomModel(GPTNeoForCausalLM): outputs.logits += source_embeds return outputs -class CustomTrainer(Trainer): - def compute_loss(self, model, inputs, return_outputs=False, **kwargs): - labels = inputs.pop("labels") - with autocast(): - source_idx = inputs.pop("source_idx") - outputs = model(**inputs, labels=labels, source_idx=source_idx) - return (outputs.loss, outputs) if return_outputs else outputs.loss - source_mapper = SourceMapper() model_name = "EleutherAI/gpt-neo-2.7B" tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -152,7 +149,7 @@ tokenizer.pad_token = tokenizer.eos_token data = prepare_dataset("files", "file_catalog.json", source_mapper) dataset = Dataset.from_list(data) -tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=32) +tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=16) config = AutoModelForCausalLM.from_pretrained(model_name).config model = CustomModel.from_pretrained(model_name) @@ -164,26 +161,28 @@ model.gradient_checkpointing_enable() training_args = TrainingArguments( output_dir="./results", num_train_epochs=3, - gradient_accumulation_steps=4, + gradient_accumulation_steps=8, learning_rate=2e-5, fp16=True, - logging_steps=100, + logging_steps=50, save_strategy="steps", - save_steps=1000, - report_to="none", - per_device_train_batch_size=4, - per_device_eval_batch_size=4, + save_steps=500, + per_device_train_batch_size=2, + per_device_eval_batch_size=2, logging_dir='./logs' ) -trainer = CustomTrainer( +trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset, data_collator=custom_collate_fn ) + trainer.train() +free_memory() + # Funkcja generująca odpowiedź def generate_answer(question, model, tokenizer, source_mapper, max_length=200): inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512)