From 4721d0b8ff7a571664af471914681d30c8254bcc Mon Sep 17 00:00:00 2001 From: hzhaoy Date: Tue, 4 Jun 2024 10:33:43 +0800 Subject: [PATCH] add: support selecting saved configuration files and loading training parameters Former-commit-id: b27c4cfcb367f7ab0b56da3ba238c4d9c29ff4e7 --- src/llamafactory/webui/components/train.py | 5 +++-- src/llamafactory/webui/utils.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index 6f742bb1..fabb91ea 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -6,7 +6,7 @@ from ...extras.constants import TRAINING_STAGES from ...extras.misc import get_device_count from ...extras.packages import is_gradio_available from ..common import DEFAULT_DATA_DIR, list_checkpoints, list_datasets -from ..utils import change_stage, check_output_dir, list_output_dirs +from ..utils import change_stage, check_output_dir, list_output_dirs, list_config_paths from .data import create_preview_box @@ -259,7 +259,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Row(): initial_dir = gr.Textbox(visible=False, interactive=False) output_dir = gr.Dropdown(allow_custom_value=True) - config_path = gr.Textbox() + config_path = gr.Dropdown(allow_custom_value=True) with gr.Row(): device_count = gr.Textbox(value=str(get_device_count() or 1), interactive=False) @@ -317,5 +317,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: output_dir.change( list_output_dirs, [model_name, finetuning_type, initial_dir], [output_dir], concurrency_limit=None ).then(check_output_dir, inputs=[lang, model_name, finetuning_type, output_dir], concurrency_limit=None) + config_path.change(list_config_paths, outputs=[config_path], concurrency_limit=None) return elem_dict diff --git a/src/llamafactory/webui/utils.py b/src/llamafactory/webui/utils.py index 09cefa0e..37df1b52 100644 --- a/src/llamafactory/webui/utils.py +++ b/src/llamafactory/webui/utils.py @@ -176,6 +176,18 @@ def list_output_dirs(model_name: str, finetuning_type: str, initial_dir: str) -> return gr.Dropdown(choices=output_dirs) +def list_config_paths() -> "gr.Dropdown": + """ + Lists all the saved configuration files that can be loaded. + """ + if os.path.exists(DEFAULT_CONFIG_DIR) and os.path.isdir(DEFAULT_CONFIG_DIR): + config_files = [file_name for file_name in os.listdir(DEFAULT_CONFIG_DIR) if file_name.endswith(".yaml")] + else: + config_files = [] + + return gr.Dropdown(choices=config_files) + + def check_output_dir(lang: str, model_name: str, finetuning_type: str, output_dir: str) -> None: r""" Check if output dir exists.