mod ds
This commit is contained in:
parent
67f729ae68
commit
7179a2de95
31
hft.py
31
hft.py
|
|
@ -9,7 +9,7 @@ import pytesseract
|
|||
import docx2txt
|
||||
import PyPDF2
|
||||
import json
|
||||
from torch.amp import autocast
|
||||
from torch.cuda.amp import autocast
|
||||
from collections import defaultdict
|
||||
from huggingface_hub import login
|
||||
|
||||
|
|
@ -21,21 +21,24 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
|
||||
def free_memory():
|
||||
torch.empty_cache('cuda')
|
||||
torch.ipc_collect('cuda')
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
class SourceMapper:
|
||||
def __init__(self):
|
||||
self.source_to_idx = defaultdict(lambda: len(self.source_to_idx))
|
||||
self.idx_to_source = {}
|
||||
self.source_to_idx = defaultdict(lambda: 0) # Domyślnie 0 dla nieznanych
|
||||
self.idx_to_source = {0: "Unknown"}
|
||||
self.next_idx = 1 # Indeksy od 1 dla znanych źródeł
|
||||
|
||||
def add_source(self, source):
|
||||
if source and source not in self.source_to_idx:
|
||||
idx = self.source_to_idx[source]
|
||||
idx = self.next_idx
|
||||
self.source_to_idx[source] = idx
|
||||
self.idx_to_source[idx] = source
|
||||
self.next_idx += 1
|
||||
|
||||
def get_idx(self, source):
|
||||
return self.source_to_idx[source] if source else -1
|
||||
return self.source_to_idx.get(source, 0)
|
||||
|
||||
def get_source(self, idx):
|
||||
return self.idx_to_source.get(idx, "Unknown")
|
||||
|
|
@ -97,7 +100,7 @@ def prepare_dataset(directory, catalog_path, source_mapper):
|
|||
for chunk in chunks:
|
||||
data.append({
|
||||
"text": chunk,
|
||||
"source_idx": -1
|
||||
"source_idx": 0
|
||||
})
|
||||
return data
|
||||
|
||||
|
|
@ -126,8 +129,9 @@ class CustomModel(GPTNeoForCausalLM):
|
|||
self.source_embedding = nn.Embedding(
|
||||
num_embeddings=1000,
|
||||
embedding_dim=config.hidden_size,
|
||||
padding_idx=-1
|
||||
padding_idx=0 # Poprawiony padding_idx
|
||||
)
|
||||
self.source_proj = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
|
||||
with autocast():
|
||||
|
|
@ -138,8 +142,9 @@ class CustomModel(GPTNeoForCausalLM):
|
|||
**kwargs
|
||||
)
|
||||
if source_idx is not None:
|
||||
source_embeds = self.source_embedding(source_idx).unsqueeze(1)
|
||||
outputs.logits += source_embeds
|
||||
source_embeds = self.source_embedding(source_idx)
|
||||
source_projected = self.source_proj(source_embeds)
|
||||
outputs.logits += source_projected.unsqueeze(1)
|
||||
return outputs
|
||||
|
||||
source_mapper = SourceMapper()
|
||||
|
|
@ -198,8 +203,8 @@ def generate_answer(question, model, tokenizer, source_mapper, max_length=200):
|
|||
answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
|
||||
|
||||
# Pobierz źródło z ostatniego tokena
|
||||
last_token_id = outputs.sequences[0][-1].item()
|
||||
source_idx = model.source_embedding.weight.shape[0] - 1 # Tymczasowe rozwiązanie
|
||||
last_token_logits = outputs.scores[-1]
|
||||
source_idx = torch.argmax(last_token_logits, dim=-1)[-1].item()
|
||||
source = source_mapper.get_source(source_idx)
|
||||
|
||||
return f"{answer}\n\nŹródło: {source if source else 'Opracowanie własne'}"
|
||||
|
|
|
|||
Loading…
Reference in New Issue