Update loader.py

Former-commit-id: b55fb611c57be03fb38218c7da1d96f6848496ba
This commit is contained in:
hoshi-hiyouga 2024-05-30 00:12:12 +08:00 committed by GitHub
parent 69a51cacb1
commit c0f11a280e

View File

@ -1,10 +1,9 @@
import inspect
import os
import numpy as np
from numpy.random import RandomState
import sys
from typing import TYPE_CHECKING, Literal, Optional, Union
import numpy as np
from datasets import load_dataset, load_from_disk
from ..extras.constants import FILEEXT2TYPE
@ -108,20 +107,14 @@ def load_single_dataset(
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if data_args.max_samples is not None: # truncate dataset
num_samples = min(data_args.max_samples, len(dataset))
dataset = dataset.select(range(num_samples))
if dataset_attr.num_samples is not None and not data_args.streaming:
indexes = np.random.permutation(len(dataset))[: dataset_attr.num_samples]
dataset = dataset.select(indexes)
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))
if dataset_attr.sample_num:
dataset_sample_num = dataset_attr.sample_num
logger.info(f"{dataset_attr.dataset_name} 采样 {dataset_sample_num} 条训练样本")
random_state = RandomState(42)
idx = random_state.permutation(len(dataset))[:dataset_sample_num]
dataset_sample_num -= len(idx)
if dataset_sample_num > 0:
idx2 = random_state.choice(len(dataset), dataset_sample_num)
idx = np.concatenate([idx, idx2], axis=0)
dataset = dataset.select(idx)
if data_args.max_samples is not None: # truncate dataset
indexes = np.random.permutation(len(dataset))[: data_args.max_samples]
dataset = dataset.select(indexes)
return align_dataset(dataset, dataset_attr, data_args)