diff --git a/hft.py b/hft.py index 13ee571..6e35ca2 100644 --- a/hft.py +++ b/hft.py @@ -15,7 +15,7 @@ from transformers import ( Trainer, DataCollatorForLanguageModeling ) -from datasets import Dataset, Features, Value, Sequence +from datasets import Dataset, Features, Value from huggingface_hub import login # Konfiguracja @@ -49,11 +49,9 @@ class LegalAITrainer: self.source_embedding = nn.Embedding(100000, config.hidden_size, padding_idx=-1) self.confidence_layer = nn.Linear(config.hidden_size, 1) - # Freeze base model for param in self.base_model.parameters(): param.requires_grad = False - # Trainable components for layer in [self.source_embedding, self.confidence_layer]: for param in layer.parameters(): param.requires_grad = True @@ -156,13 +154,13 @@ class LegalAITrainer: features = Features({ "text": Value("string"), - "source_idx": Sequence(Value("int32")), + "source_idx": Value("int32"), "is_legal": Value("int32") }) return Dataset.from_dict({ "text": [d["text"] for d in data], - "source_idx": [[d["source_idx"]] for d in data], # Zwracamy jako listę list + "source_idx": [d["source_idx"] for d in data], "is_legal": [d["is_legal"] for d in data] }, features=features), source_mapper @@ -179,13 +177,11 @@ class LegalAITrainer: max_length=512, return_tensors="pt" ) - - # Konwersja tensorów do list i odpowiednich typów return { - "input_ids": [ids.tolist() for ids in tokenized["input_ids"]], - "attention_mask": [mask.tolist() for mask in tokenized["attention_mask"]], - "labels": [labels.tolist() for labels in tokenized["input_ids"]], - "source_idx": [[idx] for idx in examples["source_idx"]] # Sekwencja długości 1 + "input_ids": tokenized["input_ids"].squeeze(), + "attention_mask": tokenized["attention_mask"].squeeze(), + "labels": tokenized["input_ids"].squeeze().clone(), + "source_idx": torch.tensor(examples["source_idx"], dtype=torch.int32) } tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16) @@ -280,13 +276,11 @@ class LegalAITrainer: if __name__ == "__main__": legal_ai = LegalAITrainer() - # Trening legal_ai.train( model_name="crumb/nano-mistral", data_dir="./legal_docs", catalog_path="./catalog.json" ) - # Test - test_prompt = "Jakie są obowiązki pracodawcy w zakresie BHP?" + test_prompt = "Jakie są kary za prowadzenie pojazdu pod wpływem alkoholu?" print(legal_ai.generate_response(test_prompt)) \ No newline at end of file