This commit is contained in:
l.gabrysiak 2025-02-25 14:46:02 +01:00
parent b519db2b27
commit 30234c332a
1 changed files with 7 additions and 5 deletions

12
hft.py
View File

@ -1,7 +1,8 @@
import os import os
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import AutoTokenizer, GPTNeoForCausalLM, Trainer # Poprawiono importy #from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from transformers import GPTNeoForCausalLM, Trainer, TrainingArguments # Zmiana importu
from datasets import Dataset from datasets import Dataset
from PIL import Image from PIL import Image
import re import re
@ -150,7 +151,7 @@ class CustomTrainer(Trainer):
# Inicjalizacja komponentów # Inicjalizacja komponentów
source_mapper = SourceMapper() source_mapper = SourceMapper()
model_name = "EleutherAI/gpt-neo-2.7B" model_name = "EleutherAI/gpt-neo-2.7B" #"google/gemma-2-2b"
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
@ -161,8 +162,9 @@ dataset = Dataset.from_list(data)
tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=32) tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=32)
# Inicjalizacja modelu # Inicjalizacja modelu
config = GPTNeoForCausalLM.from_pretrained(model_name).config config = AutoModelForCausalLM.from_pretrained(model_name).config
model = CustomModel.from_pretrained(model_name, config=config) #model = CustomModel.from_pretrained(model_name, config=config)
model = CustomModel.from_pretrained(model_name)
model.resize_token_embeddings(len(tokenizer)) model.resize_token_embeddings(len(tokenizer))
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
@ -208,7 +210,7 @@ def generate_answer(question, model, tokenizer, source_mapper, max_length=200):
# Pobierz źródło z ostatniego tokena # Pobierz źródło z ostatniego tokena
last_token_id = outputs.sequences[0][-1].item() last_token_id = outputs.sequences[0][-1].item()
source_idx = last_token_id % 1000 # Zaktualizuj sposób określania źródła source_idx = model.source_embedding.weight.shape[0] - 1 # Tymczasowe rozwiązanie
source = source_mapper.get_source(source_idx) source = source_mapper.get_source(source_idx)
return f"{answer}\n\nŹródło: {source if source else 'Opracowanie własne'}" return f"{answer}\n\nŹródło: {source if source else 'Opracowanie własne'}"