better data streaming

Former-commit-id: 00baaa990e099d6b75436eaa7a922a07646afa26
This commit is contained in:
hiyouga 2023-11-19 23:32:47 +08:00
parent d1e03512f4
commit 32545bd6d9
2 changed files with 4 additions and 1 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 141 KiB

After

Width:  |  Height:  |  Size: 140 KiB

View File

@ -60,9 +60,12 @@ def get_dataset(
split=data_args.split, split=data_args.split,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token, token=model_args.hf_hub_token,
streaming=data_args.streaming streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
) )
if data_args.streaming and (dataset_attr.load_from == "file"):
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if max_samples is not None: # truncate dataset if max_samples is not None: # truncate dataset
dataset = dataset.select(range(min(len(dataset), max_samples))) dataset = dataset.select(range(min(len(dataset), max_samples)))