[webui] display swanlab exp link (#7089)

* webui add swanlab link

* change callback name

* update

---------

Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
Former-commit-id: 891c4875039e8e3b7d0de025ee61c4ff003ff0c4
This commit is contained in:
Ze-Yi LIN 2025-02-27 19:40:54 +08:00 committed by GitHub
parent e86cb8a4fa
commit 210cdb9557
6 changed files with 63 additions and 17 deletions

View File

@ -87,6 +87,8 @@ STAGES_USE_PAIR_DATA = {"rm", "dpo"}
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
SWANLAB_CONFIG = "swanlab_public_config.json"
VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>") VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>")
V_HEAD_WEIGHTS_NAME = "value_head.bin" V_HEAD_WEIGHTS_NAME = "value_head.bin"

View File

@ -17,6 +17,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import os
from collections.abc import Mapping from collections.abc import Mapping
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
@ -31,7 +33,7 @@ from transformers.trainer_pt_utils import get_parameter_names
from typing_extensions import override from typing_extensions import override
from ..extras import logging from ..extras import logging
from ..extras.constants import IGNORE_INDEX from ..extras.constants import IGNORE_INDEX, SWANLAB_CONFIG
from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available
from ..hparams import FinetuningArguments, ModelArguments from ..hparams import FinetuningArguments, ModelArguments
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
@ -51,7 +53,7 @@ if is_ray_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel, TrainerCallback from transformers import PreTrainedModel, TrainerCallback, TrainerState
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments, RayArguments, TrainingArguments from ..hparams import DataArguments, RayArguments, TrainingArguments
@ -592,7 +594,17 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
if finetuning_args.swanlab_api_key is not None: if finetuning_args.swanlab_api_key is not None:
swanlab.login(api_key=finetuning_args.swanlab_api_key) swanlab.login(api_key=finetuning_args.swanlab_api_key)
swanlab_callback = SwanLabCallback( class SwanLabCallbackExtension(SwanLabCallback):
def setup(self, args: "TrainingArguments", state: "TrainerState", model: "PreTrainedModel", **kwargs):
if not state.is_world_process_zero:
return
super().setup(args, state, model, **kwargs)
swanlab_public_config = self._experiment.get_run().public.json()
with open(os.path.join(args.output_dir, SWANLAB_CONFIG), "w") as f:
f.write(json.dumps(swanlab_public_config, indent=2))
swanlab_callback = SwanLabCallbackExtension(
project=finetuning_args.swanlab_project, project=finetuning_args.swanlab_project,
workspace=finetuning_args.swanlab_workspace, workspace=finetuning_args.swanlab_workspace,
experiment_name=finetuning_args.swanlab_run_name, experiment_name=finetuning_args.swanlab_run_name,

View File

@ -299,9 +299,18 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
swanlab_workspace = gr.Textbox() swanlab_workspace = gr.Textbox()
swanlab_api_key = gr.Textbox() swanlab_api_key = gr.Textbox()
swanlab_mode = gr.Dropdown(choices=["cloud", "local"], value="cloud") swanlab_mode = gr.Dropdown(choices=["cloud", "local"], value="cloud")
swanlab_link = gr.Markdown(visible=False, container=True)
input_elems.update( input_elems.update(
{use_swanlab, swanlab_project, swanlab_run_name, swanlab_workspace, swanlab_api_key, swanlab_mode} {
use_swanlab,
swanlab_project,
swanlab_run_name,
swanlab_workspace,
swanlab_api_key,
swanlab_mode,
swanlab_link,
}
) )
elem_dict.update( elem_dict.update(
dict( dict(
@ -312,6 +321,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
swanlab_workspace=swanlab_workspace, swanlab_workspace=swanlab_workspace,
swanlab_api_key=swanlab_api_key, swanlab_api_key=swanlab_api_key,
swanlab_mode=swanlab_mode, swanlab_mode=swanlab_mode,
swanlab_link=swanlab_link,
) )
) )
@ -364,7 +374,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
loss_viewer=loss_viewer, loss_viewer=loss_viewer,
) )
) )
output_elems = [output_box, progress_bar, loss_viewer] output_elems = [output_box, progress_bar, loss_viewer, swanlab_link]
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None) cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
start_btn.click(engine.runner.run_train, input_elems, output_elems) start_btn.click(engine.runner.run_train, input_elems, output_elems)

View File

