mod ds
This commit is contained in:
parent
67f729ae68
commit
7179a2de95
31
hft.py
31
hft.py
|
|
@ -9,7 +9,7 @@ import pytesseract
|
||||||
import docx2txt
|
import docx2txt
|
||||||
import PyPDF2
|
import PyPDF2
|
||||||
import json
|
import json
|
||||||
from torch.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from huggingface_hub import login
|
from huggingface_hub import login
|
||||||
|
|
||||||
|
|
@ -21,21 +21,24 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||||
|
|
||||||
def free_memory():
|
def free_memory():
|
||||||
torch.empty_cache('cuda')
|
torch.cuda.empty_cache()
|
||||||
torch.ipc_collect('cuda')
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
class SourceMapper:
|
class SourceMapper:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.source_to_idx = defaultdict(lambda: len(self.source_to_idx))
|
self.source_to_idx = defaultdict(lambda: 0) # Domyślnie 0 dla nieznanych
|
||||||
self.idx_to_source = {}
|
self.idx_to_source = {0: "Unknown"}
|
||||||
|
self.next_idx = 1 # Indeksy od 1 dla znanych źródeł
|
||||||
|
|
||||||
def add_source(self, source):
|
def add_source(self, source):
|
||||||
if source and source not in self.source_to_idx:
|
if source and source not in self.source_to_idx:
|
||||||
idx = self.source_to_idx[source]
|
idx = self.next_idx
|
||||||
|
self.source_to_idx[source] = idx
|
||||||
self.idx_to_source[idx] = source
|
self.idx_to_source[idx] = source
|
||||||
|
self.next_idx += 1
|
||||||
|
|
||||||
def get_idx(self, source):
|
def get_idx(self, source):
|
||||||
return self.source_to_idx[source] if source else -1
|
return self.source_to_idx.get(source, 0)
|
||||||
|
|
||||||
def get_source(self, idx):
|
def get_source(self, idx):
|
||||||
return self.idx_to_source.get(idx, "Unknown")
|
return self.idx_to_source.get(idx, "Unknown")
|
||||||
|
|
@ -97,7 +100,7 @@ def prepare_dataset(directory, catalog_path, source_mapper):
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
data.append({
|
data.append({
|
||||||
"text": chunk,
|
"text": chunk,
|
||||||
"source_idx": -1
|
"source_idx": 0
|
||||||
})
|
})
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
@ -126,8 +129,9 @@ class CustomModel(GPTNeoForCausalLM):
|
||||||
self.source_embedding = nn.Embedding(
|
self.source_embedding = nn.Embedding(
|
||||||
num_embeddings=1000,
|
num_embeddings=1000,
|
||||||
embedding_dim=config.hidden_size,
|
embedding_dim=config.hidden_size,
|
||||||
padding_idx=-1
|
padding_idx=0 # Poprawiony padding_idx
|
||||||
)
|
)
|
||||||
|
self.source_proj = nn.Linear(config.hidden_size, config.vocab_size)
|
||||||
|
|
||||||
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
|
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
|
||||||
with autocast():
|
with autocast():
|
||||||
|
|
@ -138,8 +142,9 @@ class CustomModel(GPTNeoForCausalLM):
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
if source_idx is not None:
|
if source_idx is not None:
|
||||||
source_embeds = self.source_embedding(source_idx).unsqueeze(1)
|
source_embeds = self.source_embedding(source_idx)
|
||||||
outputs.logits += source_embeds
|
source_projected = self.source_proj(source_embeds)
|
||||||
|
outputs.logits += source_projected.unsqueeze(1)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
source_mapper = SourceMapper()
|
source_mapper = SourceMapper()
|
||||||
|
|
@ -198,8 +203,8 @@ def generate_answer(question, model, tokenizer, source_mapper, max_length=200):
|
||||||
answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
|
answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
|
||||||
|
|
||||||
# Pobierz źródło z ostatniego tokena
|
# Pobierz źródło z ostatniego tokena
|
||||||
last_token_id = outputs.sequences[0][-1].item()
|
last_token_logits = outputs.scores[-1]
|
||||||
source_idx = model.source_embedding.weight.shape[0] - 1 # Tymczasowe rozwiązanie
|
source_idx = torch.argmax(last_token_logits, dim=-1)[-1].item()
|
||||||
source = source_mapper.get_source(source_idx)
|
source = source_mapper.get_source(source_idx)
|
||||||
|
|
||||||
return f"{answer}\n\nŹródło: {source if source else 'Opracowanie własne'}"
|
return f"{answer}\n\nŹródło: {source if source else 'Opracowanie własne'}"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue