modyfiakcja trenera
This commit is contained in:
parent
136eddef07
commit
1c348d41c0
8
hft.py
8
hft.py
|
|
@ -79,6 +79,7 @@ def prepare_dataset(directory, catalog_path):
|
||||||
# Tokenizacja danych
|
# Tokenizacja danych
|
||||||
def tokenize_function(examples):
|
def tokenize_function(examples):
|
||||||
inputs = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)
|
inputs = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)
|
||||||
|
inputs["labels"] = inputs["input_ids"].copy()
|
||||||
inputs["source"] = examples["source"]
|
inputs["source"] = examples["source"]
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
|
@ -100,7 +101,8 @@ class CustomTrainer(Trainer):
|
||||||
def compute_loss(self, model, inputs, return_outputs=False):
|
def compute_loss(self, model, inputs, return_outputs=False):
|
||||||
labels = inputs.pop("labels")
|
labels = inputs.pop("labels")
|
||||||
source = inputs.pop("source")
|
source = inputs.pop("source")
|
||||||
outputs = model(**inputs, labels=labels, source=source)
|
source_ids = torch.tensor([hash(s) % 1000 for s in source], device=model.device)
|
||||||
|
outputs = model(**inputs, labels=labels, source=source_ids)
|
||||||
loss = outputs.loss
|
loss = outputs.loss
|
||||||
return (loss, outputs) if return_outputs else loss
|
return (loss, outputs) if return_outputs else loss
|
||||||
|
|
||||||
|
|
@ -113,7 +115,7 @@ model = CustomModel.from_pretrained(model_name)
|
||||||
catalog_path = "file_catalog.json"
|
catalog_path = "file_catalog.json"
|
||||||
data = prepare_dataset("files", catalog_path)
|
data = prepare_dataset("files", catalog_path)
|
||||||
dataset = Dataset.from_list(data)
|
dataset = Dataset.from_list(data)
|
||||||
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
|
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)
|
||||||
|
|
||||||
# Konfiguracja treningu
|
# Konfiguracja treningu
|
||||||
training_args = TrainingArguments(
|
training_args = TrainingArguments(
|
||||||
|
|
@ -147,7 +149,7 @@ def generate_answer(question, model, tokenizer, dataset):
|
||||||
# Znajdź najbardziej prawdopodobne źródło
|
# Znajdź najbardziej prawdopodobne źródło
|
||||||
source_probs = outputs.scores[-1][:, model.source_embedding.weight.shape[0]:]
|
source_probs = outputs.scores[-1][:, model.source_embedding.weight.shape[0]:]
|
||||||
most_likely_source_idx = torch.argmax(source_probs).item()
|
most_likely_source_idx = torch.argmax(source_probs).item()
|
||||||
most_likely_source = dataset[most_likely_source_idx]['source']
|
most_likely_source = dataset[most_likely_source_idx % len(dataset)]['source']
|
||||||
|
|
||||||
return f"{answer}\n\nŹródło: {most_likely_source}"
|
return f"{answer}\n\nŹródło: {most_likely_source}"
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue