mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-11-05 18:32:14 +08:00
better data streaming
Former-commit-id: 65ac8e84fd6f22255c587b20382fdf5d8131d015
This commit is contained in:
parent
6f64aeeba2
commit
8e50cc3c5b
@ -60,9 +60,12 @@ def get_dataset(
|
|||||||
split=data_args.split,
|
split=data_args.split,
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
token=model_args.hf_hub_token,
|
token=model_args.hf_hub_token,
|
||||||
streaming=data_args.streaming
|
streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if data_args.streaming and (dataset_attr.load_from == "file"):
|
||||||
|
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
||||||
|
|
||||||
if max_samples is not None: # truncate dataset
|
if max_samples is not None: # truncate dataset
|
||||||
dataset = dataset.select(range(min(len(dataset), max_samples)))
|
dataset = dataset.select(range(min(len(dataset), max_samples)))
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user