modyfiakcja trenera

This commit is contained in:
l.gabrysiak 2025-02-25 12:27:54 +01:00
parent 136eddef07
commit 1c348d41c0
1 changed files with 5 additions and 3 deletions

8
hft.py
View File

@ -79,6 +79,7 @@ def prepare_dataset(directory, catalog_path):
# Tokenizacja danych
def tokenize_function(examples):
inputs = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)
inputs["labels"] = inputs["input_ids"].copy()
inputs["source"] = examples["source"]
return inputs
@ -100,7 +101,8 @@ class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels")
source = inputs.pop("source")
outputs = model(**inputs, labels=labels, source=source)
source_ids = torch.tensor([hash(s) % 1000 for s in source], device=model.device)
outputs = model(**inputs, labels=labels, source=source_ids)
loss = outputs.loss
return (loss, outputs) if return_outputs else loss
@ -113,7 +115,7 @@ model = CustomModel.from_pretrained(model_name)
catalog_path = "file_catalog.json"
data = prepare_dataset("files", catalog_path)
dataset = Dataset.from_list(data)
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)
# Konfiguracja treningu
training_args = TrainingArguments(
@ -147,7 +149,7 @@ def generate_answer(question, model, tokenizer, dataset):
# Znajdź najbardziej prawdopodobne źródło
source_probs = outputs.scores[-1][:, model.source_embedding.weight.shape[0]:]
most_likely_source_idx = torch.argmax(source_probs).item()
most_likely_source = dataset[most_likely_source_idx]['source']
most_likely_source = dataset[most_likely_source_idx % len(dataset)]['source']
return f"{answer}\n\nŹródło: {most_likely_source}"