diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 5ce4392e..322eefa0 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -1,10 +1,9 @@ import inspect import os -import numpy as np -from numpy.random import RandomState import sys from typing import TYPE_CHECKING, Literal, Optional, Union +import numpy as np from datasets import load_dataset, load_from_disk from ..extras.constants import FILEEXT2TYPE @@ -108,20 +107,14 @@ def load_single_dataset( if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter - if data_args.max_samples is not None: # truncate dataset - num_samples = min(data_args.max_samples, len(dataset)) - dataset = dataset.select(range(num_samples)) + if dataset_attr.num_samples is not None and not data_args.streaming: + indexes = np.random.permutation(len(dataset))[: dataset_attr.num_samples] + dataset = dataset.select(indexes) + logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr)) - if dataset_attr.sample_num: - dataset_sample_num = dataset_attr.sample_num - logger.info(f"从 {dataset_attr.dataset_name} 采样 {dataset_sample_num} 条训练样本") - random_state = RandomState(42) - idx = random_state.permutation(len(dataset))[:dataset_sample_num] - dataset_sample_num -= len(idx) - if dataset_sample_num > 0: - idx2 = random_state.choice(len(dataset), dataset_sample_num) - idx = np.concatenate([idx, idx2], axis=0) - dataset = dataset.select(idx) + if data_args.max_samples is not None: # truncate dataset + indexes = np.random.permutation(len(dataset))[: data_args.max_samples] + dataset = dataset.select(indexes) return align_dataset(dataset, dataset_attr, data_args)