[model] fix kv cache (#7564)

This commit is contained in:
hoshi-hiyouga
2025-04-01 23:07:46 +08:00
committed by GitHub
parent a13b1bb49a
commit 2bfcad2394
16 changed files with 122 additions and 64 deletions

View File

@@ -101,12 +101,10 @@ def _load_single_dataset(
split=dataset_attr.split,
cache_dir=cache_dir,
token=model_args.ms_hub_token,
use_streaming=data_args.streaming and not data_args.dataset_shards, # only set to True when user specified streaming but do not want dataset to be sharded
use_streaming=data_args.streaming,
)
if isinstance(dataset, MsDataset):
dataset = dataset.to_hf_dataset()
if data_args.streaming and data_args.dataset_shards:
dataset = dataset.to_iterable_dataset(num_shards=data_args.dataset_shards)
elif dataset_attr.load_from == "om_hub":
check_version("openmind>=0.8.0", mandatory=True)
@@ -135,10 +133,10 @@ def _load_single_dataset(
token=model_args.hf_hub_token,
num_proc=data_args.preprocessing_num_workers,
trust_remote_code=model_args.trust_remote_code,
streaming=data_args.streaming and not data_args.dataset_shards,
streaming=data_args.streaming and dataset_attr.load_from != "file",
)
if data_args.streaming and data_args.dataset_shards:
dataset = dataset.to_iterable_dataset(num_shards=data_args.dataset_shards)
if data_args.streaming and dataset_attr.load_from == "file":
dataset = dataset.to_iterable_dataset(num_shards=training_args.dataloader_num_workers)
if dataset_attr.num_samples is not None and not data_args.streaming:
target_num = dataset_attr.num_samples