diff --git a/data/README.md b/data/README.md index 3e96fbeb..5ceae666 100644 --- a/data/README.md +++ b/data/README.md @@ -12,6 +12,7 @@ Currently we support datasets in **alpaca** and **sharegpt** format. "ranking": "whether the dataset is a preference dataset or not. (default: False)", "subset": "the name of the subset. (optional, default: None)", "folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)", + "num_samples": "the number of samples in the dataset used for training. (optional, default: None)", "columns (optional)": { "prompt": "the column name in the dataset containing the prompts. (default: instruction)", "query": "the column name in the dataset containing the queries. (default: input)", diff --git a/data/README_zh.md b/data/README_zh.md index aff6fdb1..1795f352 100644 --- a/data/README_zh.md +++ b/data/README_zh.md @@ -12,6 +12,7 @@ "ranking": "是否为偏好数据集(可选,默认:False)", "subset": "数据集子集的名称(可选,默认:None)", "folder": "Hugging Face 仓库的文件夹名称(可选,默认:None)", + "num_samples": "该数据集中用于训练的样本数量。(可选,默认:None)", "columns(可选)": { "prompt": "数据集代表提示词的表头名称(默认:instruction)", "query": "数据集代表请求的表头名称(默认:input)", diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 48d28f1d..d4a19e27 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -3,6 +3,7 @@ import os import sys from typing import TYPE_CHECKING, Literal, Optional, Union +import numpy as np from datasets import load_dataset, load_from_disk from ..extras.constants import FILEEXT2TYPE @@ -106,9 +107,21 @@ def load_single_dataset( if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter + if dataset_attr.num_samples is not None and not data_args.streaming: + target_num = dataset_attr.num_samples + indexes = np.random.permutation(len(dataset))[:target_num] + target_num -= len(indexes) + if target_num > 0: + expand_indexes = np.random.choice(len(dataset), target_num) + indexes = np.concatenate((indexes, expand_indexes), axis=0) + + assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched." + dataset = dataset.select(indexes) + logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr)) + if data_args.max_samples is not None: # truncate dataset - num_samples = min(data_args.max_samples, len(dataset)) - dataset = dataset.select(range(num_samples)) + indexes = np.random.permutation(len(dataset))[: data_args.max_samples] + dataset = dataset.select(indexes) return align_dataset(dataset, dataset_attr, data_args) diff --git a/src/llamafactory/data/parser.py b/src/llamafactory/data/parser.py index 679f8ad6..ec97bfc1 100644 --- a/src/llamafactory/data/parser.py +++ b/src/llamafactory/data/parser.py @@ -20,11 +20,12 @@ class DatasetAttr: """ basic configs """ load_from: Literal["hf_hub", "ms_hub", "script", "file"] dataset_name: str + formatting: Literal["alpaca", "sharegpt"] = "alpaca" + ranking: bool = False """ extra configs """ subset: Optional[str] = None folder: Optional[str] = None - ranking: bool = False - formatting: Literal["alpaca", "sharegpt"] = "alpaca" + num_samples: Optional[int] = None """ common columns """ system: Optional[str] = None tools: Optional[str] = None @@ -102,10 +103,11 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]: else: dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"]) + dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca") + dataset_attr.set_attr("ranking", dataset_info[name], default=False) dataset_attr.set_attr("subset", dataset_info[name]) dataset_attr.set_attr("folder", dataset_info[name]) - dataset_attr.set_attr("ranking", dataset_info[name], default=False) - dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca") + dataset_attr.set_attr("num_samples", dataset_info[name]) if "columns" in dataset_info[name]: column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]