From 633624dc3c641c92b9fe913d97feafe4d7eb57f4 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 20 Dec 2023 16:11:07 +0800 Subject: [PATCH] fix #1909 Former-commit-id: c6abbbfe90dcb0e832f73f0c611fc32eaa7ea78d --- src/llmtuner/data/preprocess.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/llmtuner/data/preprocess.py b/src/llmtuner/data/preprocess.py index 2d2b2db6..ee1c0390 100644 --- a/src/llmtuner/data/preprocess.py +++ b/src/llmtuner/data/preprocess.py @@ -44,6 +44,13 @@ def preprocess_dataset( ) -> Union["Dataset", "IterableDataset"]: template = get_template_and_fix_tokenizer(data_args.template, tokenizer) + if data_args.cache_path is not None and os.path.exists(data_args.cache_path): + logger.warning("Loading dataset from disk will ignore other data arguments.") + dataset = load_from_disk(data_args.cache_path) + if data_args.streaming: + dataset = dataset.to_iterable_dataset() + return dataset + if data_args.train_on_prompt and template.efficient_eos: raise ValueError("Current template does not support `train_on_prompt`.") @@ -240,10 +247,6 @@ def preprocess_dataset( preprocess_func = preprocess_unsupervised_dataset print_function = print_unsupervised_dataset_example - if data_args.cache_path is not None and os.path.exists(data_args.cache_path): - logger.warning("Loading dataset from disk will ignore other data arguments.") - return load_from_disk(data_args.cache_path) - with training_args.main_process_first(desc="dataset map pre-processing"): column_names = list(next(iter(dataset)).keys()) kwargs = {}