diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 322eefa0..fa5b12c5 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -108,7 +108,13 @@ def load_single_dataset( dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter if dataset_attr.num_samples is not None and not data_args.streaming: - indexes = np.random.permutation(len(dataset))[: dataset_attr.num_samples] + target_num = dataset_attr.num_samples + indexes = np.random.permutation(len(dataset))[:target_num] + target_num -= len(indexes) + if target_num > 0: + expand_indexes = np.random.choice(len(dataset), target_num) + indexes = np.concatenate((indexes, expand_indexes), axis=0) + dataset = dataset.select(indexes) logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))