From b30937ceb48afbb31925c7f1e33ed9d85930efcb Mon Sep 17 00:00:00 2001 From: "l.gabrysiak" Date: Wed, 26 Feb 2025 11:37:10 +0100 Subject: [PATCH] dodanie modelu --- allegro.py | 118 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 allegro.py diff --git a/allegro.py b/allegro.py new file mode 100644 index 0000000..6eef052 --- /dev/null +++ b/allegro.py @@ -0,0 +1,118 @@ +import os +import re +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling +from datasets import Dataset + +# Konfiguracja +os.environ["TOKENIZERS_PARALLELISM"] = "false" +MODEL_NAME = "allegro/herbert-base-cased" +SPECIAL_TOKENS = ["[CITATION_START]", "[CITATION_END]"] +TEXT_FILE_PATH = "./docs/kodekspracy.txt" # Zmień na właściwą ścieżkę + +def prepare_dataset_from_file(file_path): + with open(file_path, 'r', encoding='utf-8') as f: + text = f.read() + + # Wydziel artykuły za pomocą wyrażenia regularnego + articles = re.findall(r'Art\.\s*\d+[a-z]*\..*?(?=\s*Art\.\s*\d+[a-z]*\.|\Z)', text, flags=re.DOTALL) + + formatted_articles = [] + for article in articles: + # Usuń zbędne białe znaki + article = ' '.join(article.strip().split()) + + # Wydziel numer artykułu i treść + art_match = re.match(r'Art\.\s*(\d+[a-z]*)\.?\s*(.*)', article, re.DOTALL) + if art_match: + art_number = art_match.group(1) + art_text = art_match.group(2) + + # Podziel na paragrafy, jeśli istnieją + paragraphs = re.split(r'(§\s*\d+\.)', art_text) + if len(paragraphs) > 1: + formatted_paragraphs = [] + for i in range(1, len(paragraphs), 2): + para_num = paragraphs[i].strip() + para_text = paragraphs[i+1].strip() + formatted_paragraphs.append(f"{para_num} {para_text}") + formatted = f"[CITATION_START] Kodeks Pracy, Art. {art_number} [CITATION_END]\n" + "\n".join(formatted_paragraphs) + else: + formatted = f"[CITATION_START] Kodeks Pracy, Art. {art_number} [CITATION_END] {art_text}" + + formatted_articles.append({"text": formatted}) + + # Dodaj przykłady pytań i odpowiedzi + questions = [ + f"Zacytuj artykuł {art_number} Kodeksu pracy.", + f"Co mówi artykuł {art_number} Kodeksu pracy?", + f"Podaj treść artykułu {art_number} Kodeksu pracy." + ] + for question in questions: + formatted_articles.append({"text": f"{question}\n{formatted}"}) + + return formatted_articles + + +def main(): + # Inicjalizacja tokenizera + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + tokenizer.add_special_tokens({"additional_special_tokens": SPECIAL_TOKENS}) + tokenizer.pad_token = tokenizer.eos_token + + # Przygotowanie danych + data = prepare_dataset_from_file(TEXT_FILE_PATH) + dataset = Dataset.from_dict({"text": [d["text"] for d in data]}) + + # Tokenizacja + def tokenize_function(examples): + tokenized = tokenizer( + examples["text"], + truncation=True, + padding="max_length", + max_length=1024, # Zwiększono dla dłuższych artykułów + return_tensors="pt" + ) + tokenized["labels"] = tokenized["input_ids"].clone() + return tokenized + + tokenized_dataset = dataset.map(tokenize_function, batched=True) + + # Model i data collator + model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) + model.resize_token_embeddings(len(tokenizer), mean_resizing=False) + + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False + ) + + # Konfiguracja treningu + training_args = TrainingArguments( + output_dir="./results", + num_train_epochs=32, # Zwiększono liczbę epok + per_device_train_batch_size=2, + learning_rate=1e-5, #precyzja uczenia + logging_steps=10, + weight_decay=0.01, + report_to="none", + save_strategy="no", + load_best_model_at_end=True, # Ładowanie najlepszego modelu na końcu + ) + + + # Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized_dataset, + data_collator=data_collator + ) + + print("Rozpoczęcie treningu...") + trainer.train() + trainer.save_model("./trained_model/gpt") + tokenizer.save_pretrained("./trained_model/gpt") + +if __name__ == "__main__": + main() \ No newline at end of file