mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
Update loader.py
Former-commit-id: b55fb611c57be03fb38218c7da1d96f6848496ba
This commit is contained in:
parent
69a51cacb1
commit
c0f11a280e
@ -1,10 +1,9 @@
|
||||
import inspect
|
||||
import os
|
||||
import numpy as np
|
||||
from numpy.random import RandomState
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
from ..extras.constants import FILEEXT2TYPE
|
||||
@ -108,20 +107,14 @@ def load_single_dataset(
|
||||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
||||
|
||||
if data_args.max_samples is not None: # truncate dataset
|
||||
num_samples = min(data_args.max_samples, len(dataset))
|
||||
dataset = dataset.select(range(num_samples))
|
||||
if dataset_attr.num_samples is not None and not data_args.streaming:
|
||||
indexes = np.random.permutation(len(dataset))[: dataset_attr.num_samples]
|
||||
dataset = dataset.select(indexes)
|
||||
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))
|
||||
|
||||
if dataset_attr.sample_num:
|
||||
dataset_sample_num = dataset_attr.sample_num
|
||||
logger.info(f"从 {dataset_attr.dataset_name} 采样 {dataset_sample_num} 条训练样本")
|
||||
random_state = RandomState(42)
|
||||
idx = random_state.permutation(len(dataset))[:dataset_sample_num]
|
||||
dataset_sample_num -= len(idx)
|
||||
if dataset_sample_num > 0:
|
||||
idx2 = random_state.choice(len(dataset), dataset_sample_num)
|
||||
idx = np.concatenate([idx, idx2], axis=0)
|
||||
dataset = dataset.select(idx)
|
||||
if data_args.max_samples is not None: # truncate dataset
|
||||
indexes = np.random.permutation(len(dataset))[: data_args.max_samples]
|
||||
dataset = dataset.select(indexes)
|
||||
|
||||
return align_dataset(dataset, dataset_attr, data_args)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user