diff --git a/allegro.py b/allegro.py index ba30a0e..dabe8f4 100644 --- a/allegro.py +++ b/allegro.py @@ -9,7 +9,7 @@ import numpy as np from sentence_transformers import SentenceTransformer from datasets import Dataset from peft import LoraConfig, get_peft_model -from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling +from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling, MarianForCausalLM, MarianTokenizer embed_model = SentenceTransformer("all-MiniLM-L6-v2") @@ -64,8 +64,8 @@ eval_dataset = split_dataset["test"] device = "cuda" if torch.cuda.is_available() else "cpu" model_name = "allegro/multislav-5lang" -model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device) -tokenizer = AutoTokenizer.from_pretrained(model_name) +model = MarianForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device) +tokenizer = MarianForCausalLM.from_pretrained(model_name) lora_config = LoraConfig( r=8, lora_alpha=32, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM" diff --git a/requirements.txt b/requirements.txt index d08ae7e..4929b01 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,6 @@ numpy peft weaviate-client sentence_transformers -faiss-gpu \ No newline at end of file +faiss-gpu +sentencepiece +sacremoses \ No newline at end of file