diff --git a/gpt.py b/gpt.py index d3b56d8..0e3bdf2 100644 --- a/gpt.py +++ b/gpt.py @@ -44,7 +44,6 @@ def prepare_dataset_from_file(file_path): return formatted_articles - def main(): # Inicjalizacja tokenizera tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) @@ -61,17 +60,17 @@ def main(): examples["text"], truncation=True, padding="max_length", - max_length=256, # Zwiększono dla dłuższych artykułów + max_length=2048, # Zwiększono dla dłuższych artykułów return_tensors="pt" ) tokenized["labels"] = tokenized["input_ids"].clone() return tokenized - tokenized_dataset = dataset.map(tokenize_function, batched=True) + tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names) # Model i data collator model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) - model.resize_token_embeddings(len(tokenizer), mean_resizing=False) + model.resize_token_embeddings(len(tokenizer)) data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, @@ -81,12 +80,17 @@ def main(): # Konfiguracja treningu training_args = TrainingArguments( output_dir="./results", - num_train_epochs=8, # Zwiększono liczbę epok - per_device_train_batch_size=2, - learning_rate=5e-5, + num_train_epochs=15, # Zwiększono liczbę epok + per_device_train_batch_size=4, # Zwiększono rozmiar batcha + learning_rate=2e-5, # Zmniejszono learning rate + weight_decay=0.01, # Dodano weight decay logging_steps=10, + save_steps=500, # Dodano zapisywanie modelu co 500 kroków + eval_steps=500, # Dodano ewaluację co 500 kroków + evaluation_strategy="steps", + load_best_model_at_end=True, # Ładowanie najlepszego modelu na końcu report_to="none", - save_strategy="no" + save_total_limit=2, # Ograniczenie liczby zapisywanych checkpointów ) # Trainer @@ -94,6 +98,7 @@ def main(): model=model, args=training_args, train_dataset=tokenized_dataset, + eval_dataset=tokenized_dataset, # Używamy tego samego zbioru do ewaluacji data_collator=data_collator ) @@ -103,4 +108,4 @@ def main(): tokenizer.save_pretrained("./trained_model/gpt") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/test.py b/test.py index 5d47e17..8143388 100644 --- a/test.py +++ b/test.py @@ -17,6 +17,6 @@ def generate_response(prompt, max_length=1000): response = tokenizer.decode(outputs[0], skip_special_tokens=True) return response -prompt = "Zacytuj art. 154 kodeksu pracy" +prompt = "Jak brzmi art. 154 kodeksu pracy" response = generate_response(prompt) print(response) \ No newline at end of file