From 0b9bb7a371c2f128e975c3be916243d7e5a82923 Mon Sep 17 00:00:00 2001 From: "l.gabrysiak" Date: Wed, 26 Feb 2025 09:32:16 +0100 Subject: [PATCH] mod gpt --- gpt.py | 44 ++++++++++++++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/gpt.py b/gpt.py index 213c655..81d8398 100644 --- a/gpt.py +++ b/gpt.py @@ -1,4 +1,5 @@ import os +import re import torch from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling from datasets import Dataset @@ -7,12 +8,29 @@ from datasets import Dataset os.environ["TOKENIZERS_PARALLELISM"] = "false" MODEL_NAME = "gpt2" SPECIAL_TOKENS = ["[CITATION_START]", "[CITATION_END]"] +TEXT_FILE_PATH = "scieżka/do/pliku_z_kodeksem.txt" # Zmień na właściwą ścieżkę -def prepare_simple_dataset(): - return [ - {"text": "[CITATION_START] Kodeks Pracy, Art. 1 [CITATION_END] Tekst artykułu..."}, - {"text": "[CITATION_START] Kodeks Pracy, Art. 2 [CITATION_END] Inny tekst..."} - ] +def prepare_dataset_from_file(file_path): + with open(file_path, 'r', encoding='utf-8') as f: + text = f.read() + + # Wydziel artykuły za pomocą wyrażenia regularnego + articles = re.findall(r'Art\.\s*\d+[a-z]*\..*?(?=\s*Art\.\s*\d+[a-z]*\.|\Z', text, flags=re.DOTALL) + + formatted_articles = [] + for article in articles: + # Usuń zbędne białe znaki + article = ' '.join(article.strip().split()) + + # Wydziel numer artykułu + art_match = re.match(r'Art\.\s*(\d+[a-z]*)\.\s*(.*)', article, re.DOTALL) + if art_match: + art_number = art_match.group(1) + art_text = art_match.group(2) + formatted = f"[CITATION_START] Kodeks Pracy, Art. {art_number} [CITATION_END] {art_text}" + formatted_articles.append({"text": formatted}) + + return formatted_articles def main(): # Inicjalizacja tokenizera @@ -21,16 +39,16 @@ def main(): tokenizer.pad_token = tokenizer.eos_token # Przygotowanie danych - data = prepare_simple_dataset() + data = prepare_dataset_from_file(TEXT_FILE_PATH) dataset = Dataset.from_dict({"text": [d["text"] for d in data]}) - # Tokenizacja z prawidłowymi etykietami + # Tokenizacja def tokenize_function(examples): tokenized = tokenizer( examples["text"], truncation=True, padding="max_length", - max_length=128, + max_length=256, # Zwiększono dla dłuższych artykułów return_tensors="pt" ) tokenized["labels"] = tokenized["input_ids"].clone() @@ -50,11 +68,12 @@ def main(): # Konfiguracja treningu training_args = TrainingArguments( output_dir="./results", - num_train_epochs=1, + num_train_epochs=3, # Zwiększono liczbę epok per_device_train_batch_size=2, - remove_unused_columns=True, - logging_steps=1, - report_to="none" + learning_rate=5e-5, + logging_steps=10, + report_to="none", + save_strategy="no" ) # Trainer @@ -67,6 +86,7 @@ def main(): print("Rozpoczęcie treningu...") trainer.train() + trainer.save_model("./trained_model") if __name__ == "__main__": main() \ No newline at end of file