[webui] display swanlab exp link (#7089)

* webui add swanlab link

* change callback name

* update

---------

Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
Former-commit-id: 891c4875039e8e3b7d0de025ee61c4ff003ff0c4
This commit is contained in:
Ze-Yi LIN 2025-02-27 19:40:54 +08:00 committed by GitHub
parent e86cb8a4fa
commit 210cdb9557
6 changed files with 63 additions and 17 deletions

View File

@ -87,6 +87,8 @@ STAGES_USE_PAIR_DATA = {"rm", "dpo"}
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
SWANLAB_CONFIG = "swanlab_public_config.json"
VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>")
V_HEAD_WEIGHTS_NAME = "value_head.bin"

View File

@ -17,6 +17,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
@ -31,7 +33,7 @@ from transformers.trainer_pt_utils import get_parameter_names
from typing_extensions import override
from ..extras import logging
from ..extras.constants import IGNORE_INDEX
from ..extras.constants import IGNORE_INDEX, SWANLAB_CONFIG
from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available
from ..hparams import FinetuningArguments, ModelArguments
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
@ -51,7 +53,7 @@ if is_ray_available():
if TYPE_CHECKING:
from transformers import PreTrainedModel, TrainerCallback
from transformers import PreTrainedModel, TrainerCallback, TrainerState
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments, RayArguments, TrainingArguments
@ -592,7 +594,17 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
if finetuning_args.swanlab_api_key is not None:
swanlab.login(api_key=finetuning_args.swanlab_api_key)
swanlab_callback = SwanLabCallback(
class SwanLabCallbackExtension(SwanLabCallback):
def setup(self, args: "TrainingArguments", state: "TrainerState", model: "PreTrainedModel", **kwargs):
if not state.is_world_process_zero:
return
super().setup(args, state, model, **kwargs)
swanlab_public_config = self._experiment.get_run().public.json()
with open(os.path.join(args.output_dir, SWANLAB_CONFIG), "w") as f:
f.write(json.dumps(swanlab_public_config, indent=2))
swanlab_callback = SwanLabCallbackExtension(
project=finetuning_args.swanlab_project,
workspace=finetuning_args.swanlab_workspace,
experiment_name=finetuning_args.swanlab_run_name,

View File

@ -299,9 +299,18 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
swanlab_workspace = gr.Textbox()
swanlab_api_key = gr.Textbox()
swanlab_mode = gr.Dropdown(choices=["cloud", "local"], value="cloud")
swanlab_link = gr.Markdown(visible=False, container=True)
input_elems.update(
{use_swanlab, swanlab_project, swanlab_run_name, swanlab_workspace, swanlab_api_key, swanlab_mode}
{
use_swanlab,
swanlab_project,
swanlab_run_name,
swanlab_workspace,
swanlab_api_key,
swanlab_mode,
swanlab_link,
}
)
elem_dict.update(
dict(
@ -312,6 +321,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
swanlab_workspace=swanlab_workspace,
swanlab_api_key=swanlab_api_key,
swanlab_mode=swanlab_mode,
swanlab_link=swanlab_link,
)
)
@ -364,7 +374,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
loss_viewer=loss_viewer,
)
)
output_elems = [output_box, progress_bar, loss_viewer]
output_elems = [output_box, progress_bar, loss_viewer, swanlab_link]
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
start_btn.click(engine.runner.run_train, input_elems, output_elems)

View File

@ -23,6 +23,7 @@ from ..extras.constants import (
PEFT_METHODS,
RUNNING_LOG,
STAGES_USE_PAIR_DATA,
SWANLAB_CONFIG,
TRAINER_LOG,
TRAINING_STAGES,
)
@ -30,6 +31,7 @@ from ..extras.packages import is_gradio_available, is_matplotlib_available
from ..extras.ploting import gen_loss_plot
from ..model import QuantizationMethod
from .common import DEFAULT_CONFIG_DIR, DEFAULT_DATA_DIR, get_model_path, get_save_dir, get_template, load_dataset_info
from .locales import ALERTS
if is_gradio_available():
@ -86,20 +88,20 @@ def get_model_info(model_name: str) -> Tuple[str, str]:
return get_model_path(model_name), get_template(model_name)
def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]:
def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Dict[str, Any]]:
r"""
Gets training infomation for monitor.
If do_train is True:
Inputs: train.output_path
Outputs: train.output_box, train.progress_bar, train.loss_viewer
Inputs: top.lang, train.output_path
Outputs: train.output_box, train.progress_bar, train.loss_viewer, train.swanlab_link
If do_train is False:
Inputs: eval.output_path
Outputs: eval.output_box, eval.progress_bar, None
Inputs: top.lang, eval.output_path
Outputs: eval.output_box, eval.progress_bar, None, None
"""
running_log = ""
running_progress = gr.Slider(visible=False)
running_loss = None
running_info = {}
running_log_path = os.path.join(output_path, RUNNING_LOG)
if os.path.isfile(running_log_path):
@ -125,9 +127,19 @@ def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr
running_progress = gr.Slider(label=label, value=percentage, visible=True)
if do_train and is_matplotlib_available():
running_loss = gr.Plot(gen_loss_plot(trainer_log))
running_info["loss_viewer"] = gr.Plot(gen_loss_plot(trainer_log))
return running_log, running_progress, running_loss
swanlab_config_path = os.path.join(output_path, SWANLAB_CONFIG)
if os.path.isfile(swanlab_config_path):
with open(swanlab_config_path, encoding="utf-8") as f:
swanlab_public_config = json.load(f)
swanlab_link = swanlab_public_config["cloud"]["experiment_url"]
if swanlab_link is not None:
running_info["swanlab_link"] = gr.Markdown(
ALERTS["info_swanlab_link"][lang] + swanlab_link, visible=True
)
return running_log, running_progress, running_info
def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":

View File

@ -2814,4 +2814,11 @@ ALERTS = {
"ko": "모델이 내보내졌습니다.",
"ja": "モデルのエクスポートが完了しました。",
},
"info_swanlab_link": {
"en": "### SwanLab Link\n",
"ru": "### SwanLab ссылка\n",
"zh": "### SwanLab 链接\n",
"ko": "### SwanLab 링크\n",
"ja": "### SwanLab リンク\n",
},
}

View File

@ -423,6 +423,7 @@ class Runner:
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if self.do_train else "eval"))
progress_bar = self.manager.get_elem_by_id("{}.progress_bar".format("train" if self.do_train else "eval"))
loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None
swanlab_link = self.manager.get_elem_by_id("train.swanlab_link") if self.do_train else None
running_log = ""
while self.trainer is not None:
@ -432,16 +433,18 @@ class Runner:
progress_bar: gr.Slider(visible=False),
}
else:
running_log, running_progress, running_loss = get_trainer_info(output_path, self.do_train)
running_log, running_progress, running_info = get_trainer_info(lang, output_path, self.do_train)
return_dict = {
output_box: running_log,
progress_bar: running_progress,
}
if running_loss is not None:
return_dict[loss_viewer] = running_loss
if "loss_viewer" in running_info:
return_dict[loss_viewer] = running_info["loss_viewer"]
if "swanlab_link" in running_info:
return_dict[swanlab_link] = running_info["swanlab_link"]
yield return_dict
try:
self.trainer.wait(2)
self.trainer = None