diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index a686a0a6..be37f38f 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -179,6 +179,9 @@ def _get_preprocessed_dataset( load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), desc="Running tokenizer on dataset", ) + if data_args.dataset_map_batch_size: + # Set the batch size conditionally without considering the default variable of the batch size in the map function + kwargs.update(batch_size=data_args.dataset_map_batch_size) dataset = dataset.map( preprocess_func, diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 1adcf2d0..a03128c6 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -113,6 +113,10 @@ class DataArguments: default=None, metadata={"help": "Path to save or load the tokenized datasets."}, ) + dataset_map_batch_size: Optional[int] = field( + default=None, + metadata={"help": "Batch size for dataset mapping."}, + ) def __post_init__(self): def split_arg(arg):