From 359ca701724dfa9b0f570c2cd585a2c04120b366 Mon Sep 17 00:00:00 2001 From: "l.gabrysiak" Date: Tue, 25 Feb 2025 17:57:49 +0100 Subject: [PATCH] mod --- hft.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/hft.py b/hft.py index 28d0795..86a3048 100644 --- a/hft.py +++ b/hft.py @@ -1,7 +1,7 @@ import os import torch import torch.nn as nn -from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, PreTrainedModel +from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer from datasets import Dataset from PIL import Image import re @@ -117,7 +117,7 @@ def custom_collate_fn(batch): class CustomModel(nn.Module): def __init__(self, model_name, config): - super().__init__(config) + super().__init__() self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config) self.source_embedding = nn.Embedding( num_embeddings=1000, @@ -160,7 +160,7 @@ tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=8) # Inicjalizacja modelu config = AutoModelForCausalLM.from_pretrained(model_name).config -#print("Vocabulary size:", config.vocab_size) +print("Vocabulary size:", config.vocab_size) model = CustomModel(model_name, config) model.to("cpu") # Zmienione na CPU dla debugowania @@ -204,14 +204,4 @@ def generate_answer(question, model, tokenizer, source_mapper, max_length=200): # Pobierz źródło z ostatniego tokena last_token_id = outputs.sequences[0][-1].item() - source_idx = model.source_embedding.weight.shape[0] - 1 # Tymczasowe rozwiązanie - source = source_mapper.get_source(source_idx) - - return f"{answer}\n\nŹródło: {source if source else 'Opracowanie własne'}" - -# Przykład użycia -question = "Ile dni urlopu przysługuje pracownikowi?" -answer = generate_answer(question, model, tokenizer, source_mapper) -print("Pytanie:", question) -print("Odpowiedź:", answer) - + source_idx = model.source_embeddi