add dataset stage and filter dataset when stage chosen in webui

Former-commit-id: 26e4136449a4df6028d834fd16a0f4a7c532759d
This commit is contained in:
codemayq 2023-08-23 18:54:23 +08:00
parent 4606340f0f
commit cbc7db3478
3 changed files with 18 additions and 4 deletions

View File

@ -18,6 +18,14 @@ STAGES = [
"Pre-Training"
]
DATASET_STAGE_MAP = {
"SFT": "sft",
"Pre-Training": "pt",
"Reward Modeling": "rm",
"PPO": "sft",
"DPO": "rm"
}
SUPPORTED_MODELS = {
"LLaMA-7B": "huggyllama/llama-7b",
"LLaMA-13B": "huggyllama/llama-13b",

View File

@ -6,7 +6,7 @@ import gradio as gr
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, DATASET_STAGE_MAP
DEFAULT_CACHE_DIR = "cache"
@ -78,6 +78,11 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Any]:
return {}
def list_dataset(dataset_dir: Optional[str] = None) -> Dict[str, Any]:
def list_dataset(dataset_dir: Optional[str] = None, stage: Optional[str] = None) -> Dict[str, Any]:
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
if stage:
dataset_stage = DATASET_STAGE_MAP[stage]
dataset_info = {key: value for key, value in dataset_info.items()
if ("stage" not in value) or value["stage"] == dataset_stage}
return gr.update(value=[], choices=list(dataset_info.keys()))

View File

@ -22,7 +22,8 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset])
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
data_preview_btn.click(
get_preview,