From cbc7db34784f336db087664063a15601cd410f9d Mon Sep 17 00:00:00 2001 From: codemayq Date: Wed, 23 Aug 2023 18:54:23 +0800 Subject: [PATCH 1/4] add dataset stage and filter dataset when stage chosen in webui Former-commit-id: 26e4136449a4df6028d834fd16a0f4a7c532759d --- src/llmtuner/extras/constants.py | 8 ++++++++ src/llmtuner/webui/common.py | 11 ++++++++--- src/llmtuner/webui/components/train.py | 3 ++- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index cd22943f..f10aaaa3 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -18,6 +18,14 @@ STAGES = [ "Pre-Training" ] +DATASET_STAGE_MAP = { + "SFT": "sft", + "Pre-Training": "pt", + "Reward Modeling": "rm", + "PPO": "sft", + "DPO": "rm" +} + SUPPORTED_MODELS = { "LLaMA-7B": "huggyllama/llama-7b", "LLaMA-13B": "huggyllama/llama-13b", diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py index 965a690b..98fface4 100644 --- a/src/llmtuner/webui/common.py +++ b/src/llmtuner/webui/common.py @@ -6,7 +6,7 @@ import gradio as gr from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME -from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS +from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, DATASET_STAGE_MAP DEFAULT_CACHE_DIR = "cache" @@ -78,6 +78,11 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Any]: return {} -def list_dataset(dataset_dir: Optional[str] = None) -> Dict[str, Any]: +def list_dataset(dataset_dir: Optional[str] = None, stage: Optional[str] = None) -> Dict[str, Any]: dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR) - return gr.update(value=[], choices=list(dataset_info.keys())) + if stage: + dataset_stage = DATASET_STAGE_MAP[stage] + dataset_info = {key: value for key, value in dataset_info.items() + if ("stage" not in value) or value["stage"] == dataset_stage} + + return gr.update(value=[], choices=list(dataset_info.keys())) \ No newline at end of file diff --git a/src/llmtuner/webui/components/train.py b/src/llmtuner/webui/components/train.py index aab512ee..7b69944c 100644 --- a/src/llmtuner/webui/components/train.py +++ b/src/llmtuner/webui/components/train.py @@ -22,7 +22,8 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic preview_box, preview_count, preview_samples, close_btn = create_preview_box() - dataset_dir.change(list_dataset, [dataset_dir], [dataset]) + training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset]) + dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset]) dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn]) data_preview_btn.click( get_preview, From a6662b73f5b02226034c833f689556c575f70326 Mon Sep 17 00:00:00 2001 From: codemayq Date: Wed, 23 Aug 2023 19:55:45 +0800 Subject: [PATCH 2/4] add readme for dataset Former-commit-id: bdcb0ea40e726e4c5752f938b379ed9a18e7e1d0 --- data/README.md | 6 ++++-- data/README_zh.md | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/data/README.md b/data/README.md index dc1c8bce..a7375b5d 100644 --- a/data/README.md +++ b/data/README.md @@ -11,7 +11,8 @@ If you are using a custom dataset, please provide your dataset definition in the "query": "the name of the column in the datasets containing the queries. (default: input)", "response": "the name of the column in the datasets containing the responses. (default: output)", "history": "the name of the column in the datasets containing the history of chat. (default: None)" - } + }, + "stage": "The stage at which the data is being used: pt, sft, and rm, which correspond to pre-training, supervised fine-tuning(PPO), and reward model (DPO) training, respectively.(default: None)" } ``` @@ -26,6 +27,7 @@ For datasets used in reward modeling or DPO training, the `response` column shou "output": [ "Chosen answer", "Rejected answer" - ] + ], + "stage": "rm" } ``` diff --git a/data/README_zh.md b/data/README_zh.md index 054ee8ea..e23a3e70 100644 --- a/data/README_zh.md +++ b/data/README_zh.md @@ -11,7 +11,8 @@ "query": "数据集代表请求的表头名称(默认:input)", "response": "数据集代表回答的表头名称(默认:output)", "history": "数据集代表历史对话的表头名称(默认:None)" - } + }, + "stage": "数据所应用的训练阶段,可选值有 pt, sft, rm 三个,对应预训练,指令监督微调(PPO),奖励模型(DPO)训练, 默认为None,表示不限制" } ``` @@ -26,6 +27,7 @@ "output": [ "Chosen answer", "Rejected answer" - ] + ], + "stage": "rm" } ``` From d3fd8f89b8a98ae2f42eeb4f3316279ac377077c Mon Sep 17 00:00:00 2001 From: codemayq Date: Wed, 23 Aug 2023 20:54:53 +0800 Subject: [PATCH 3/4] add stage in DatasetAttr Former-commit-id: 9c55200d8de0623640f529dbf39b8b0f169636d3 --- src/llmtuner/hparams/data_args.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index 374d03c6..db7702cd 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -11,6 +11,7 @@ class DatasetAttr: dataset_name: Optional[str] = None dataset_sha1: Optional[str] = None system_prompt: Optional[str] = None + stage: Optional[str] = None def __repr__(self) -> str: return self.dataset_name @@ -113,14 +114,21 @@ class DataArguments: raise ValueError("Undefined dataset {} in dataset_info.json.".format(name)) if "hf_hub_url" in dataset_info[name]: - dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) + dataset_attr = DatasetAttr( + "hf_hub", + dataset_name=dataset_info[name]["hf_hub_url"], + stage=dataset_info[name].get("stage", None)) elif "script_url" in dataset_info[name]: - dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) + dataset_attr = DatasetAttr( + "script", + dataset_name=dataset_info[name]["script_url"], + stage=dataset_info[name].get("stage", None)) else: dataset_attr = DatasetAttr( "file", dataset_name=dataset_info[name]["file_name"], - dataset_sha1=dataset_info[name].get("file_sha1", None) + dataset_sha1=dataset_info[name].get("file_sha1", None), + stage=dataset_info[name].get("stage", None) ) if "columns" in dataset_info[name]: From b869bc1a20c1c8eb73d50ec4730ebc6941e0dca8 Mon Sep 17 00:00:00 2001 From: codemayq Date: Sun, 27 Aug 2023 20:35:32 +0800 Subject: [PATCH 4/4] add ad gen dataset Former-commit-id: fcd0788aa4dda0cecc1420d369d371032a207810 --- README.md | 1 + README_zh.md | 1 + 2 files changed, 2 insertions(+) diff --git a/README.md b/README.md index 0db73b16..69dfe649 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,7 @@ - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [UltraChat (en)](https://github.com/thunlp/UltraChat) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) + - [Ad Gen (zh)](https://arxiv.org/abs/1908.06605) - For reward modeling or DPO training: - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) diff --git a/README_zh.md b/README_zh.md index ec4a524c..628a2b10 100644 --- a/README_zh.md +++ b/README_zh.md @@ -105,6 +105,7 @@ - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [UltraChat (en)](https://github.com/thunlp/UltraChat) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) + - [Ad Gen (zh)](https://arxiv.org/abs/1908.06605) - 用于奖励模型或 DPO 训练: - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)