@ -23,6 +23,7 @@ from ..extras.constants import (
PEFT_METHODS, PEFT_METHODS,
RUNNING_LOG, RUNNING_LOG,
STAGES_USE_PAIR_DATA, STAGES_USE_PAIR_DATA,
SWANLAB_CONFIG,
TRAINER_LOG, TRAINER_LOG,
TRAINING_STAGES, TRAINING_STAGES,
) )
@ -30,6 +31,7 @@ from ..extras.packages import is_gradio_available, is_matplotlib_available
from ..extras.ploting import gen_loss_plot from ..extras.ploting import gen_loss_plot
from ..model import QuantizationMethod from ..model import QuantizationMethod
from .common import DEFAULT_CONFIG_DIR, DEFAULT_DATA_DIR, get_model_path, get_save_dir, get_template, load_dataset_info from .common import DEFAULT_CONFIG_DIR, DEFAULT_DATA_DIR, get_model_path, get_save_dir, get_template, load_dataset_info
from .locales import ALERTS
if is_gradio_available(): if is_gradio_available():
@ -86,20 +88,20 @@ def get_model_info(model_name: str) -> Tuple[str, str]:
return get_model_path(model_name), get_template(model_name) return get_model_path(model_name), get_template(model_name)
def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]: def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Dict[str, Any]]:
r""" r"""
Gets training infomation for monitor. Gets training infomation for monitor.
If do_train is True: If do_train is True:
Inputs: train.output_path Inputs: top.lang, train.output_path
Outputs: train.output_box, train.progress_bar, train.loss_viewer Outputs: train.output_box, train.progress_bar, train.loss_viewer, train.swanlab_link
If do_train is False: If do_train is False:
Inputs: eval.output_path Inputs: top.lang, eval.output_path
Outputs: eval.output_box, eval.progress_bar, None Outputs: eval.output_box, eval.progress_bar, None, None
""" """
running_log = "" running_log = ""
running_progress = gr.Slider(visible=False) running_progress = gr.Slider(visible=False)
running_loss = None running_info = {}
running_log_path = os.path.join(output_path, RUNNING_LOG) running_log_path = os.path.join(output_path, RUNNING_LOG)
if os.path.isfile(running_log_path): if os.path.isfile(running_log_path):
@ -125,9 +127,19 @@ def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr
running_progress = gr.Slider(label=label, value=percentage, visible=True) running_progress = gr.Slider(label=label, value=percentage, visible=True)
if do_train and is_matplotlib_available(): if do_train and is_matplotlib_available():
running_loss = gr.Plot(gen_loss_plot(trainer_log)) running_info["loss_viewer"] = gr.Plot(gen_loss_plot(trainer_log))
return running_log, running_progress, running_loss swanlab_config_path = os.path.join(output_path, SWANLAB_CONFIG)
if os.path.isfile(swanlab_config_path):
with open(swanlab_config_path, encoding="utf-8") as f:
swanlab_public_config = json.load(f)
swanlab_link = swanlab_public_config["cloud"]["experiment_url"]
if swanlab_link is not None:
running_info["swanlab_link"] = gr.Markdown(
ALERTS["info_swanlab_link"][lang] + swanlab_link, visible=True
)
return running_log, running_progress, running_info
def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown": def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":

View File

@ -2814,4 +2814,11 @@ ALERTS = {
"ko": "모델이 내보내졌습니다.", "ko": "모델이 내보내졌습니다.",
"ja": "モデルのエクスポートが完了しました。", "ja": "モデルのエクスポートが完了しました。",
}, },
"info_swanlab_link": {
"en": "### SwanLab Link\n",
"ru": "### SwanLab ссылка\n",
"zh": "### SwanLab 链接\n",
"ko": "### SwanLab 링크\n",
"ja": "### SwanLab リンク\n",
},
} }

View File

@ -423,6 +423,7 @@ class Runner:
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if self.do_train else "eval")) output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if self.do_train else "eval"))
progress_bar = self.manager.get_elem_by_id("{}.progress_bar".format("train" if self.do_train else "eval")) progress_bar = self.manager.get_elem_by_id("{}.progress_bar".format("train" if self.do_train else "eval"))
loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None
swanlab_link = self.manager.get_elem_by_id("train.swanlab_link") if self.do_train else None
running_log = "" running_log = ""
while self.trainer is not None: while self.trainer is not None:
@ -432,16 +433,18 @@ class Runner:
progress_bar: gr.Slider(visible=False), progress_bar: gr.Slider(visible=False),
} }
else: else:
running_log, running_progress, running_loss = get_trainer_info(output_path, self.do_train) running_log, running_progress, running_info = get_trainer_info(lang, output_path, self.do_train)
return_dict = { return_dict = {
output_box: running_log, output_box: running_log,
progress_bar: running_progress, progress_bar: running_progress,
} }
if running_loss is not None: if "loss_viewer" in running_info:
return_dict[loss_viewer] = running_loss return_dict[loss_viewer] = running_info["loss_viewer"]
if "swanlab_link" in running_info:
return_dict[swanlab_link] = running_info["swanlab_link"]
yield return_dict yield return_dict
try: try:
self.trainer.wait(2) self.trainer.wait(2)
self.trainer = None self.trainer = None