This commit is contained in:
l.gabrysiak 2025-02-25 19:28:44 +01:00
parent e068a27261
commit ed04739b58
1 changed files with 11 additions and 3 deletions

14
hft.py
View File

@ -1,7 +1,7 @@
import os import os
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, GenerationMixin
from datasets import Dataset from datasets import Dataset
from PIL import Image from PIL import Image
import re import re
@ -113,7 +113,7 @@ def custom_collate_fn(batch):
source_idx = torch.tensor([b.get("source_idx", -1) for b in batch], dtype=torch.long).cpu() source_idx = torch.tensor([b.get("source_idx", -1) for b in batch], dtype=torch.long).cpu()
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "source_idx": source_idx} return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "source_idx": source_idx}
class CustomModel(nn.Module): class CustomModel(nn.Module, GenerationMixin):
def __init__(self, model_name, config): def __init__(self, model_name, config):
super().__init__() super().__init__()
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config) self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
@ -122,6 +122,7 @@ class CustomModel(nn.Module):
embedding_dim=config.hidden_size, embedding_dim=config.hidden_size,
padding_idx=-1 padding_idx=-1
) )
self.config = config
self.device = next(self.base_model.parameters()).device self.device = next(self.base_model.parameters()).device
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs): def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
@ -135,6 +136,12 @@ class CustomModel(nn.Module):
return outputs return outputs
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return self.base_model.prepare_inputs_for_generation(input_ids, **kwargs)
def _reorder_cache(self, past, beam_idx):
return self.base_model._reorder_cache(past, beam_idx)
class CustomTrainer(Trainer): class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, **kwargs): def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
device = next(model.parameters()).device device = next(model.parameters()).device
@ -190,7 +197,8 @@ trainer.train()
# Funkcja generująca odpowiedź # Funkcja generująca odpowiedź
def generate_answer(question, model, tokenizer, source_mapper, max_length=200): def generate_answer(question, model, tokenizer, source_mapper, max_length=200):
inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512) device = next(model.parameters()).device
inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512).to(device)
outputs = model.generate( outputs = model.generate(
**inputs, **inputs,