mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
[webui] display swanlab exp link (#7089)
* webui add swanlab link * change callback name * update --------- Co-authored-by: hiyouga <hiyouga@buaa.edu.cn> Former-commit-id: 891c4875039e8e3b7d0de025ee61c4ff003ff0c4
This commit is contained in:
parent
e86cb8a4fa
commit
210cdb9557
@ -87,6 +87,8 @@ STAGES_USE_PAIR_DATA = {"rm", "dpo"}
|
|||||||
|
|
||||||
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
||||||
|
|
||||||
|
SWANLAB_CONFIG = "swanlab_public_config.json"
|
||||||
|
|
||||||
VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>")
|
VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>")
|
||||||
|
|
||||||
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
||||||
|
@ -17,6 +17,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
@ -31,7 +33,7 @@ from transformers.trainer_pt_utils import get_parameter_names
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.constants import IGNORE_INDEX
|
from ..extras.constants import IGNORE_INDEX, SWANLAB_CONFIG
|
||||||
from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available
|
from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available
|
||||||
from ..hparams import FinetuningArguments, ModelArguments
|
from ..hparams import FinetuningArguments, ModelArguments
|
||||||
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
|
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
|
||||||
@ -51,7 +53,7 @@ if is_ray_available():
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel, TrainerCallback
|
from transformers import PreTrainedModel, TrainerCallback, TrainerState
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from ..hparams import DataArguments, RayArguments, TrainingArguments
|
from ..hparams import DataArguments, RayArguments, TrainingArguments
|
||||||
@ -592,7 +594,17 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
|
|||||||
if finetuning_args.swanlab_api_key is not None:
|
if finetuning_args.swanlab_api_key is not None:
|
||||||
swanlab.login(api_key=finetuning_args.swanlab_api_key)
|
swanlab.login(api_key=finetuning_args.swanlab_api_key)
|
||||||
|
|
||||||
swanlab_callback = SwanLabCallback(
|
class SwanLabCallbackExtension(SwanLabCallback):
|
||||||
|
def setup(self, args: "TrainingArguments", state: "TrainerState", model: "PreTrainedModel", **kwargs):
|
||||||
|
if not state.is_world_process_zero:
|
||||||
|
return
|
||||||
|
|
||||||
|
super().setup(args, state, model, **kwargs)
|
||||||
|
swanlab_public_config = self._experiment.get_run().public.json()
|
||||||
|
with open(os.path.join(args.output_dir, SWANLAB_CONFIG), "w") as f:
|
||||||
|
f.write(json.dumps(swanlab_public_config, indent=2))
|
||||||
|
|
||||||
|
swanlab_callback = SwanLabCallbackExtension(
|
||||||
project=finetuning_args.swanlab_project,
|
project=finetuning_args.swanlab_project,
|
||||||
workspace=finetuning_args.swanlab_workspace,
|
workspace=finetuning_args.swanlab_workspace,
|
||||||
experiment_name=finetuning_args.swanlab_run_name,
|
experiment_name=finetuning_args.swanlab_run_name,
|
||||||
|
@ -299,9 +299,18 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
swanlab_workspace = gr.Textbox()
|
swanlab_workspace = gr.Textbox()
|
||||||
swanlab_api_key = gr.Textbox()
|
swanlab_api_key = gr.Textbox()
|
||||||
swanlab_mode = gr.Dropdown(choices=["cloud", "local"], value="cloud")
|
swanlab_mode = gr.Dropdown(choices=["cloud", "local"], value="cloud")
|
||||||
|
swanlab_link = gr.Markdown(visible=False, container=True)
|
||||||
|
|
||||||
input_elems.update(
|
input_elems.update(
|
||||||
{use_swanlab, swanlab_project, swanlab_run_name, swanlab_workspace, swanlab_api_key, swanlab_mode}
|
{
|
||||||
|
use_swanlab,
|
||||||
|
swanlab_project,
|
||||||
|
swanlab_run_name,
|
||||||
|
swanlab_workspace,
|
||||||
|
swanlab_api_key,
|
||||||
|
swanlab_mode,
|
||||||
|
swanlab_link,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
elem_dict.update(
|
elem_dict.update(
|
||||||
dict(
|
dict(
|
||||||
@ -312,6 +321,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
swanlab_workspace=swanlab_workspace,
|
swanlab_workspace=swanlab_workspace,
|
||||||
swanlab_api_key=swanlab_api_key,
|
swanlab_api_key=swanlab_api_key,
|
||||||
swanlab_mode=swanlab_mode,
|
swanlab_mode=swanlab_mode,
|
||||||
|
swanlab_link=swanlab_link,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -364,7 +374,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
loss_viewer=loss_viewer,
|
loss_viewer=loss_viewer,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
output_elems = [output_box, progress_bar, loss_viewer]
|
output_elems = [output_box, progress_bar, loss_viewer, swanlab_link]
|
||||||
|
|
||||||
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
|
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
|
||||||
start_btn.click(engine.runner.run_train, input_elems, output_elems)
|
start_btn.click(engine.runner.run_train, input_elems, output_elems)
|
||||||
|
@ -23,6 +23,7 @@ from ..extras.constants import (
|
|||||||
PEFT_METHODS,
|
PEFT_METHODS,
|
||||||
RUNNING_LOG,
|
RUNNING_LOG,
|
||||||
STAGES_USE_PAIR_DATA,
|
STAGES_USE_PAIR_DATA,
|
||||||
|
SWANLAB_CONFIG,
|
||||||
TRAINER_LOG,
|
TRAINER_LOG,
|
||||||
TRAINING_STAGES,
|
TRAINING_STAGES,
|
||||||
)
|
)
|
||||||
@ -30,6 +31,7 @@ from ..extras.packages import is_gradio_available, is_matplotlib_available
|
|||||||
from ..extras.ploting import gen_loss_plot
|
from ..extras.ploting import gen_loss_plot
|
||||||
from ..model import QuantizationMethod
|
from ..model import QuantizationMethod
|
||||||
from .common import DEFAULT_CONFIG_DIR, DEFAULT_DATA_DIR, get_model_path, get_save_dir, get_template, load_dataset_info
|
from .common import DEFAULT_CONFIG_DIR, DEFAULT_DATA_DIR, get_model_path, get_save_dir, get_template, load_dataset_info
|
||||||
|
from .locales import ALERTS
|
||||||
|
|
||||||
|
|
||||||
if is_gradio_available():
|
if is_gradio_available():
|
||||||
@ -86,20 +88,20 @@ def get_model_info(model_name: str) -> Tuple[str, str]:
|
|||||||
return get_model_path(model_name), get_template(model_name)
|
return get_model_path(model_name), get_template(model_name)
|
||||||
|
|
||||||
|
|
||||||
def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]:
|
def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Dict[str, Any]]:
|
||||||
r"""
|
r"""
|
||||||
Gets training infomation for monitor.
|
Gets training infomation for monitor.
|
||||||
|
|
||||||
If do_train is True:
|
If do_train is True:
|
||||||
Inputs: train.output_path
|
Inputs: top.lang, train.output_path
|
||||||
Outputs: train.output_box, train.progress_bar, train.loss_viewer
|
Outputs: train.output_box, train.progress_bar, train.loss_viewer, train.swanlab_link
|
||||||
If do_train is False:
|
If do_train is False:
|
||||||
Inputs: eval.output_path
|
Inputs: top.lang, eval.output_path
|
||||||
Outputs: eval.output_box, eval.progress_bar, None
|
Outputs: eval.output_box, eval.progress_bar, None, None
|
||||||
"""
|
"""
|
||||||
running_log = ""
|
running_log = ""
|
||||||
running_progress = gr.Slider(visible=False)
|
running_progress = gr.Slider(visible=False)
|
||||||
running_loss = None
|
running_info = {}
|
||||||
|
|
||||||
running_log_path = os.path.join(output_path, RUNNING_LOG)
|
running_log_path = os.path.join(output_path, RUNNING_LOG)
|
||||||
if os.path.isfile(running_log_path):
|
if os.path.isfile(running_log_path):
|
||||||
@ -125,9 +127,19 @@ def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr
|
|||||||
running_progress = gr.Slider(label=label, value=percentage, visible=True)
|
running_progress = gr.Slider(label=label, value=percentage, visible=True)
|
||||||
|
|
||||||
if do_train and is_matplotlib_available():
|
if do_train and is_matplotlib_available():
|
||||||
running_loss = gr.Plot(gen_loss_plot(trainer_log))
|
running_info["loss_viewer"] = gr.Plot(gen_loss_plot(trainer_log))
|
||||||
|
|
||||||
return running_log, running_progress, running_loss
|
swanlab_config_path = os.path.join(output_path, SWANLAB_CONFIG)
|
||||||
|
if os.path.isfile(swanlab_config_path):
|
||||||
|
with open(swanlab_config_path, encoding="utf-8") as f:
|
||||||
|
swanlab_public_config = json.load(f)
|
||||||
|
swanlab_link = swanlab_public_config["cloud"]["experiment_url"]
|
||||||
|
if swanlab_link is not None:
|
||||||
|
running_info["swanlab_link"] = gr.Markdown(
|
||||||
|
ALERTS["info_swanlab_link"][lang] + swanlab_link, visible=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return running_log, running_progress, running_info
|
||||||
|
|
||||||
|
|
||||||
def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
|
def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
|
||||||
|
@ -2814,4 +2814,11 @@ ALERTS = {
|
|||||||
"ko": "모델이 내보내졌습니다.",
|
"ko": "모델이 내보내졌습니다.",
|
||||||
"ja": "モデルのエクスポートが完了しました。",
|
"ja": "モデルのエクスポートが完了しました。",
|
||||||
},
|
},
|
||||||
|
"info_swanlab_link": {
|
||||||
|
"en": "### SwanLab Link\n",
|
||||||
|
"ru": "### SwanLab ссылка\n",
|
||||||
|
"zh": "### SwanLab 链接\n",
|
||||||
|
"ko": "### SwanLab 링크\n",
|
||||||
|
"ja": "### SwanLab リンク\n",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
@ -423,6 +423,7 @@ class Runner:
|
|||||||
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if self.do_train else "eval"))
|
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if self.do_train else "eval"))
|
||||||
progress_bar = self.manager.get_elem_by_id("{}.progress_bar".format("train" if self.do_train else "eval"))
|
progress_bar = self.manager.get_elem_by_id("{}.progress_bar".format("train" if self.do_train else "eval"))
|
||||||
loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None
|
loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None
|
||||||
|
swanlab_link = self.manager.get_elem_by_id("train.swanlab_link") if self.do_train else None
|
||||||
|
|
||||||
running_log = ""
|
running_log = ""
|
||||||
while self.trainer is not None:
|
while self.trainer is not None:
|
||||||
@ -432,16 +433,18 @@ class Runner:
|
|||||||
progress_bar: gr.Slider(visible=False),
|
progress_bar: gr.Slider(visible=False),
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
running_log, running_progress, running_loss = get_trainer_info(output_path, self.do_train)
|
running_log, running_progress, running_info = get_trainer_info(lang, output_path, self.do_train)
|
||||||
return_dict = {
|
return_dict = {
|
||||||
output_box: running_log,
|
output_box: running_log,
|
||||||
progress_bar: running_progress,
|
progress_bar: running_progress,
|
||||||
}
|
}
|
||||||
if running_loss is not None:
|
if "loss_viewer" in running_info:
|
||||||
return_dict[loss_viewer] = running_loss
|
return_dict[loss_viewer] = running_info["loss_viewer"]
|
||||||
|
|
||||||
|
if "swanlab_link" in running_info:
|
||||||
|
return_dict[swanlab_link] = running_info["swanlab_link"]
|
||||||
|
|
||||||
yield return_dict
|
yield return_dict
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.trainer.wait(2)
|
self.trainer.wait(2)
|
||||||
self.trainer = None
|
self.trainer = None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user