add dataset stage and filter dataset when stage chosen in webui

This commit is contained in:
codemayq
2023-08-23 18:54:23 +08:00
parent 1c702ad538
commit c0e4d1e81b
4 changed files with 84 additions and 37 deletions

View File

@@ -6,7 +6,7 @@ import gradio as gr
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, DATASET_STAGE_MAP
DEFAULT_CACHE_DIR = "cache"
@@ -78,6 +78,11 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Any]:
return {}
def list_dataset(dataset_dir: Optional[str] = None) -> Dict[str, Any]:
def list_dataset(dataset_dir: Optional[str] = None, stage: Optional[str] = None) -> Dict[str, Any]:
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
return gr.update(value=[], choices=list(dataset_info.keys()))
if stage:
dataset_stage = DATASET_STAGE_MAP[stage]
dataset_info = {key: value for key, value in dataset_info.items()
if ("stage" not in value) or value["stage"] == dataset_stage}
return gr.update(value=[], choices=list(dataset_info.keys()))