diff --git a/hft.py b/hft.py index 6fe00a8..52d6fa4 100644 --- a/hft.py +++ b/hft.py @@ -1,7 +1,7 @@ import os import torch import torch.nn as nn -from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer +from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, GenerationMixin from datasets import Dataset from PIL import Image import re @@ -113,7 +113,7 @@ def custom_collate_fn(batch): source_idx = torch.tensor([b.get("source_idx", -1) for b in batch], dtype=torch.long).cpu() return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "source_idx": source_idx} -class CustomModel(nn.Module): +class CustomModel(nn.Module, GenerationMixin): def __init__(self, model_name, config): super().__init__() self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config) @@ -122,6 +122,7 @@ class CustomModel(nn.Module): embedding_dim=config.hidden_size, padding_idx=-1 ) + self.config = config self.device = next(self.base_model.parameters()).device def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs): @@ -135,6 +136,12 @@ class CustomModel(nn.Module): return outputs + def prepare_inputs_for_generation(self, input_ids, **kwargs): + return self.base_model.prepare_inputs_for_generation(input_ids, **kwargs) + + def _reorder_cache(self, past, beam_idx): + return self.base_model._reorder_cache(past, beam_idx) + class CustomTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False, **kwargs): device = next(model.parameters()).device @@ -190,7 +197,8 @@ trainer.train() # Funkcja generująca odpowiedź def generate_answer(question, model, tokenizer, source_mapper, max_length=200): - inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512) + device = next(model.parameters()).device + inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512).to(device) outputs = model.generate( **inputs,