diff --git a/allegro.py b/allegro.py index a058b10..7a1582f 100644 --- a/allegro.py +++ b/allegro.py @@ -9,12 +9,14 @@ tokenizer = MarianTokenizer.from_pretrained(model_name) # Załaduj dane (przykład dla tłumaczenia z języka rumuńskiego na angielski) dataset = load_dataset("wmt16", "ro-en") -# Przetwórz dane do formatu odpowiedniego dla modelu def tokenize_function(examples): - # Jeśli 'translation' to lista słowników, np. [{'en': 'text1', 'ro': 'text1_translated'}, ...] - return tokenizer([example['en'] for example in examples['translation']], - [example['ro'] for example in examples['translation']], - truncation=True, padding='max_length', max_length=128) + # Tokenizacja + tokenized = tokenizer([example['en'] for example in examples['translation']], + [example['ro'] for example in examples['translation']], + truncation=True, padding='max_length', max_length=128) + # Ustawienie labels + tokenized['labels'] = tokenized['input_ids'].copy() + return tokenized tokenized_datasets = dataset.map(tokenize_function, batched=True)