diff --git a/hft.py b/hft.py index 20b6f78..0561452 100644 --- a/hft.py +++ b/hft.py @@ -1,3 +1,8 @@ +import nltk +nltk.download('averaged_perceptron_tagger', quiet=True) +nltk.download('wordnet', quiet=True) +nltk.download('punkt', quiet=True) + import os import torch import random @@ -45,7 +50,7 @@ class SourceMapper: class LegalProcessor: def __init__(self, catalog_path): self.catalog = self.load_catalog(catalog_path) - self.augmenter = SynonymAug(aug_src='wordnet', aug_max=3) + self.augmenter = SynonymAug(aug_src='wordnet', aug_max=3, lang='pol') def load_catalog(self, path): try: @@ -137,6 +142,34 @@ class LegalProcessor: return [f"[Custom] {chunk}" for chunk in chunks if chunk.strip()] +class CustomModel(torch.nn.Module): + def __init__(self, model_name): + super().__init__() + self.base_model = AutoModelForCausalLM.from_pretrained(model_name) + self.source_emb = torch.nn.Embedding(1000, self.base_model.config.hidden_size) + + # Zamrożenie parametrów bazowych + for param in self.base_model.parameters(): + param.requires_grad = False + + # Odmrożenie ostatnich warstw + for layer in self.base_model.transformer.h[-2:]: + for param in layer.parameters(): + param.requires_grad = True + + self.base_model.get_output_embeddings().requires_grad_(True) + + def forward(self, input_ids, attention_mask, labels, source_idx): + inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_emb = self.source_emb(source_idx.clamp(0, 999)).unsqueeze(1) + inputs_embeds += source_emb + + return self.base_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + labels=labels + ) + def main(): # Inicjalizacja komponentów source_mapper = SourceMapper() @@ -173,7 +206,7 @@ def main(): # Przetwarzanie wielowątkowe with ThreadPoolExecutor(max_workers=cpu_count()) as executor: futures = [] - for root, _, files in os.walk("files"): # Zmieniono na "files" + for root, _, files in os.walk("files"): for file in files: file_path = os.path.join(root, file) futures.append(executor.submit(process_and_augment, file_path)) @@ -181,7 +214,70 @@ def main(): for future in futures: future.result() - # Reszta kodu pozostaje bez zmian... + print(f"\nPrzygotowano {len(data)} przykładów treningowych") + print("Przykładowe dane:") + for example in random.sample(data, 3): + print(f"\nŹródło: {source_mapper.get_source(example['source_idx'])}") + print(f"Tekst: {example['text'][:150]}...") + + # Przygotowanie datasetu + dataset = Dataset.from_list(data) + + def tokenize_fn(examples): + tokenized = tokenizer( + examples["text"], + max_length=512, + padding="max_length", + truncation=True, + return_tensors="pt" + ) + return { + "input_ids": tokenized["input_ids"].squeeze(), + "attention_mask": tokenized["attention_mask"].squeeze(), + "labels": tokenized["input_ids"].squeeze(), + "source_idx": examples["source_idx"] + } + + tokenized_ds = dataset.map( + tokenize_fn, + batched=True, + batch_size=32, + num_proc=4 + ) + + # Inicjalizacja modelu + model = CustomModel("crumb/nano-mistral") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + + # Konfiguracja treningu + training_args = TrainingArguments( + output_dir="./wyniki", + num_train_epochs=5, + per_device_train_batch_size=2, + gradient_accumulation_steps=8, + learning_rate=2e-5, + fp16=torch.cuda.is_available(), + logging_steps=20, + save_strategy="epoch", + report_to="none" + ) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized_ds, + data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + ) + + # Trening + print("\nRozpoczynanie treningu...") + trainer.train() + + # Zapis modelu + model.save_pretrained("./trained_legal_model") + tokenizer.save_pretrained("./trained_legal_model") + print("Trening zakończony pomyślnie!") if __name__ == "__main__": main() \ No newline at end of file