Merge pull request #651 from hiyouga/feature-dataset_stage

add dataset stage

Former-commit-id: 3b0ef57405cbc22ff8ce4eef2cfcb73872519db5
This commit is contained in:
codingma 2023-08-28 16:03:45 +08:00 committed by GitHub
commit f7658db1b6
8 changed files with 39 additions and 11 deletions

View File

@ -105,6 +105,7 @@
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [UltraChat (en)](https://github.com/thunlp/UltraChat) - [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Ad Gen (zh)](https://arxiv.org/abs/1908.06605)
- For reward modeling or DPO training: - For reward modeling or DPO training:
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)

View File

@ -105,6 +105,7 @@
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [UltraChat (en)](https://github.com/thunlp/UltraChat) - [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Ad Gen (zh)](https://arxiv.org/abs/1908.06605)
- 用于奖励模型或 DPO 训练: - 用于奖励模型或 DPO 训练:
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)

View File

@ -11,7 +11,8 @@ If you are using a custom dataset, please provide your dataset definition in the
"query": "the name of the column in the datasets containing the queries. (default: input)", "query": "the name of the column in the datasets containing the queries. (default: input)",
"response": "the name of the column in the datasets containing the responses. (default: output)", "response": "the name of the column in the datasets containing the responses. (default: output)",
"history": "the name of the column in the datasets containing the history of chat. (default: None)" "history": "the name of the column in the datasets containing the history of chat. (default: None)"
} },
"stage": "The stage at which the data is being used: pt, sft, and rm, which correspond to pre-training, supervised fine-tuning(PPO), and reward model (DPO) training, respectively.(default: None)"
} }
``` ```
@ -26,6 +27,7 @@ For datasets used in reward modeling or DPO training, the `response` column shou
"output": [ "output": [
"Chosen answer", "Chosen answer",
"Rejected answer" "Rejected answer"
] ],
"stage": "rm"
} }
``` ```

View File

@ -11,7 +11,8 @@
"query": "数据集代表请求的表头名称默认input", "query": "数据集代表请求的表头名称默认input",
"response": "数据集代表回答的表头名称默认output", "response": "数据集代表回答的表头名称默认output",
"history": "数据集代表历史对话的表头名称默认None" "history": "数据集代表历史对话的表头名称默认None"
} },
"stage": "数据所应用的训练阶段,可选值有 pt, sft, rm 三个,对应预训练,指令监督微调(PPO),奖励模型(DPO)训练, 默认为None表示不限制"
} }
``` ```
@ -26,6 +27,7 @@
"output": [ "output": [
"Chosen answer", "Chosen answer",
"Rejected answer" "Rejected answer"
] ],
"stage": "rm"
} }
``` ```

View File

@ -18,6 +18,14 @@ STAGES = [
"Pre-Training" "Pre-Training"
] ]
DATASET_STAGE_MAP = {
"SFT": "sft",
"Pre-Training": "pt",
"Reward Modeling": "rm",
"PPO": "sft",
"DPO": "rm"
}
SUPPORTED_MODELS = { SUPPORTED_MODELS = {
"LLaMA-7B": "huggyllama/llama-7b", "LLaMA-7B": "huggyllama/llama-7b",
"LLaMA-13B": "huggyllama/llama-13b", "LLaMA-13B": "huggyllama/llama-13b",

View File

@ -11,6 +11,7 @@ class DatasetAttr:
dataset_name: Optional[str] = None dataset_name: Optional[str] = None
dataset_sha1: Optional[str] = None dataset_sha1: Optional[str] = None
system_prompt: Optional[str] = None system_prompt: Optional[str] = None
stage: Optional[str] = None
def __repr__(self) -> str: def __repr__(self) -> str:
return self.dataset_name return self.dataset_name
@ -113,14 +114,21 @@ class DataArguments:
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name)) raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
if "hf_hub_url" in dataset_info[name]: if "hf_hub_url" in dataset_info[name]:
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) dataset_attr = DatasetAttr(
"hf_hub",
dataset_name=dataset_info[name]["hf_hub_url"],
stage=dataset_info[name].get("stage", None))
elif "script_url" in dataset_info[name]: elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) dataset_attr = DatasetAttr(
"script",
dataset_name=dataset_info[name]["script_url"],
stage=dataset_info[name].get("stage", None))
else: else:
dataset_attr = DatasetAttr( dataset_attr = DatasetAttr(
"file", "file",
dataset_name=dataset_info[name]["file_name"], dataset_name=dataset_info[name]["file_name"],
dataset_sha1=dataset_info[name].get("file_sha1", None) dataset_sha1=dataset_info[name].get("file_sha1", None),
stage=dataset_info[name].get("stage", None)
) )
if "columns" in dataset_info[name]: if "columns" in dataset_info[name]:

View File

@ -6,7 +6,7 @@ import gradio as gr
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, DATASET_STAGE_MAP
DEFAULT_CACHE_DIR = "cache" DEFAULT_CACHE_DIR = "cache"
@ -78,6 +78,11 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Any]:
return {} return {}
def list_dataset(dataset_dir: Optional[str] = None) -> Dict[str, Any]: def list_dataset(dataset_dir: Optional[str] = None, stage: Optional[str] = None) -> Dict[str, Any]:
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR) dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
if stage:
dataset_stage = DATASET_STAGE_MAP[stage]
dataset_info = {key: value for key, value in dataset_info.items()
if ("stage" not in value) or value["stage"] == dataset_stage}
return gr.update(value=[], choices=list(dataset_info.keys())) return gr.update(value=[], choices=list(dataset_info.keys()))

View File

@ -22,7 +22,8 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
preview_box, preview_count, preview_samples, close_btn = create_preview_box() preview_box, preview_count, preview_samples, close_btn = create_preview_box()
dataset_dir.change(list_dataset, [dataset_dir], [dataset]) training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset])
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn]) dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
data_preview_btn.click( data_preview_btn.click(
get_preview, get_preview,