From f869e44fe5924e45f81a38cd889ea8bc30a6a262 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sat, 23 Dec 2023 14:42:20 +0800 Subject: [PATCH] fix #1909 Former-commit-id: 3e93c33af9f80e28c9f30af9b7ba20757358afb4 --- src/llmtuner/data/loader.py | 15 +++++++++------ src/llmtuner/hparams/data_args.py | 3 --- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/llmtuner/data/loader.py b/src/llmtuner/data/loader.py index 650e809b..f9019c8b 100644 --- a/src/llmtuner/data/loader.py +++ b/src/llmtuner/data/loader.py @@ -22,12 +22,15 @@ def get_dataset( max_samples = data_args.max_samples all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets - if data_args.cache_path is not None and os.path.exists(data_args.cache_path): - logger.warning("Loading dataset from disk will ignore other data arguments.") - dataset = load_from_disk(data_args.cache_path) - if data_args.streaming: - dataset = dataset.to_iterable_dataset() - return dataset + if data_args.cache_path is not None: + if os.path.exists(data_args.cache_path): + logger.warning("Loading dataset from disk will ignore other data arguments.") + dataset = load_from_disk(data_args.cache_path) + if data_args.streaming: + dataset = dataset.to_iterable_dataset() + return dataset + elif data_args.streaming: + raise ValueError("Turn off dataset streaming to save cache files.") for dataset_attr in data_args.dataset_list: logger.info("Loading dataset {}...".format(dataset_attr)) diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index dd4f2bb9..7be4f4f5 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -127,9 +127,6 @@ class DataArguments: if self.streaming and self.max_samples is not None: raise ValueError("`max_samples` is incompatible with `streaming`.") - if self.streaming and self.cache_path: - raise ValueError("`cache_path` is incompatible with `streaming`.") - def init_for_training(self, seed: int): # support mixing multiple datasets self.seed = seed dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []