diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index f5929f15..4d0503c3 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -120,8 +120,8 @@ def load_single_dataset( logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr)) if data_args.max_samples is not None: # truncate dataset - indexes = np.random.permutation(len(dataset))[: data_args.max_samples] - dataset = dataset.select(indexes) + max_samples = min(data_args.max_samples, len(dataset)) + dataset = dataset.select(range(max_samples)) return align_dataset(dataset, dataset_attr, data_args)