From 4b29d9d2b0dce4e6ffd8c0e6053e6f7a6b3036fe Mon Sep 17 00:00:00 2001 From: codemayq Date: Wed, 23 Aug 2023 18:54:23 +0800 Subject: [PATCH 1/4] add dataset stage and filter dataset when stage chosen in webui Former-commit-id: c0e4d1e81b41c9a36291d8bee46d7d807c898c21 --- data/dataset_info.json | 99 +++++++++++++++++--------- src/llmtuner/extras/constants.py | 8 +++ src/llmtuner/webui/common.py | 11 ++- src/llmtuner/webui/components/train.py | 3 +- 4 files changed, 84 insertions(+), 37 deletions(-) diff --git a/data/dataset_info.json b/data/dataset_info.json index 3eaf920e..5fd4fb1f 100644 --- a/data/dataset_info.json +++ b/data/dataset_info.json @@ -1,23 +1,28 @@ { "alpaca_en": { "file_name": "alpaca_data_en_52k.json", - "file_sha1": "607f94a7f581341e59685aef32f531095232cf23" + "file_sha1": "607f94a7f581341e59685aef32f531095232cf23", + "stage": "sft" }, "alpaca_zh": { "file_name": "alpaca_data_zh_51k.json", - "file_sha1": "e655af3db557a4197f7b0cf92e1986b08fae6311" + "file_sha1": "e655af3db557a4197f7b0cf92e1986b08fae6311", + "stage": "sft" }, "alpaca_gpt4_en": { "file_name": "alpaca_gpt4_data_en.json", - "file_sha1": "647f4ad447bd993e4b6b6223d1be15208bab694a" + "file_sha1": "647f4ad447bd993e4b6b6223d1be15208bab694a", + "stage": "sft" }, "alpaca_gpt4_zh": { "file_name": "alpaca_gpt4_data_zh.json", - "file_sha1": "3eaa3bda364ccdd59925d7448a698256c31ef845" + "file_sha1": "3eaa3bda364ccdd59925d7448a698256c31ef845", + "stage": "sft" }, "self_cognition": { "file_name": "self_cognition.json", - "file_sha1": "6287a730ada924fc5d9eadc6d8f865e01b7a6f67" + "file_sha1": "6287a730ada924fc5d9eadc6d8f865e01b7a6f67", + "stage": "sft" }, "oaast_sft": { "file_name": "oaast_sft.json", @@ -27,7 +32,8 @@ "query": "input", "response": "output", "history": "history" - } + }, + "stage": "sft" }, "oaast_sft_zh": { "file_name": "oaast_sft_zh.json", @@ -37,7 +43,8 @@ "query": "input", "response": "output", "history": "history" - } + }, + "stage": "sft" }, "sharegpt_zh": { "file_name": "sharegpt_zh_27k.json", @@ -47,7 +54,8 @@ "query": "input", "response": "output", "history": "history" - } + }, + "stage": "sft" }, "lima": { "file_name": "lima.json", @@ -57,7 +65,8 @@ "query": "input", "response": "output", "history": "history" - } + }, + "stage": "sft" }, "example": { "script_url": "example_dataset", @@ -66,25 +75,32 @@ "query": "input", "response": "output", "history": "history" - } + }, + "stage": "sft" }, "guanaco": { - "hf_hub_url": "JosephusCheung/GuanacoDataset" + "hf_hub_url": "JosephusCheung/GuanacoDataset", + "stage": "sft" }, "belle_0.5m": { - "hf_hub_url": "BelleGroup/train_0.5M_CN" + "hf_hub_url": "BelleGroup/train_0.5M_CN", + "stage": "sft" }, "belle_1m": { - "hf_hub_url": "BelleGroup/train_1M_CN" + "hf_hub_url": "BelleGroup/train_1M_CN", + "stage": "sft" }, "belle_2m": { - "hf_hub_url": "BelleGroup/train_2M_CN" + "hf_hub_url": "BelleGroup/train_2M_CN", + "stage": "sft" }, "belle_dialog": { - "hf_hub_url": "BelleGroup/generated_chat_0.4M" + "hf_hub_url": "BelleGroup/generated_chat_0.4M", + "stage": "sft" }, "belle_math": { - "hf_hub_url": "BelleGroup/school_math_0.25M" + "hf_hub_url": "BelleGroup/school_math_0.25M", + "stage": "sft" }, "belle_multiturn": { "script_url": "belle_multiturn", @@ -93,7 +109,8 @@ "query": "", "response": "output", "history": "history" - } + }, + "stage": "sft" }, "firefly": { "hf_hub_url": "YeungNLP/firefly-train-1.1M", @@ -102,13 +119,16 @@ "query": "", "response": "target", "history": "" - } + }, + "stage": "sft" }, "codealpaca": { - "hf_hub_url": "sahil2801/CodeAlpaca-20k" + "hf_hub_url": "sahil2801/CodeAlpaca-20k", + "stage": "sft" }, "alpaca_cot": { - "hf_hub_url": "QingyiSi/Alpaca-CoT" + "hf_hub_url": "QingyiSi/Alpaca-CoT", + "stage": "sft" }, "webqa": { "hf_hub_url": "suolyer/webqa", @@ -117,7 +137,8 @@ "query": "", "response": "output", "history": "" - } + }, + "stage": "sft" }, "ultra_chat": { "script_url": "ultra_chat", @@ -126,18 +147,22 @@ "query": "", "response": "output", "history": "history" - } + }, + "stage": "sft" }, "novel_tokens512_50k": { - "hf_hub_url": "zxbsmk/webnovel_cn" + "hf_hub_url": "zxbsmk/webnovel_cn", + "stage": "sft" }, "comparison_gpt4_en": { "file_name": "comparison_gpt4_data_en.json", - "file_sha1": "96fa18313544e22444fe20eead7754b17da452ae" + "file_sha1": "96fa18313544e22444fe20eead7754b17da452ae", + "stage": "rm" }, "comparison_gpt4_zh": { "file_name": "comparison_gpt4_data_zh.json", - "file_sha1": "515b18ed497199131ddcc1af950345c11dc5c7fd" + "file_sha1": "515b18ed497199131ddcc1af950345c11dc5c7fd", + "stage": "rm" }, "hh_rlhf_en": { "script_url": "hh_rlhf_en", @@ -146,7 +171,8 @@ "query": "", "response": "output", "history": "history" - } + }, + "stage": "rm" }, "oaast_rm": { "file_name": "oaast_rm.json", @@ -156,7 +182,8 @@ "query": "input", "response": "output", "history": "history" - } + }, + "stage": "rm" }, "oaast_rm_zh": { "file_name": "oaast_rm_zh.json", @@ -166,7 +193,8 @@ "query": "input", "response": "output", "history": "history" - } + }, + "stage": "rm" }, "wiki_demo": { "file_name": "wiki_demo.txt", @@ -176,7 +204,8 @@ "query": "", "response": "", "history": "" - } + }, + "stage": "pt" }, "refinedweb": { "hf_hub_url": "tiiuae/falcon-refinedweb", @@ -185,7 +214,8 @@ "query": "", "response": "", "history": "" - } + }, + "stage": "pt" }, "starcoder": { "hf_hub_url": "bigcode/starcoderdata", @@ -194,7 +224,8 @@ "query": "", "response": "", "history": "" - } + }, + "stage": "pt" }, "wikipedia_en": { "hf_hub_url": "olm/olm-wikipedia-20221220", @@ -203,7 +234,8 @@ "query": "", "response": "", "history": "" - } + }, + "stage": "pt" }, "wikipedia_zh": { "hf_hub_url": "pleisto/wikipedia-cn-20230720-filtered", @@ -212,6 +244,7 @@ "query": "", "response": "", "history": "" - } + }, + "stage": "pt" } } diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index cd22943f..f10aaaa3 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -18,6 +18,14 @@ STAGES = [ "Pre-Training" ] +DATASET_STAGE_MAP = { + "SFT": "sft", + "Pre-Training": "pt", + "Reward Modeling": "rm", + "PPO": "sft", + "DPO": "rm" +} + SUPPORTED_MODELS = { "LLaMA-7B": "huggyllama/llama-7b", "LLaMA-13B": "huggyllama/llama-13b", diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py index 965a690b..98fface4 100644 --- a/src/llmtuner/webui/common.py +++ b/src/llmtuner/webui/common.py @@ -6,7 +6,7 @@ import gradio as gr from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME -from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS +from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, DATASET_STAGE_MAP DEFAULT_CACHE_DIR = "cache" @@ -78,6 +78,11 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Any]: return {} -def list_dataset(dataset_dir: Optional[str] = None) -> Dict[str, Any]: +def list_dataset(dataset_dir: Optional[str] = None, stage: Optional[str] = None) -> Dict[str, Any]: dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR) - return gr.update(value=[], choices=list(dataset_info.keys())) + if stage: + dataset_stage = DATASET_STAGE_MAP[stage] + dataset_info = {key: value for key, value in dataset_info.items() + if ("stage" not in value) or value["stage"] == dataset_stage} + + return gr.update(value=[], choices=list(dataset_info.keys())) \ No newline at end of file diff --git a/src/llmtuner/webui/components/train.py b/src/llmtuner/webui/components/train.py index aab512ee..7b69944c 100644 --- a/src/llmtuner/webui/components/train.py +++ b/src/llmtuner/webui/components/train.py @@ -22,7 +22,8 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic preview_box, preview_count, preview_samples, close_btn = create_preview_box() - dataset_dir.change(list_dataset, [dataset_dir], [dataset]) + training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset]) + dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset]) dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn]) data_preview_btn.click( get_preview, From b032dc4c4eb89afea9dca31a665e8816b7bbf8cf Mon Sep 17 00:00:00 2001 From: codemayq Date: Wed, 23 Aug 2023 19:55:45 +0800 Subject: [PATCH 2/4] add readme for dataset Former-commit-id: cece66d48a770e3e418496445d4040e3cafa9411 --- data/README.md | 6 ++++-- data/README_zh.md | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/data/README.md b/data/README.md index dc1c8bce..a7375b5d 100644 --- a/data/README.md +++ b/data/README.md @@ -11,7 +11,8 @@ If you are using a custom dataset, please provide your dataset definition in the "query": "the name of the column in the datasets containing the queries. (default: input)", "response": "the name of the column in the datasets containing the responses. (default: output)", "history": "the name of the column in the datasets containing the history of chat. (default: None)" - } + }, + "stage": "The stage at which the data is being used: pt, sft, and rm, which correspond to pre-training, supervised fine-tuning(PPO), and reward model (DPO) training, respectively.(default: None)" } ``` @@ -26,6 +27,7 @@ For datasets used in reward modeling or DPO training, the `response` column shou "output": [ "Chosen answer", "Rejected answer" - ] + ], + "stage": "rm" } ``` diff --git a/data/README_zh.md b/data/README_zh.md index 054ee8ea..e23a3e70 100644 --- a/data/README_zh.md +++ b/data/README_zh.md @@ -11,7 +11,8 @@ "query": "数据集代表请求的表头名称(默认:input)", "response": "数据集代表回答的表头名称(默认:output)", "history": "数据集代表历史对话的表头名称(默认:None)" - } + }, + "stage": "数据所应用的训练阶段,可选值有 pt, sft, rm 三个,对应预训练,指令监督微调(PPO),奖励模型(DPO)训练, 默认为None,表示不限制" } ``` @@ -26,6 +27,7 @@ "output": [ "Chosen answer", "Rejected answer" - ] + ], + "stage": "rm" } ``` From 2b979d39f271b4e3596c06fa3f30e6776494f70b Mon Sep 17 00:00:00 2001 From: codemayq Date: Wed, 23 Aug 2023 20:54:53 +0800 Subject: [PATCH 3/4] add stage in DatasetAttr Former-commit-id: ba94c8729dd8c90bedf7079a6978e150fc92b737 --- src/llmtuner/hparams/data_args.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index 374d03c6..db7702cd 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -11,6 +11,7 @@ class DatasetAttr: dataset_name: Optional[str] = None dataset_sha1: Optional[str] = None system_prompt: Optional[str] = None + stage: Optional[str] = None def __repr__(self) -> str: return self.dataset_name @@ -113,14 +114,21 @@ class DataArguments: raise ValueError("Undefined dataset {} in dataset_info.json.".format(name)) if "hf_hub_url" in dataset_info[name]: - dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) + dataset_attr = DatasetAttr( + "hf_hub", + dataset_name=dataset_info[name]["hf_hub_url"], + stage=dataset_info[name].get("stage", None)) elif "script_url" in dataset_info[name]: - dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) + dataset_attr = DatasetAttr( + "script", + dataset_name=dataset_info[name]["script_url"], + stage=dataset_info[name].get("stage", None)) else: dataset_attr = DatasetAttr( "file", dataset_name=dataset_info[name]["file_name"], - dataset_sha1=dataset_info[name].get("file_sha1", None) + dataset_sha1=dataset_info[name].get("file_sha1", None), + stage=dataset_info[name].get("stage", None) ) if "columns" in dataset_info[name]: From d9b9d9d1fe1a075af350b18f8773bc9c46ecc95c Mon Sep 17 00:00:00 2001 From: codemayq Date: Sun, 27 Aug 2023 20:35:32 +0800 Subject: [PATCH 4/4] add ad gen dataset Former-commit-id: 604f85487b46b3eb01b68cb2cc6535b7cb5527a7 --- README.md | 1 + README_zh.md | 1 + data/dataset_info.json | 10 ++++++++++ 3 files changed, 12 insertions(+) diff --git a/README.md b/README.md index 0db73b16..69dfe649 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,7 @@ - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [UltraChat (en)](https://github.com/thunlp/UltraChat) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) + - [Ad Gen (zh)](https://arxiv.org/abs/1908.06605) - For reward modeling or DPO training: - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) diff --git a/README_zh.md b/README_zh.md index ec4a524c..628a2b10 100644 --- a/README_zh.md +++ b/README_zh.md @@ -105,6 +105,7 @@ - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [UltraChat (en)](https://github.com/thunlp/UltraChat) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) + - [Ad Gen (zh)](https://arxiv.org/abs/1908.06605) - 用于奖励模型或 DPO 训练: - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) diff --git a/data/dataset_info.json b/data/dataset_info.json index 5fd4fb1f..4ae75089 100644 --- a/data/dataset_info.json +++ b/data/dataset_info.json @@ -154,6 +154,16 @@ "hf_hub_url": "zxbsmk/webnovel_cn", "stage": "sft" }, + "ad_gen": { + "hf_hub_url": "HasturOfficial/adgen", + "columns": { + "prompt": "content", + "query": "", + "response": "summary", + "history": "" + }, + "stage": "sft" + }, "comparison_gpt4_en": { "file_name": "comparison_gpt4_data_en.json", "file_sha1": "96fa18313544e22444fe20eead7754b17da452ae",