From 7179a2de95eb6ab26ca525ba43ef473a4c1e7a49 Mon Sep 17 00:00:00 2001 From: "l.gabrysiak" Date: Tue, 25 Feb 2025 15:14:27 +0100 Subject: [PATCH] mod ds --- hft.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/hft.py b/hft.py index 31bdac9..ecc1e42 100644 --- a/hft.py +++ b/hft.py @@ -9,7 +9,7 @@ import pytesseract import docx2txt import PyPDF2 import json -from torch.amp import autocast +from torch.cuda.amp import autocast from collections import defaultdict from huggingface_hub import login @@ -21,21 +21,24 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" def free_memory(): - torch.empty_cache('cuda') - torch.ipc_collect('cuda') + torch.cuda.empty_cache() + torch.cuda.ipc_collect() class SourceMapper: def __init__(self): - self.source_to_idx = defaultdict(lambda: len(self.source_to_idx)) - self.idx_to_source = {} + self.source_to_idx = defaultdict(lambda: 0) # Domyślnie 0 dla nieznanych + self.idx_to_source = {0: "Unknown"} + self.next_idx = 1 # Indeksy od 1 dla znanych źródeł def add_source(self, source): if source and source not in self.source_to_idx: - idx = self.source_to_idx[source] + idx = self.next_idx + self.source_to_idx[source] = idx self.idx_to_source[idx] = source + self.next_idx += 1 def get_idx(self, source): - return self.source_to_idx[source] if source else -1 + return self.source_to_idx.get(source, 0) def get_source(self, idx): return self.idx_to_source.get(idx, "Unknown") @@ -97,7 +100,7 @@ def prepare_dataset(directory, catalog_path, source_mapper): for chunk in chunks: data.append({ "text": chunk, - "source_idx": -1 + "source_idx": 0 }) return data @@ -126,8 +129,9 @@ class CustomModel(GPTNeoForCausalLM): self.source_embedding = nn.Embedding( num_embeddings=1000, embedding_dim=config.hidden_size, - padding_idx=-1 + padding_idx=0 # Poprawiony padding_idx ) + self.source_proj = nn.Linear(config.hidden_size, config.vocab_size) def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs): with autocast(): @@ -138,8 +142,9 @@ class CustomModel(GPTNeoForCausalLM): **kwargs ) if source_idx is not None: - source_embeds = self.source_embedding(source_idx).unsqueeze(1) - outputs.logits += source_embeds + source_embeds = self.source_embedding(source_idx) + source_projected = self.source_proj(source_embeds) + outputs.logits += source_projected.unsqueeze(1) return outputs source_mapper = SourceMapper() @@ -198,8 +203,8 @@ def generate_answer(question, model, tokenizer, source_mapper, max_length=200): answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) # Pobierz źródło z ostatniego tokena - last_token_id = outputs.sequences[0][-1].item() - source_idx = model.source_embedding.weight.shape[0] - 1 # Tymczasowe rozwiązanie + last_token_logits = outputs.scores[-1] + source_idx = torch.argmax(last_token_logits, dim=-1)[-1].item() source = source_mapper.get_source(source_idx) return f"{answer}\n\nŹródło: {source if source else 'Opracowanie własne'}"