diff --git a/hft.py b/hft.py index 17c259f..3ef5864 100644 --- a/hft.py +++ b/hft.py @@ -1,8 +1,7 @@ import os import torch import torch.nn as nn -#from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer -from transformers import GPTNeoForCausalLM # Zmiana importu +from transformers import AutoTokenizer, GPTNeoForCausalLM # Poprawiono importy from datasets import Dataset from PIL import Image import re @@ -151,7 +150,7 @@ class CustomTrainer(Trainer): # Inicjalizacja komponentów source_mapper = SourceMapper() -model_name = "EleutherAI/gpt-neo-2.7B" #"google/gemma-2-2b" +model_name = "EleutherAI/gpt-neo-2.7B" tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token @@ -162,9 +161,8 @@ dataset = Dataset.from_list(data) tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=32) # Inicjalizacja modelu -config = AutoModelForCausalLM.from_pretrained(model_name).config -#model = CustomModel.from_pretrained(model_name, config=config) -model = CustomModel.from_pretrained(model_name) +config = GPTNeoForCausalLM.from_pretrained(model_name).config +model = CustomModel.from_pretrained(model_name, config=config) model.resize_token_embeddings(len(tokenizer)) model.gradient_checkpointing_enable() @@ -210,7 +208,7 @@ def generate_answer(question, model, tokenizer, source_mapper, max_length=200): # Pobierz źródło z ostatniego tokena last_token_id = outputs.sequences[0][-1].item() - source_idx = model.source_embedding.weight.shape[0] - 1 # Tymczasowe rozwiązanie + source_idx = last_token_id % 1000 # Zaktualizuj sposób określania źródła source = source_mapper.get_source(source_idx) return f"{answer}\n\nŹródło: {source if source else 'Opracowanie własne'}"