mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 23:58:11 +08:00
Merge pull request #3829 from seanzhang-zhichen/add_dataset_sample_num
Add dataset sample num Former-commit-id: ab38cf74ce48ea4f1800e077ca287f2eb9336135
This commit is contained in:
commit
4f7c850115
@ -12,6 +12,7 @@ Currently we support datasets in **alpaca** and **sharegpt** format.
|
|||||||
"ranking": "whether the dataset is a preference dataset or not. (default: False)",
|
"ranking": "whether the dataset is a preference dataset or not. (default: False)",
|
||||||
"subset": "the name of the subset. (optional, default: None)",
|
"subset": "the name of the subset. (optional, default: None)",
|
||||||
"folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)",
|
"folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)",
|
||||||
|
"num_samples": "the number of samples in the dataset used for training. (optional, default: None)",
|
||||||
"columns (optional)": {
|
"columns (optional)": {
|
||||||
"prompt": "the column name in the dataset containing the prompts. (default: instruction)",
|
"prompt": "the column name in the dataset containing the prompts. (default: instruction)",
|
||||||
"query": "the column name in the dataset containing the queries. (default: input)",
|
"query": "the column name in the dataset containing the queries. (default: input)",
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
"ranking": "是否为偏好数据集(可选,默认:False)",
|
"ranking": "是否为偏好数据集(可选,默认:False)",
|
||||||
"subset": "数据集子集的名称(可选,默认:None)",
|
"subset": "数据集子集的名称(可选,默认:None)",
|
||||||
"folder": "Hugging Face 仓库的文件夹名称(可选,默认:None)",
|
"folder": "Hugging Face 仓库的文件夹名称(可选,默认:None)",
|
||||||
|
"num_samples": "该数据集中用于训练的样本数量。(可选,默认:None)",
|
||||||
"columns(可选)": {
|
"columns(可选)": {
|
||||||
"prompt": "数据集代表提示词的表头名称(默认:instruction)",
|
"prompt": "数据集代表提示词的表头名称(默认:instruction)",
|
||||||
"query": "数据集代表请求的表头名称(默认:input)",
|
"query": "数据集代表请求的表头名称(默认:input)",
|
||||||
|
@ -3,6 +3,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from datasets import load_dataset, load_from_disk
|
from datasets import load_dataset, load_from_disk
|
||||||
|
|
||||||
from ..extras.constants import FILEEXT2TYPE
|
from ..extras.constants import FILEEXT2TYPE
|
||||||
@ -106,9 +107,21 @@ def load_single_dataset(
|
|||||||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
||||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
||||||
|
|
||||||
|
if dataset_attr.num_samples is not None and not data_args.streaming:
|
||||||
|
target_num = dataset_attr.num_samples
|
||||||
|
indexes = np.random.permutation(len(dataset))[:target_num]
|
||||||
|
target_num -= len(indexes)
|
||||||
|
if target_num > 0:
|
||||||
|
expand_indexes = np.random.choice(len(dataset), target_num)
|
||||||
|
indexes = np.concatenate((indexes, expand_indexes), axis=0)
|
||||||
|
|
||||||
|
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
|
||||||
|
dataset = dataset.select(indexes)
|
||||||
|
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))
|
||||||
|
|
||||||
if data_args.max_samples is not None: # truncate dataset
|
if data_args.max_samples is not None: # truncate dataset
|
||||||
num_samples = min(data_args.max_samples, len(dataset))
|
indexes = np.random.permutation(len(dataset))[: data_args.max_samples]
|
||||||
dataset = dataset.select(range(num_samples))
|
dataset = dataset.select(indexes)
|
||||||
|
|
||||||
return align_dataset(dataset, dataset_attr, data_args)
|
return align_dataset(dataset, dataset_attr, data_args)
|
||||||
|
|
||||||
|
@ -20,11 +20,12 @@ class DatasetAttr:
|
|||||||
""" basic configs """
|
""" basic configs """
|
||||||
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
||||||
dataset_name: str
|
dataset_name: str
|
||||||
|
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
||||||
|
ranking: bool = False
|
||||||
""" extra configs """
|
""" extra configs """
|
||||||
subset: Optional[str] = None
|
subset: Optional[str] = None
|
||||||
folder: Optional[str] = None
|
folder: Optional[str] = None
|
||||||
ranking: bool = False
|
num_samples: Optional[int] = None
|
||||||
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
|
||||||
""" common columns """
|
""" common columns """
|
||||||
system: Optional[str] = None
|
system: Optional[str] = None
|
||||||
tools: Optional[str] = None
|
tools: Optional[str] = None
|
||||||
@ -102,10 +103,11 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
|||||||
else:
|
else:
|
||||||
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
|
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
|
||||||
|
|
||||||
|
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||||
|
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
||||||
dataset_attr.set_attr("subset", dataset_info[name])
|
dataset_attr.set_attr("subset", dataset_info[name])
|
||||||
dataset_attr.set_attr("folder", dataset_info[name])
|
dataset_attr.set_attr("folder", dataset_info[name])
|
||||||
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
dataset_attr.set_attr("num_samples", dataset_info[name])
|
||||||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
|
||||||
|
|
||||||
if "columns" in dataset_info[name]:
|
if "columns" in dataset_info[name]:
|
||||||
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]
|
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user