From f64be8ee8411afda76ad36b94b8e904c27047ada Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 20 Dec 2023 16:15:41 +0800 Subject: [PATCH] optimize data loading logic Former-commit-id: ec1fe1daa98c61f62c753b22847de028b5c5cded --- src/llmtuner/data/loader.py | 9 ++++++++- src/llmtuner/data/preprocess.py | 8 +------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/llmtuner/data/loader.py b/src/llmtuner/data/loader.py index bbdde411..6b8761de 100644 --- a/src/llmtuner/data/loader.py +++ b/src/llmtuner/data/loader.py @@ -1,7 +1,7 @@ import os from typing import TYPE_CHECKING, Any, Dict, List, Union -from datasets import concatenate_datasets, interleave_datasets, load_dataset +from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk from llmtuner.data.utils import checksum from llmtuner.extras.constants import FILEEXT2TYPE @@ -22,6 +22,13 @@ def get_dataset( max_samples = data_args.max_samples all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets + if data_args.cache_path is not None and os.path.exists(data_args.cache_path): + logger.warning("Loading dataset from disk will ignore other data arguments.") + dataset = load_from_disk(data_args.cache_path) + if data_args.streaming: + dataset = dataset.to_iterable_dataset() + return dataset + for dataset_attr in data_args.dataset_list: logger.info("Loading dataset {}...".format(dataset_attr)) diff --git a/src/llmtuner/data/preprocess.py b/src/llmtuner/data/preprocess.py index ee1c0390..b4cf524a 100644 --- a/src/llmtuner/data/preprocess.py +++ b/src/llmtuner/data/preprocess.py @@ -3,8 +3,6 @@ import tiktoken from itertools import chain from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Union -from datasets import load_from_disk - from llmtuner.data.template import get_template_and_fix_tokenizer from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.logging import get_logger @@ -45,11 +43,7 @@ def preprocess_dataset( template = get_template_and_fix_tokenizer(data_args.template, tokenizer) if data_args.cache_path is not None and os.path.exists(data_args.cache_path): - logger.warning("Loading dataset from disk will ignore other data arguments.") - dataset = load_from_disk(data_args.cache_path) - if data_args.streaming: - dataset = dataset.to_iterable_dataset() - return dataset + return dataset # already preprocessed if data_args.train_on_prompt and template.efficient_eos: raise ValueError("Current template does not support `train_on_prompt`.")