diff --git a/assets/wechat.jpg b/assets/wechat.jpg index 1cdb261f..c8662c40 100644 Binary files a/assets/wechat.jpg and b/assets/wechat.jpg differ diff --git a/src/llmtuner/data/loader.py b/src/llmtuner/data/loader.py index b2a64075..8e9053ca 100644 --- a/src/llmtuner/data/loader.py +++ b/src/llmtuner/data/loader.py @@ -60,9 +60,12 @@ def get_dataset( split=data_args.split, cache_dir=model_args.cache_dir, token=model_args.hf_hub_token, - streaming=data_args.streaming + streaming=(data_args.streaming and (dataset_attr.load_from != "file")) ) + if data_args.streaming and (dataset_attr.load_from == "file"): + dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter + if max_samples is not None: # truncate dataset dataset = dataset.select(range(min(len(dataset), max_samples)))