diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 645794c1..9de6477e 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -300,7 +300,7 @@ def get_dataset( raise ValueError("Turn off `streaming` when saving dataset to disk.") # Load and preprocess dataset - with training_args.main_process_first(desc="load dataset"): + with training_args.main_process_first(desc="load dataset", local=(not data_args.data_shared_file_system)): dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage) eval_dataset = _get_merged_dataset( data_args.eval_dataset, @@ -311,7 +311,7 @@ def get_dataset( return_dict=data_args.eval_on_each_dataset, ) - with training_args.main_process_first(desc="pre-process dataset"): + with training_args.main_process_first(desc="pre-process dataset", local=(not data_args.data_shared_file_system)): dataset = _get_preprocessed_dataset( dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False ) diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index c84fb2f7..e6844733 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -133,6 +133,10 @@ class DataArguments: ) }, ) + data_shared_file_system: bool = field( + default=False, + metadata={"help": "Whether or not to use a shared file system for the datasets."}, + ) def __post_init__(self): def split_arg(arg):