This commit is contained in:
hiyouga
2023-12-23 14:42:20 +08:00
parent 0ad86a4f62
commit 0bbf7118df
2 changed files with 9 additions and 9 deletions

View File

@@ -22,12 +22,15 @@ def get_dataset(
max_samples = data_args.max_samples
all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.cache_path)
if data_args.streaming:
dataset = dataset.to_iterable_dataset()
return dataset
if data_args.cache_path is not None:
if os.path.exists(data_args.cache_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.cache_path)
if data_args.streaming:
dataset = dataset.to_iterable_dataset()
return dataset
elif data_args.streaming:
raise ValueError("Turn off dataset streaming to save cache files.")
for dataset_attr in data_args.dataset_list:
logger.info("Loading dataset {}...".format(dataset_attr))