diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index cd22943f..f10aaaa3 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -18,6 +18,14 @@ STAGES = [ "Pre-Training" ] +DATASET_STAGE_MAP = { + "SFT": "sft", + "Pre-Training": "pt", + "Reward Modeling": "rm", + "PPO": "sft", + "DPO": "rm" +} + SUPPORTED_MODELS = { "LLaMA-7B": "huggyllama/llama-7b", "LLaMA-13B": "huggyllama/llama-13b", diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py index 965a690b..98fface4 100644 --- a/src/llmtuner/webui/common.py +++ b/src/llmtuner/webui/common.py @@ -6,7 +6,7 @@ import gradio as gr from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME -from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS +from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, DATASET_STAGE_MAP DEFAULT_CACHE_DIR = "cache" @@ -78,6 +78,11 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Any]: return {} -def list_dataset(dataset_dir: Optional[str] = None) -> Dict[str, Any]: +def list_dataset(dataset_dir: Optional[str] = None, stage: Optional[str] = None) -> Dict[str, Any]: dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR) - return gr.update(value=[], choices=list(dataset_info.keys())) + if stage: + dataset_stage = DATASET_STAGE_MAP[stage] + dataset_info = {key: value for key, value in dataset_info.items() + if ("stage" not in value) or value["stage"] == dataset_stage} + + return gr.update(value=[], choices=list(dataset_info.keys())) \ No newline at end of file diff --git a/src/llmtuner/webui/components/train.py b/src/llmtuner/webui/components/train.py index aab512ee..7b69944c 100644 --- a/src/llmtuner/webui/components/train.py +++ b/src/llmtuner/webui/components/train.py @@ -22,7 +22,8 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic preview_box, preview_count, preview_samples, close_btn = create_preview_box() - dataset_dir.change(list_dataset, [dataset_dir], [dataset]) + training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset]) + dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset]) dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn]) data_preview_btn.click( get_preview,