This commit is contained in:
l.gabrysiak 2025-02-25 17:57:49 +01:00
parent 8f8843fbb2
commit 359ca70172
1 changed files with 4 additions and 14 deletions

18
hft.py
View File

@ -1,7 +1,7 @@
import os
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, PreTrainedModel
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from datasets import Dataset
from PIL import Image
import re
@ -117,7 +117,7 @@ def custom_collate_fn(batch):
class CustomModel(nn.Module):
def __init__(self, model_name, config):
super().__init__(config)
super().__init__()
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
self.source_embedding = nn.Embedding(
num_embeddings=1000,
@ -160,7 +160,7 @@ tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=8)
# Inicjalizacja modelu
config = AutoModelForCausalLM.from_pretrained(model_name).config
#print("Vocabulary size:", config.vocab_size)
print("Vocabulary size:", config.vocab_size)
model = CustomModel(model_name, config)
model.to("cpu") # Zmienione na CPU dla debugowania
@ -204,14 +204,4 @@ def generate_answer(question, model, tokenizer, source_mapper, max_length=200):
# Pobierz źródło z ostatniego tokena
last_token_id = outputs.sequences[0][-1].item()
source_idx = model.source_embedding.weight.shape[0] - 1 # Tymczasowe rozwiązanie
source = source_mapper.get_source(source_idx)
return f"{answer}\n\nŹródło: {source if source else 'Opracowanie własne'}"
# Przykład użycia
question = "Ile dni urlopu przysługuje pracownikowi?"
answer = generate_answer(question, model, tokenizer, source_mapper)
print("Pytanie:", question)
print("Odpowiedź:", answer)
source_idx = model.source_embeddi