From 7cdc16abdf63a01e9769d423346c5e952fcad263 Mon Sep 17 00:00:00 2001 From: zhangzc <2608882093@qq.com> Date: Wed, 27 Mar 2024 14:22:50 +0800 Subject: [PATCH 1/7] Supports custom data set sampling quantity Former-commit-id: fa8325401df27595de4611a89dfcc14644956abd --- data/README.md | 5 +++-- data/README_zh.md | 3 ++- src/llmtuner/data/loader.py | 13 +++++++++++++ src/llmtuner/data/parser.py | 4 +++- 4 files changed, 21 insertions(+), 4 deletions(-) diff --git a/data/README.md b/data/README.md index fa2c9ee0..c4a1b298 100644 --- a/data/README.md +++ b/data/README.md @@ -27,8 +27,9 @@ If you are using a custom dataset, please provide your dataset definition in the "assistant_tag": "the value of the role_tag represents the assistant. (default: gpt)", "observation_tag": "the value of the role_tag represents the tool results. (default: observation)", "function_tag": "the value of the role_tag represents the function call. (default: function_call)", - "system_tag": "the value of the role_tag represents the system prompt. (default: system, can override system column)" - } + "system_tag": "the value of the role_tag represents the system prompt. (default: system, can override system column)", + }, + "sample_num": "the number of samples from this dataset can be greater than the total amount of the dataset. (default: None)" } ``` diff --git a/data/README_zh.md b/data/README_zh.md index e0004f4a..6396688a 100644 --- a/data/README_zh.md +++ b/data/README_zh.md @@ -28,7 +28,8 @@ "observation_tag": "消息中代表工具返回结果的 role_tag(默认:observation)", "function_tag": "消息中代表工具调用的 role_tag(默认:function_call)", "system_tag": "消息中代表系统提示的 role_tag(默认:system,会覆盖 system 列)" - } + }, + "sample_num": "从该数据集采样的数量,可大于该数据集总量(默认:None)" } ``` diff --git a/src/llmtuner/data/loader.py b/src/llmtuner/data/loader.py index 935695ad..bebe5718 100644 --- a/src/llmtuner/data/loader.py +++ b/src/llmtuner/data/loader.py @@ -1,5 +1,7 @@ import inspect import os +import numpy as np +from numpy.random import RandomState from typing import TYPE_CHECKING, Literal, Union from datasets import load_dataset, load_from_disk @@ -108,6 +110,17 @@ def load_single_dataset( num_samples = min(data_args.max_samples, len(dataset)) dataset = dataset.select(range(num_samples)) + if dataset_attr.sample_num: + dataset_sample_num = dataset_attr.sample_num + logger.info(f"从 {dataset_attr.dataset_name} 采样 {dataset_sample_num} 条训练样本") + random_state = RandomState(42) + idx = random_state.permutation(len(dataset))[:dataset_sample_num] + dataset_sample_num -= len(idx) + if dataset_sample_num > 0: + idx2 = random_state.choice(len(dataset), dataset_sample_num) + idx = np.concatenate([idx, idx2], axis=0) + dataset = dataset.select(idx) + return align_dataset(dataset, dataset_attr, data_args) diff --git a/src/llmtuner/data/parser.py b/src/llmtuner/data/parser.py index 861396a0..9746b5b2 100644 --- a/src/llmtuner/data/parser.py +++ b/src/llmtuner/data/parser.py @@ -44,6 +44,7 @@ class DatasetAttr: observation_tag: Optional[str] = "observation" function_tag: Optional[str] = "function_call" system_tag: Optional[str] = "system" + sample_num: Optional[int] = None def __repr__(self) -> str: return self.dataset_name @@ -90,7 +91,8 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]: dataset_attr.set_attr("folder", dataset_info[name]) dataset_attr.set_attr("ranking", dataset_info[name], default=False) dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca") - + dataset_attr.set_attr("sample_num", dataset_info[name]) + if "columns" in dataset_info[name]: column_names = ["system"] if dataset_attr.formatting == "alpaca": From 890926e60c8585f60b782f77305331f0afe4f54a Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Thu, 30 May 2024 00:04:26 +0800 Subject: [PATCH 2/7] Update README.md Former-commit-id: 65fb69e388c0a04c15ecd11441e567966f51fae5 --- data/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/data/README.md b/data/README.md index dd7ca201..5ceae666 100644 --- a/data/README.md +++ b/data/README.md @@ -12,6 +12,7 @@ Currently we support datasets in **alpaca** and **sharegpt** format. "ranking": "whether the dataset is a preference dataset or not. (default: False)", "subset": "the name of the subset. (optional, default: None)", "folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)", + "num_samples": "the number of samples in the dataset used for training. (optional, default: None)", "columns (optional)": { "prompt": "the column name in the dataset containing the prompts. (default: instruction)", "query": "the column name in the dataset containing the queries. (default: input)", @@ -32,9 +33,8 @@ Currently we support datasets in **alpaca** and **sharegpt** format. "assistant_tag": "the value of the role_tag represents the assistant. (default: gpt)", "observation_tag": "the value of the role_tag represents the tool results. (default: observation)", "function_tag": "the value of the role_tag represents the function call. (default: function_call)", - "system_tag": "the value of the role_tag represents the system prompt. (default: system, can override system column)", - }, - "sample_num": "the number of samples from this dataset can be greater than the total amount of the dataset. (default: None)" + "system_tag": "the value of the role_tag represents the system prompt. (default: system, can override system column)" + } } ``` From 91cc571e6e6cd9fd849d710aac6567d06677fa72 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Thu, 30 May 2024 00:04:47 +0800 Subject: [PATCH 3/7] Update README_zh.md Former-commit-id: 3007d260ed45169583a74497a53b661337dd5f71 --- data/README_zh.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/data/README_zh.md b/data/README_zh.md index 1427e48d..1795f352 100644 --- a/data/README_zh.md +++ b/data/README_zh.md @@ -12,6 +12,7 @@ "ranking": "是否为偏好数据集(可选,默认:False)", "subset": "数据集子集的名称(可选,默认:None)", "folder": "Hugging Face 仓库的文件夹名称(可选,默认:None)", + "num_samples": "该数据集中用于训练的样本数量。(可选,默认:None)", "columns(可选)": { "prompt": "数据集代表提示词的表头名称(默认:instruction)", "query": "数据集代表请求的表头名称(默认:input)", @@ -32,9 +33,8 @@ "assistant_tag": "消息中代表助手的 role_tag(默认:gpt)", "observation_tag": "消息中代表工具返回结果的 role_tag(默认:observation)", "function_tag": "消息中代表工具调用的 role_tag(默认:function_call)", - "system_tag": "消息中代表系统提示的 role_tag(默认:system,会覆盖 system 列)" - }, - "sample_num": "从该数据集采样的数量,可大于该数据集总量(默认:None)" + "system_tag": "消息中代表系统提示的 role_tag(默认:system,会覆盖 system column)" + } } ``` From 05e6fe42875afcdb597c7053a2e3bfb2cd32fcc6 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Thu, 30 May 2024 00:05:20 +0800 Subject: [PATCH 4/7] Update parser.py Former-commit-id: 310cc11e8c83f16fc5bccc349c38fea347ea9a97 --- src/llamafactory/data/parser.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/llamafactory/data/parser.py b/src/llamafactory/data/parser.py index 99b71cf0..ec97bfc1 100644 --- a/src/llamafactory/data/parser.py +++ b/src/llamafactory/data/parser.py @@ -20,11 +20,12 @@ class DatasetAttr: """ basic configs """ load_from: Literal["hf_hub", "ms_hub", "script", "file"] dataset_name: str + formatting: Literal["alpaca", "sharegpt"] = "alpaca" + ranking: bool = False """ extra configs """ subset: Optional[str] = None folder: Optional[str] = None - ranking: bool = False - formatting: Literal["alpaca", "sharegpt"] = "alpaca" + num_samples: Optional[int] = None """ common columns """ system: Optional[str] = None tools: Optional[str] = None @@ -48,7 +49,6 @@ class DatasetAttr: observation_tag: Optional[str] = "observation" function_tag: Optional[str] = "function_call" system_tag: Optional[str] = "system" - sample_num: Optional[int] = None def __repr__(self) -> str: return self.dataset_name @@ -103,12 +103,12 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]: else: dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"]) + dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca") + dataset_attr.set_attr("ranking", dataset_info[name], default=False) dataset_attr.set_attr("subset", dataset_info[name]) dataset_attr.set_attr("folder", dataset_info[name]) - dataset_attr.set_attr("ranking", dataset_info[name], default=False) - dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca") - dataset_attr.set_attr("sample_num", dataset_info[name]) - + dataset_attr.set_attr("num_samples", dataset_info[name]) + if "columns" in dataset_info[name]: column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"] if dataset_attr.formatting == "alpaca": From 5f67fdaac99e88d76cf7a9dfe210cd96f2054e1b Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Thu, 30 May 2024 00:12:12 +0800 Subject: [PATCH 5/7] Update loader.py Former-commit-id: 19d8fd62c18ee3ba0e431fc241f7d315cb716fef --- src/llamafactory/data/loader.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 5ce4392e..322eefa0 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -1,10 +1,9 @@ import inspect import os -import numpy as np -from numpy.random import RandomState import sys from typing import TYPE_CHECKING, Literal, Optional, Union +import numpy as np from datasets import load_dataset, load_from_disk from ..extras.constants import FILEEXT2TYPE @@ -108,20 +107,14 @@ def load_single_dataset( if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter - if data_args.max_samples is not None: # truncate dataset - num_samples = min(data_args.max_samples, len(dataset)) - dataset = dataset.select(range(num_samples)) + if dataset_attr.num_samples is not None and not data_args.streaming: + indexes = np.random.permutation(len(dataset))[: dataset_attr.num_samples] + dataset = dataset.select(indexes) + logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr)) - if dataset_attr.sample_num: - dataset_sample_num = dataset_attr.sample_num - logger.info(f"从 {dataset_attr.dataset_name} 采样 {dataset_sample_num} 条训练样本") - random_state = RandomState(42) - idx = random_state.permutation(len(dataset))[:dataset_sample_num] - dataset_sample_num -= len(idx) - if dataset_sample_num > 0: - idx2 = random_state.choice(len(dataset), dataset_sample_num) - idx = np.concatenate([idx, idx2], axis=0) - dataset = dataset.select(idx) + if data_args.max_samples is not None: # truncate dataset + indexes = np.random.permutation(len(dataset))[: data_args.max_samples] + dataset = dataset.select(indexes) return align_dataset(dataset, dataset_attr, data_args) From a67199246de33f151f056cb5ce78bca2d441fe34 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Thu, 30 May 2024 00:17:21 +0800 Subject: [PATCH 6/7] Update loader.py Former-commit-id: aa7f335e3ad5a78e4ed5f99c120be28e9733ea2e --- src/llamafactory/data/loader.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 322eefa0..fa5b12c5 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -108,7 +108,13 @@ def load_single_dataset( dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter if dataset_attr.num_samples is not None and not data_args.streaming: - indexes = np.random.permutation(len(dataset))[: dataset_attr.num_samples] + target_num = dataset_attr.num_samples + indexes = np.random.permutation(len(dataset))[:target_num] + target_num -= len(indexes) + if target_num > 0: + expand_indexes = np.random.choice(len(dataset), target_num) + indexes = np.concatenate((indexes, expand_indexes), axis=0) + dataset = dataset.select(indexes) logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr)) From 391eca66cffc64ec073cfa6f7e79862c687dd73a Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Thu, 30 May 2024 00:20:20 +0800 Subject: [PATCH 7/7] Update loader.py Former-commit-id: 0aa59322906d91c5e385c9c02ebb5dd64ba060f3 --- src/llamafactory/data/loader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index fa5b12c5..d4a19e27 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -115,6 +115,7 @@ def load_single_dataset( expand_indexes = np.random.choice(len(dataset), target_num) indexes = np.concatenate((indexes, expand_indexes), axis=0) + assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched." dataset = dataset.select(indexes) logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))