From 46695e42cc5c57af11b5acef8b0b136bb5f82107 Mon Sep 17 00:00:00 2001 From: naem1023 Date: Mon, 2 Sep 2024 13:52:47 +0900 Subject: [PATCH] feat: add batch size of map function in the preprocessed dataset Former-commit-id: 209313eeeab8d1a7c320bd9aa90a5f4656082b7c --- src/llamafactory/data/loader.py | 3 +++ src/llamafactory/hparams/data_args.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 069ea199..dd9d9d2a 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -179,6 +179,9 @@ def _get_preprocessed_dataset( load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), desc="Running tokenizer on dataset", ) + if data_args.dataset_map_batch_size: + # Set the batch size conditionally without considering the default variable of the batch size in the map function + kwargs.update(batch_size=data_args.dataset_map_batch_size) dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs) diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 0cb4a56d..d80c9165 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -109,6 +109,10 @@ class DataArguments: default=None, metadata={"help": "Path to save or load the tokenized datasets."}, ) + dataset_map_batch_size: Optional[int] = field( + default=None, + metadata={"help": "Batch size for dataset mapping."}, + ) def __post_init__(self): def split_arg(arg):