mod
This commit is contained in:
parent
d00a183104
commit
85ba5346fb
22
hft.py
22
hft.py
|
|
@ -15,7 +15,7 @@ from transformers import (
|
||||||
Trainer,
|
Trainer,
|
||||||
DataCollatorForLanguageModeling
|
DataCollatorForLanguageModeling
|
||||||
)
|
)
|
||||||
from datasets import Dataset, Features, Value, Sequence
|
from datasets import Dataset, Features, Value
|
||||||
from huggingface_hub import login
|
from huggingface_hub import login
|
||||||
|
|
||||||
# Konfiguracja
|
# Konfiguracja
|
||||||
|
|
@ -49,11 +49,9 @@ class LegalAITrainer:
|
||||||
self.source_embedding = nn.Embedding(100000, config.hidden_size, padding_idx=-1)
|
self.source_embedding = nn.Embedding(100000, config.hidden_size, padding_idx=-1)
|
||||||
self.confidence_layer = nn.Linear(config.hidden_size, 1)
|
self.confidence_layer = nn.Linear(config.hidden_size, 1)
|
||||||
|
|
||||||
# Freeze base model
|
|
||||||
for param in self.base_model.parameters():
|
for param in self.base_model.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
# Trainable components
|
|
||||||
for layer in [self.source_embedding, self.confidence_layer]:
|
for layer in [self.source_embedding, self.confidence_layer]:
|
||||||
for param in layer.parameters():
|
for param in layer.parameters():
|
||||||
param.requires_grad = True
|
param.requires_grad = True
|
||||||
|
|
@ -156,13 +154,13 @@ class LegalAITrainer:
|
||||||
|
|
||||||
features = Features({
|
features = Features({
|
||||||
"text": Value("string"),
|
"text": Value("string"),
|
||||||
"source_idx": Sequence(Value("int32")),
|
"source_idx": Value("int32"),
|
||||||
"is_legal": Value("int32")
|
"is_legal": Value("int32")
|
||||||
})
|
})
|
||||||
|
|
||||||
return Dataset.from_dict({
|
return Dataset.from_dict({
|
||||||
"text": [d["text"] for d in data],
|
"text": [d["text"] for d in data],
|
||||||
"source_idx": [[d["source_idx"]] for d in data], # Zwracamy jako listę list
|
"source_idx": [d["source_idx"] for d in data],
|
||||||
"is_legal": [d["is_legal"] for d in data]
|
"is_legal": [d["is_legal"] for d in data]
|
||||||
}, features=features), source_mapper
|
}, features=features), source_mapper
|
||||||
|
|
||||||
|
|
@ -179,13 +177,11 @@ class LegalAITrainer:
|
||||||
max_length=512,
|
max_length=512,
|
||||||
return_tensors="pt"
|
return_tensors="pt"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Konwersja tensorów do list i odpowiednich typów
|
|
||||||
return {
|
return {
|
||||||
"input_ids": [ids.tolist() for ids in tokenized["input_ids"]],
|
"input_ids": tokenized["input_ids"].squeeze(),
|
||||||
"attention_mask": [mask.tolist() for mask in tokenized["attention_mask"]],
|
"attention_mask": tokenized["attention_mask"].squeeze(),
|
||||||
"labels": [labels.tolist() for labels in tokenized["input_ids"]],
|
"labels": tokenized["input_ids"].squeeze().clone(),
|
||||||
"source_idx": [[idx] for idx in examples["source_idx"]] # Sekwencja długości 1
|
"source_idx": torch.tensor(examples["source_idx"], dtype=torch.int32)
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16)
|
tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16)
|
||||||
|
|
@ -280,13 +276,11 @@ class LegalAITrainer:
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
legal_ai = LegalAITrainer()
|
legal_ai = LegalAITrainer()
|
||||||
|
|
||||||
# Trening
|
|
||||||
legal_ai.train(
|
legal_ai.train(
|
||||||
model_name="crumb/nano-mistral",
|
model_name="crumb/nano-mistral",
|
||||||
data_dir="./legal_docs",
|
data_dir="./legal_docs",
|
||||||
catalog_path="./catalog.json"
|
catalog_path="./catalog.json"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test
|
test_prompt = "Jakie są kary za prowadzenie pojazdu pod wpływem alkoholu?"
|
||||||
test_prompt = "Jakie są obowiązki pracodawcy w zakresie BHP?"
|
|
||||||
print(legal_ai.generate_response(test_prompt))
|
print(legal_ai.generate_response(test_prompt))
|
||||||
Loading…
Reference in New Issue