mod
This commit is contained in:
parent
8f8843fbb2
commit
359ca70172
18
hft.py
18
hft.py
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, PreTrainedModel
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
|
||||
from datasets import Dataset
|
||||
from PIL import Image
|
||||
import re
|
||||
|
|
@ -117,7 +117,7 @@ def custom_collate_fn(batch):
|
|||
|
||||
class CustomModel(nn.Module):
|
||||
def __init__(self, model_name, config):
|
||||
super().__init__(config)
|
||||
super().__init__()
|
||||
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
|
||||
self.source_embedding = nn.Embedding(
|
||||
num_embeddings=1000,
|
||||
|
|
@ -160,7 +160,7 @@ tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=8)
|
|||
|
||||
# Inicjalizacja modelu
|
||||
config = AutoModelForCausalLM.from_pretrained(model_name).config
|
||||
#print("Vocabulary size:", config.vocab_size)
|
||||
print("Vocabulary size:", config.vocab_size)
|
||||
model = CustomModel(model_name, config)
|
||||
model.to("cpu") # Zmienione na CPU dla debugowania
|
||||
|
||||
|
|
@ -204,14 +204,4 @@ def generate_answer(question, model, tokenizer, source_mapper, max_length=200):
|
|||
|
||||
# Pobierz źródło z ostatniego tokena
|
||||
last_token_id = outputs.sequences[0][-1].item()
|
||||
source_idx = model.source_embedding.weight.shape[0] - 1 # Tymczasowe rozwiązanie
|
||||
source = source_mapper.get_source(source_idx)
|
||||
|
||||
return f"{answer}\n\nŹródło: {source if source else 'Opracowanie własne'}"
|
||||
|
||||
# Przykład użycia
|
||||
question = "Ile dni urlopu przysługuje pracownikowi?"
|
||||
answer = generate_answer(question, model, tokenizer, source_mapper)
|
||||
print("Pytanie:", question)
|
||||
print("Odpowiedź:", answer)
|
||||
|
||||
source_idx = model.source_embeddi
|
||||
|
|
|
|||
Loading…
Reference in New Issue