This commit is contained in:
l.gabrysiak 2025-02-25 19:28:44 +01:00
parent e068a27261
commit ed04739b58
1 changed files with 11 additions and 3 deletions

14
hft.py
View File

@ -1,7 +1,7 @@
import os
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, GenerationMixin
from datasets import Dataset
from PIL import Image
import re
@ -113,7 +113,7 @@ def custom_collate_fn(batch):
source_idx = torch.tensor([b.get("source_idx", -1) for b in batch], dtype=torch.long).cpu()
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "source_idx": source_idx}
class CustomModel(nn.Module):
class CustomModel(nn.Module, GenerationMixin):
def __init__(self, model_name, config):
super().__init__()
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
@ -122,6 +122,7 @@ class CustomModel(nn.Module):
embedding_dim=config.hidden_size,
padding_idx=-1
)
self.config = config
self.device = next(self.base_model.parameters()).device
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
@ -135,6 +136,12 @@ class CustomModel(nn.Module):
return outputs
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return self.base_model.prepare_inputs_for_generation(input_ids, **kwargs)
def _reorder_cache(self, past, beam_idx):
return self.base_model._reorder_cache(past, beam_idx)
class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
device = next(model.parameters()).device
@ -190,7 +197,8 @@ trainer.train()
# Funkcja generująca odpowiedź
def generate_answer(question, model, tokenizer, source_mapper, max_length=200):
inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512)
device = next(model.parameters()).device
inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512).to(device)
outputs = model.generate(
**inputs,