better data streaming

Former-commit-id: 65ac8e84fd6f22255c587b20382fdf5d8131d015
This commit is contained in:
hiyouga 2023-11-19 23:32:47 +08:00
parent d2ff09a404
commit ba2be6371d

View File

@ -60,9 +60,12 @@ def get_dataset(
split=data_args.split,
cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token,
streaming=data_args.streaming
streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
)
if data_args.streaming and (dataset_attr.load_from == "file"):
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if max_samples is not None: # truncate dataset
dataset = dataset.select(range(min(len(dataset), max_samples)))