This commit is contained in:
l.gabrysiak 2025-02-26 09:32:16 +01:00
parent ffe1bf5eab
commit 0b9bb7a371
1 changed files with 32 additions and 12 deletions

44
gpt.py
View File

@ -1,4 +1,5 @@
import os import os
import re
import torch import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from datasets import Dataset from datasets import Dataset
@ -7,12 +8,29 @@ from datasets import Dataset
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
MODEL_NAME = "gpt2" MODEL_NAME = "gpt2"
SPECIAL_TOKENS = ["[CITATION_START]", "[CITATION_END]"] SPECIAL_TOKENS = ["[CITATION_START]", "[CITATION_END]"]
TEXT_FILE_PATH = "scieżka/do/pliku_z_kodeksem.txt" # Zmień na właściwą ścieżkę
def prepare_simple_dataset(): def prepare_dataset_from_file(file_path):
return [ with open(file_path, 'r', encoding='utf-8') as f:
{"text": "[CITATION_START] Kodeks Pracy, Art. 1 [CITATION_END] Tekst artykułu..."}, text = f.read()
{"text": "[CITATION_START] Kodeks Pracy, Art. 2 [CITATION_END] Inny tekst..."}
] # Wydziel artykuły za pomocą wyrażenia regularnego
articles = re.findall(r'Art\.\s*\d+[a-z]*\..*?(?=\s*Art\.\s*\d+[a-z]*\.|\Z', text, flags=re.DOTALL)
formatted_articles = []
for article in articles:
# Usuń zbędne białe znaki
article = ' '.join(article.strip().split())
# Wydziel numer artykułu
art_match = re.match(r'Art\.\s*(\d+[a-z]*)\.\s*(.*)', article, re.DOTALL)
if art_match:
art_number = art_match.group(1)
art_text = art_match.group(2)
formatted = f"[CITATION_START] Kodeks Pracy, Art. {art_number} [CITATION_END] {art_text}"
formatted_articles.append({"text": formatted})
return formatted_articles
def main(): def main():
# Inicjalizacja tokenizera # Inicjalizacja tokenizera
@ -21,16 +39,16 @@ def main():
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
# Przygotowanie danych # Przygotowanie danych
data = prepare_simple_dataset() data = prepare_dataset_from_file(TEXT_FILE_PATH)
dataset = Dataset.from_dict({"text": [d["text"] for d in data]}) dataset = Dataset.from_dict({"text": [d["text"] for d in data]})
# Tokenizacja z prawidłowymi etykietami # Tokenizacja
def tokenize_function(examples): def tokenize_function(examples):
tokenized = tokenizer( tokenized = tokenizer(
examples["text"], examples["text"],
truncation=True, truncation=True,
padding="max_length", padding="max_length",
max_length=128, max_length=256, # Zwiększono dla dłuższych artykułów
return_tensors="pt" return_tensors="pt"
) )
tokenized["labels"] = tokenized["input_ids"].clone() tokenized["labels"] = tokenized["input_ids"].clone()
@ -50,11 +68,12 @@ def main():
# Konfiguracja treningu # Konfiguracja treningu
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir="./results", output_dir="./results",
num_train_epochs=1, num_train_epochs=3, # Zwiększono liczbę epok
per_device_train_batch_size=2, per_device_train_batch_size=2,
remove_unused_columns=True, learning_rate=5e-5,
logging_steps=1, logging_steps=10,
report_to="none" report_to="none",
save_strategy="no"
) )
# Trainer # Trainer
@ -67,6 +86,7 @@ def main():
print("Rozpoczęcie treningu...") print("Rozpoczęcie treningu...")
trainer.train() trainer.train()
trainer.save_model("./trained_model")
if __name__ == "__main__": if __name__ == "__main__":
main() main()