feat: swanlab params

Former-commit-id: d5cf87990e5bea920ecd1561def09fa17cf328b1
This commit is contained in:
ZeYi Lin 2024-12-19 18:47:27 +08:00
parent 1a48340680
commit cc5cde734b
2 changed files with 37 additions and 5 deletions

View File

@ -308,11 +308,31 @@ class BAdamArgument:
class SwanLabArguments: class SwanLabArguments:
use_swanlab: bool = field( use_swanlab: bool = field(
default=False, default=False,
metadata={"help": ""}, metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tools)."},
) )
swanlab_name: str = field( swanlab_project: str = field(
default="", default="",
metadata={}, metadata={"help": "The project name in SwanLab."},
)
swanlab_workspace: str = field(
default="",
metadata={"help": "The workspace name in SwanLab."},
)
swanlab_experiment_name: str = field(
default="",
metadata={"help": "The experiment name in SwanLab."},
)
swanlab_description: str = field(
default="",
metadata={"help": "The experiment description in SwanLab."},
)
swanlab_mode: Literal["cloud", "local", "disabled"] = field(
default="cloud",
metadata={"help": "The mode of SwanLab."},
)
swanlab_api_key: str = field(
default="",
metadata={"help": "The API key for SwanLab."},
) )

View File

@ -463,6 +463,18 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
r""" r"""
Gets the callback for logging to SwanLab. Gets the callback for logging to SwanLab.
""" """
from swanlab.integration.huggingface import SwanLabCallback import swanlab
from swanlab.integration.transformers import SwanLabCallback
if finetuning_args.swanlab_api_key is not None:
swanlab.login(api_key=finetuning_args.swanlab_api_key)
return SwanLabCallback() swanlab_callback = SwanLabCallback(
project=finetuning_args.swanlab_project,
workspace=finetuning_args.swanlab_workspace,
experiment_name=finetuning_args.swanlab_experiment_name,
description=finetuning_args.swanlab_description,
mode=finetuning_args.swanlab_mode,
)
return swanlab_callback