From ba2be6371d1d3b3266f5650d6b055b5c8c64e176 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sun, 19 Nov 2023 23:32:47 +0800 Subject: [PATCH] better data streaming Former-commit-id: 65ac8e84fd6f22255c587b20382fdf5d8131d015 --- src/llmtuner/data/loader.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/llmtuner/data/loader.py b/src/llmtuner/data/loader.py index b2a64075..8e9053ca 100644 --- a/src/llmtuner/data/loader.py +++ b/src/llmtuner/data/loader.py @@ -60,9 +60,12 @@ def get_dataset( split=data_args.split, cache_dir=model_args.cache_dir, token=model_args.hf_hub_token, - streaming=data_args.streaming + streaming=(data_args.streaming and (dataset_attr.load_from != "file")) ) + if data_args.streaming and (dataset_attr.load_from == "file"): + dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter + if max_samples is not None: # truncate dataset dataset = dataset.select(range(min(len(dataset), max_samples)))