From 53103f55b695bd40b78e4329f6f304d3b60b0743 Mon Sep 17 00:00:00 2001 From: ZeYi Lin <944270057@qq.com> Date: Thu, 19 Dec 2024 19:04:19 +0800 Subject: [PATCH] feat: optimize frontend Former-commit-id: 8c2df41b937f491f7ebf593b20c65a19738c7642 --- src/llamafactory/train/trainer_utils.py | 2 +- src/llamafactory/webui/components/train.py | 12 +++ src/llamafactory/webui/locales.py | 106 +++++++++++++++++++++ src/llamafactory/webui/runner.py | 8 +- 4 files changed, 126 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index e13fe552..583ca3ee 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -465,7 +465,7 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall """ import swanlab from swanlab.integration.transformers import SwanLabCallback - + if finetuning_args.swanlab_api_key is not None: swanlab.login(api_key=finetuning_args.swanlab_api_key) diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index bc776acd..f905e43e 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -273,12 +273,24 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Accordion(open=False) as swanlab_tab: with gr.Row(): use_swanlab = gr.Checkbox() + swanlab_api_key = gr.Textbox(value="", placeholder="API key", interactive=True) + swanlab_project = gr.Textbox(value="", placeholder="Project name", interactive=True) + swanlab_workspace = gr.Textbox(value="", placeholder="Workspace name", interactive=True) + swanlab_experiment_name = gr.Textbox(value="", placeholder="Experiment name", interactive=True) + swanlab_description = gr.Textbox(value="", placeholder="Experiment description", interactive=True) + swanlab_mode = gr.Dropdown(choices=["cloud", "local", "disabled"], value="cloud", interactive=True) input_elems.update({use_swanlab}) elem_dict.update( dict( swanlab_tab=swanlab_tab, use_swanlab=use_swanlab, + swanlab_api_key=swanlab_api_key, + swanlab_project=swanlab_project, + swanlab_workspace=swanlab_workspace, + swanlab_experiment_name=swanlab_experiment_name, + swanlab_description=swanlab_description, + swanlab_mode=swanlab_mode, ) ) diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index c88800d2..2167bfd1 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -1367,6 +1367,112 @@ LOCALES = { "label": "SwanLab 설정", }, }, + "use_swanlab": { + "en": { + "label": "Use SwanLab", + "info": "Enable SwanLab for experiment tracking and visualization.", + }, + "ru": { + "label": "Использовать SwanLab", + "info": "Включить SwanLab для отслеживания и визуализации экспериментов.", + }, + "zh": { + "label": "使用 SwanLab", + "info": "启用 SwanLab 进行实验跟踪和可视化。", + }, + "ko": { + "label": "SwanLab 사용", + "info": "SwanLab를 사용하여 실험을 추적하고 시각화합니다.", + }, + }, + "swanlab_api_key": { + "en": { + "label": "API key", + "info": "The API key for SwanLab.", + }, + "ru": { + "label": "API ключ", + "info": "API ключ для SwanLab.", + }, + "zh": { + "label": "API 密钥", + "info": "SwanLab 的 API 密钥。", + }, + "ko": { + "label": "API 키", + "info": "SwanLab의 API 키.", + }, + }, + "swanlab_project": { + "en": { + "label": "SwanLab project", + }, + "ru": { + "label": "Проект SwanLab", + }, + "zh": { + "label": "SwanLab 项目", + }, + "ko": { + "label": "SwanLab 프로젝트", + }, + }, + "swanlab_workspace": { + "en": { + "label": "SwanLab workspace", + }, + "ru": { + "label": "Рабочая область SwanLab", + }, + "zh": { + "label": "SwanLab 工作区", + }, + "ko": { + "label": "SwanLab 작업 영역", + }, + }, + "swanlab_experiment_name": { + "en": { + "label": "SwanLab experiment name", + }, + "ru": { + "label": "Имя эксперимента SwanLab", + }, + "zh": { + "label": "SwanLab 实验名称", + }, + "ko": { + "label": "SwanLab 실험 이름", + }, + }, + "swanlab_description": { + "en": { + "label": "SwanLab experiment description", + }, + "ru": { + "label": "Описание эксперимента SwanLab", + }, + "zh": { + "label": "SwanLab 实验描述", + }, + "ko": { + "label": "SwanLab 실험 설명", + }, + }, + "swanlab_mode": { + "en": { + "label": "SwanLab mode", + }, + "ru": { + "label": "Режим SwanLab", + }, + "zh": { + "label": "SwanLab 模式", + }, + "ko": { + "label": "SwanLab 모드", + }, + }, "cmd_preview_btn": { "en": { "value": "Preview command", diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index 7bd61390..6c2b4feb 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -231,7 +231,13 @@ class Runner: # swanlab config if get("train.use_swanlab"): - args["swanlab_name"] = get("train.swanlab_name") + args["swanlab_api_key"] = get("train.swanlab_api_key") + args["swanlab_project"] = get("train.swanlab_project") + args["swanlab_workspace"] = get("train.swanlab_workspace") + args["swanlab_experiment_name"] = get("train.swanlab_experiment_name") + args["swanlab_description"] = get("train.swanlab_description") + args["swanlab_mode"] = get("train.swanlab_mode") + # eval config if get("train.val_size") > 1e-6 and args["stage"] != "ppo":