This commit is contained in:
l.gabrysiak 2025-02-25 17:57:49 +01:00
parent 8f8843fbb2
commit 359ca70172
1 changed files with 4 additions and 14 deletions

18
hft.py
View File

@ -1,7 +1,7 @@
import os import os
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, PreTrainedModel from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from datasets import Dataset from datasets import Dataset
from PIL import Image from PIL import Image
import re import re
@ -117,7 +117,7 @@ def custom_collate_fn(batch):
class CustomModel(nn.Module): class CustomModel(nn.Module):
def __init__(self, model_name, config): def __init__(self, model_name, config):
super().__init__(config) super().__init__()
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config) self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
self.source_embedding = nn.Embedding( self.source_embedding = nn.Embedding(
num_embeddings=1000, num_embeddings=1000,
@ -160,7 +160,7 @@ tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=8)
# Inicjalizacja modelu # Inicjalizacja modelu
config = AutoModelForCausalLM.from_pretrained(model_name).config config = AutoModelForCausalLM.from_pretrained(model_name).config
#print("Vocabulary size:", config.vocab_size) print("Vocabulary size:", config.vocab_size)
model = CustomModel(model_name, config) model = CustomModel(model_name, config)
model.to("cpu") # Zmienione na CPU dla debugowania model.to("cpu") # Zmienione na CPU dla debugowania
@ -204,14 +204,4 @@ def generate_answer(question, model, tokenizer, source_mapper, max_length=200):
# Pobierz źródło z ostatniego tokena # Pobierz źródło z ostatniego tokena
last_token_id = outputs.sequences[0][-1].item() last_token_id = outputs.sequences[0][-1].item()
source_idx = model.source_embedding.weight.shape[0] - 1 # Tymczasowe rozwiązanie source_idx = model.source_embeddi
source = source_mapper.get_source(source_idx)
return f"{answer}\n\nŹródło: {source if source else 'Opracowanie własne'}"
# Przykład użycia
question = "Ile dni urlopu przysługuje pracownikowi?"
answer = generate_answer(question, model, tokenizer, source_mapper)
print("Pytanie:", question)
print("Odpowiedź:", answer)