mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-15 16:18:10 +08:00
add dataset stage and filter dataset when stage chosen in webui
Former-commit-id: 26e4136449a4df6028d834fd16a0f4a7c532759d
This commit is contained in:
parent
4606340f0f
commit
cbc7db3478
@ -18,6 +18,14 @@ STAGES = [
|
|||||||
"Pre-Training"
|
"Pre-Training"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
DATASET_STAGE_MAP = {
|
||||||
|
"SFT": "sft",
|
||||||
|
"Pre-Training": "pt",
|
||||||
|
"Reward Modeling": "rm",
|
||||||
|
"PPO": "sft",
|
||||||
|
"DPO": "rm"
|
||||||
|
}
|
||||||
|
|
||||||
SUPPORTED_MODELS = {
|
SUPPORTED_MODELS = {
|
||||||
"LLaMA-7B": "huggyllama/llama-7b",
|
"LLaMA-7B": "huggyllama/llama-7b",
|
||||||
"LLaMA-13B": "huggyllama/llama-13b",
|
"LLaMA-13B": "huggyllama/llama-13b",
|
||||||
|
@ -6,7 +6,7 @@ import gradio as gr
|
|||||||
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
|
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
|
||||||
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||||
|
|
||||||
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS
|
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, DATASET_STAGE_MAP
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CACHE_DIR = "cache"
|
DEFAULT_CACHE_DIR = "cache"
|
||||||
@ -78,6 +78,11 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Any]:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def list_dataset(dataset_dir: Optional[str] = None) -> Dict[str, Any]:
|
def list_dataset(dataset_dir: Optional[str] = None, stage: Optional[str] = None) -> Dict[str, Any]:
|
||||||
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
|
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
|
||||||
return gr.update(value=[], choices=list(dataset_info.keys()))
|
if stage:
|
||||||
|
dataset_stage = DATASET_STAGE_MAP[stage]
|
||||||
|
dataset_info = {key: value for key, value in dataset_info.items()
|
||||||
|
if ("stage" not in value) or value["stage"] == dataset_stage}
|
||||||
|
|
||||||
|
return gr.update(value=[], choices=list(dataset_info.keys()))
|
@ -22,7 +22,8 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
|
|||||||
|
|
||||||
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
||||||
|
|
||||||
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
|
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset])
|
||||||
|
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset])
|
||||||
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
|
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
|
||||||
data_preview_btn.click(
|
data_preview_btn.click(
|
||||||
get_preview,
|
get_preview,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user