diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 069ea199..dd9d9d2a 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -179,6 +179,9 @@ def _get_preprocessed_dataset( load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), desc="Running tokenizer on dataset", ) + if data_args.dataset_map_batch_size: + # Set the batch size conditionally without considering the default variable of the batch size in the map function + kwargs.update(batch_size=data_args.dataset_map_batch_size) dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs) diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 0cb4a56d..d80c9165 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -109,6 +109,10 @@ class DataArguments: default=None, metadata={"help": "Path to save or load the tokenized datasets."}, ) + dataset_map_batch_size: Optional[int] = field( + default=None, + metadata={"help": "Batch size for dataset mapping."}, + ) def __post_init__(self): def split_arg(arg):