This commit is contained in:
l.gabrysiak 2025-02-25 15:14:27 +01:00
parent 67f729ae68
commit 7179a2de95
1 changed files with 18 additions and 13 deletions

31
hft.py
View File

@ -9,7 +9,7 @@ import pytesseract
import docx2txt
import PyPDF2
import json
from torch.amp import autocast
from torch.cuda.amp import autocast
from collections import defaultdict
from huggingface_hub import login
@ -21,21 +21,24 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
def free_memory():
torch.empty_cache('cuda')
torch.ipc_collect('cuda')
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
class SourceMapper:
def __init__(self):
self.source_to_idx = defaultdict(lambda: len(self.source_to_idx))
self.idx_to_source = {}
self.source_to_idx = defaultdict(lambda: 0) # Domyślnie 0 dla nieznanych
self.idx_to_source = {0: "Unknown"}
self.next_idx = 1 # Indeksy od 1 dla znanych źródeł
def add_source(self, source):
if source and source not in self.source_to_idx:
idx = self.source_to_idx[source]
idx = self.next_idx
self.source_to_idx[source] = idx
self.idx_to_source[idx] = source
self.next_idx += 1
def get_idx(self, source):
return self.source_to_idx[source] if source else -1
return self.source_to_idx.get(source, 0)
def get_source(self, idx):
return self.idx_to_source.get(idx, "Unknown")
@ -97,7 +100,7 @@ def prepare_dataset(directory, catalog_path, source_mapper):
for chunk in chunks:
data.append({
"text": chunk,
"source_idx": -1
"source_idx": 0
})
return data
@ -126,8 +129,9 @@ class CustomModel(GPTNeoForCausalLM):
self.source_embedding = nn.Embedding(
num_embeddings=1000,
embedding_dim=config.hidden_size,
padding_idx=-1
padding_idx=0 # Poprawiony padding_idx
)
self.source_proj = nn.Linear(config.hidden_size, config.vocab_size)
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
with autocast():
@ -138,8 +142,9 @@ class CustomModel(GPTNeoForCausalLM):
**kwargs
)
if source_idx is not None:
source_embeds = self.source_embedding(source_idx).unsqueeze(1)
outputs.logits += source_embeds
source_embeds = self.source_embedding(source_idx)
source_projected = self.source_proj(source_embeds)
outputs.logits += source_projected.unsqueeze(1)
return outputs
source_mapper = SourceMapper()
@ -198,8 +203,8 @@ def generate_answer(question, model, tokenizer, source_mapper, max_length=200):
answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
# Pobierz źródło z ostatniego tokena
last_token_id = outputs.sequences[0][-1].item()
source_idx = model.source_embedding.weight.shape[0] - 1 # Tymczasowe rozwiązanie
last_token_logits = outputs.scores[-1]
source_idx = torch.argmax(last_token_logits, dim=-1)[-1].item()
source = source_mapper.get_source(source_idx)
return f"{answer}\n\nŹródło: {source if source else 'Opracowanie własne'}"