This commit is contained in:
l.gabrysiak 2025-02-25 15:14:27 +01:00
parent 67f729ae68
commit 7179a2de95
1 changed files with 18 additions and 13 deletions

31
hft.py
View File

@ -9,7 +9,7 @@ import pytesseract
import docx2txt import docx2txt
import PyPDF2 import PyPDF2
import json import json
from torch.amp import autocast from torch.cuda.amp import autocast
from collections import defaultdict from collections import defaultdict
from huggingface_hub import login from huggingface_hub import login
@ -21,21 +21,24 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
def free_memory(): def free_memory():
torch.empty_cache('cuda') torch.cuda.empty_cache()
torch.ipc_collect('cuda') torch.cuda.ipc_collect()
class SourceMapper: class SourceMapper:
def __init__(self): def __init__(self):
self.source_to_idx = defaultdict(lambda: len(self.source_to_idx)) self.source_to_idx = defaultdict(lambda: 0) # Domyślnie 0 dla nieznanych
self.idx_to_source = {} self.idx_to_source = {0: "Unknown"}
self.next_idx = 1 # Indeksy od 1 dla znanych źródeł
def add_source(self, source): def add_source(self, source):
if source and source not in self.source_to_idx: if source and source not in self.source_to_idx:
idx = self.source_to_idx[source] idx = self.next_idx
self.source_to_idx[source] = idx
self.idx_to_source[idx] = source self.idx_to_source[idx] = source
self.next_idx += 1
def get_idx(self, source): def get_idx(self, source):
return self.source_to_idx[source] if source else -1 return self.source_to_idx.get(source, 0)
def get_source(self, idx): def get_source(self, idx):
return self.idx_to_source.get(idx, "Unknown") return self.idx_to_source.get(idx, "Unknown")
@ -97,7 +100,7 @@ def prepare_dataset(directory, catalog_path, source_mapper):
for chunk in chunks: for chunk in chunks:
data.append({ data.append({
"text": chunk, "text": chunk,
"source_idx": -1 "source_idx": 0
}) })
return data return data
@ -126,8 +129,9 @@ class CustomModel(GPTNeoForCausalLM):
self.source_embedding = nn.Embedding( self.source_embedding = nn.Embedding(
num_embeddings=1000, num_embeddings=1000,
embedding_dim=config.hidden_size, embedding_dim=config.hidden_size,
padding_idx=-1 padding_idx=0 # Poprawiony padding_idx
) )
self.source_proj = nn.Linear(config.hidden_size, config.vocab_size)
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs): def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
with autocast(): with autocast():
@ -138,8 +142,9 @@ class CustomModel(GPTNeoForCausalLM):
**kwargs **kwargs
) )
if source_idx is not None: if source_idx is not None:
source_embeds = self.source_embedding(source_idx).unsqueeze(1) source_embeds = self.source_embedding(source_idx)
outputs.logits += source_embeds source_projected = self.source_proj(source_embeds)
outputs.logits += source_projected.unsqueeze(1)
return outputs return outputs
source_mapper = SourceMapper() source_mapper = SourceMapper()
@ -198,8 +203,8 @@ def generate_answer(question, model, tokenizer, source_mapper, max_length=200):
answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
# Pobierz źródło z ostatniego tokena # Pobierz źródło z ostatniego tokena
last_token_id = outputs.sequences[0][-1].item() last_token_logits = outputs.scores[-1]
source_idx = model.source_embedding.weight.shape[0] - 1 # Tymczasowe rozwiązanie source_idx = torch.argmax(last_token_logits, dim=-1)[-1].item()
source = source_mapper.get_source(source_idx) source = source_mapper.get_source(source_idx)
return f"{answer}\n\nŹródło: {source if source else 'Opracowanie własne'}" return f"{answer}\n\nŹródło: {source if source else 'Opracowanie własne'}"