This commit is contained in:
l.gabrysiak 2025-02-26 09:32:16 +01:00
parent ffe1bf5eab
commit 0b9bb7a371
1 changed files with 32 additions and 12 deletions

44
gpt.py
View File

@ -1,4 +1,5 @@
import os
import re
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from datasets import Dataset
@ -7,12 +8,29 @@ from datasets import Dataset
os.environ["TOKENIZERS_PARALLELISM"] = "false"
MODEL_NAME = "gpt2"
SPECIAL_TOKENS = ["[CITATION_START]", "[CITATION_END]"]
TEXT_FILE_PATH = "scieżka/do/pliku_z_kodeksem.txt" # Zmień na właściwą ścieżkę
def prepare_simple_dataset():
return [
{"text": "[CITATION_START] Kodeks Pracy, Art. 1 [CITATION_END] Tekst artykułu..."},
{"text": "[CITATION_START] Kodeks Pracy, Art. 2 [CITATION_END] Inny tekst..."}
]
def prepare_dataset_from_file(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
text = f.read()
# Wydziel artykuły za pomocą wyrażenia regularnego
articles = re.findall(r'Art\.\s*\d+[a-z]*\..*?(?=\s*Art\.\s*\d+[a-z]*\.|\Z', text, flags=re.DOTALL)
formatted_articles = []
for article in articles:
# Usuń zbędne białe znaki
article = ' '.join(article.strip().split())
# Wydziel numer artykułu
art_match = re.match(r'Art\.\s*(\d+[a-z]*)\.\s*(.*)', article, re.DOTALL)
if art_match:
art_number = art_match.group(1)
art_text = art_match.group(2)
formatted = f"[CITATION_START] Kodeks Pracy, Art. {art_number} [CITATION_END] {art_text}"
formatted_articles.append({"text": formatted})
return formatted_articles
def main():
# Inicjalizacja tokenizera
@ -21,16 +39,16 @@ def main():
tokenizer.pad_token = tokenizer.eos_token
# Przygotowanie danych
data = prepare_simple_dataset()
data = prepare_dataset_from_file(TEXT_FILE_PATH)
dataset = Dataset.from_dict({"text": [d["text"] for d in data]})
# Tokenizacja z prawidłowymi etykietami
# Tokenizacja
def tokenize_function(examples):
tokenized = tokenizer(
examples["text"],
truncation=True,
padding="max_length",
max_length=128,
max_length=256, # Zwiększono dla dłuższych artykułów
return_tensors="pt"
)
tokenized["labels"] = tokenized["input_ids"].clone()
@ -50,11 +68,12 @@ def main():
# Konfiguracja treningu
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=1,
num_train_epochs=3, # Zwiększono liczbę epok
per_device_train_batch_size=2,
remove_unused_columns=True,
logging_steps=1,
report_to="none"
learning_rate=5e-5,
logging_steps=10,
report_to="none",
save_strategy="no"
)
# Trainer
@ -67,6 +86,7 @@ def main():
print("Rozpoczęcie treningu...")
trainer.train()
trainer.save_model("./trained_model")
if __name__ == "__main__":
main()