diff --git a/README.md b/README.md index ab500587..30d1b9f9 100644 --- a/README.md +++ b/README.md @@ -204,7 +204,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ [23/08/11] We supported **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [examples](examples/README.md) for usage. -[23/07/31] We supported **dataset streaming**. Try `streaming: true` and `max_steps: 10000` arguments to load your dataset in streaming mode. +[23/07/31] We supported **dataset streaming**. Try `streaming: true` and `max_steps: 10000` arguments to load your dataset in streaming mode. Use `dataset_shards` to enable parallel preprocessing with streaming. [23/07/29] We released two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft)) for details. diff --git a/README_zh.md b/README_zh.md index 4862060e..d6643906 100644 --- a/README_zh.md +++ b/README_zh.md @@ -206,7 +206,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc [23/08/11] 我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。详细用法请参照 [examples](examples/README_zh.md)。 -[23/07/31] 我们支持了**数据流式加载**。请使用 `streaming: true` 和 `max_steps: 10000` 参数来流式加载数据集。 +[23/07/31] 我们支持了**数据流式加载**。请使用 `streaming: true` 和 `max_steps: 10000` 参数来流式加载数据集。 用 `dataset_shards` 来开启多进程加载。 [23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft))。 diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index e495cc31..78fa1192 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -101,10 +101,12 @@ def _load_single_dataset( split=dataset_attr.split, cache_dir=cache_dir, token=model_args.ms_hub_token, - use_streaming=data_args.streaming, + use_streaming=data_args.streaming and not data_args.dataset_shards, # only set to True when user specified streaming but do not want dataset to be sharded ) if isinstance(dataset, MsDataset): dataset = dataset.to_hf_dataset() + if data_args.streaming and data_args.dataset_shards: + dataset = dataset.to_iterable_dataset(num_shards=data_args.dataset_shards) elif dataset_attr.load_from == "om_hub": check_version("openmind>=0.8.0", mandatory=True) @@ -131,10 +133,12 @@ def _load_single_dataset( split=dataset_attr.split, cache_dir=model_args.cache_dir, token=model_args.hf_hub_token, - streaming=data_args.streaming, num_proc=data_args.preprocessing_num_workers, trust_remote_code=model_args.trust_remote_code, + streaming=data_args.streaming and not data_args.dataset_shards, ) + if data_args.streaming and data_args.dataset_shards: + dataset = dataset.to_iterable_dataset(num_shards=data_args.dataset_shards) if dataset_attr.num_samples is not None and not data_args.streaming: target_num = dataset_attr.num_samples diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 3a66b2c0..3a51142b 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -83,6 +83,10 @@ class DataArguments: default=None, metadata={"help": "The number of processes to use for the pre-processing."}, ) + dataset_shards: Optional[int] = field( + default=None, + metadata={"help": "The number of shards to split the dataset into. Only used in streaming mode. This should be set to the same as dataloader_num_workers. Not setting this while streaming data will cause the dataset to be non-sharded and thus only can be processed using one worker."}, + ) max_samples: Optional[int] = field( default=None, metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},