mod
This commit is contained in:
parent
b519db2b27
commit
30234c332a
12
hft.py
12
hft.py
|
|
@ -1,7 +1,8 @@
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import AutoTokenizer, GPTNeoForCausalLM, Trainer # Poprawiono importy
|
#from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
|
||||||
|
from transformers import GPTNeoForCausalLM, Trainer, TrainingArguments # Zmiana importu
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import re
|
import re
|
||||||
|
|
@ -150,7 +151,7 @@ class CustomTrainer(Trainer):
|
||||||
|
|
||||||
# Inicjalizacja komponentów
|
# Inicjalizacja komponentów
|
||||||
source_mapper = SourceMapper()
|
source_mapper = SourceMapper()
|
||||||
model_name = "EleutherAI/gpt-neo-2.7B"
|
model_name = "EleutherAI/gpt-neo-2.7B" #"google/gemma-2-2b"
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
|
@ -161,8 +162,9 @@ dataset = Dataset.from_list(data)
|
||||||
tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=32)
|
tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=32)
|
||||||
|
|
||||||
# Inicjalizacja modelu
|
# Inicjalizacja modelu
|
||||||
config = GPTNeoForCausalLM.from_pretrained(model_name).config
|
config = AutoModelForCausalLM.from_pretrained(model_name).config
|
||||||
model = CustomModel.from_pretrained(model_name, config=config)
|
#model = CustomModel.from_pretrained(model_name, config=config)
|
||||||
|
model = CustomModel.from_pretrained(model_name)
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
|
@ -208,7 +210,7 @@ def generate_answer(question, model, tokenizer, source_mapper, max_length=200):
|
||||||
|
|
||||||
# Pobierz źródło z ostatniego tokena
|
# Pobierz źródło z ostatniego tokena
|
||||||
last_token_id = outputs.sequences[0][-1].item()
|
last_token_id = outputs.sequences[0][-1].item()
|
||||||
source_idx = last_token_id % 1000 # Zaktualizuj sposób określania źródła
|
source_idx = model.source_embedding.weight.shape[0] - 1 # Tymczasowe rozwiązanie
|
||||||
source = source_mapper.get_source(source_idx)
|
source = source_mapper.get_source(source_idx)
|
||||||
|
|
||||||
return f"{answer}\n\nŹródło: {source if source else 'Opracowanie własne'}"
|
return f"{answer}\n\nŹródło: {source if source else 'Opracowanie własne'}"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue