feat: swanlab params

Former-commit-id: 761b3bdb03e27826fde2ca86d4e37b53c2bbc777
This commit is contained in:
ZeYi Lin 2024-12-19 18:47:27 +08:00
parent 7eeeffdb8a
commit 44dfbf9dbd
2 changed files with 37 additions and 5 deletions

View File

@ -308,11 +308,31 @@ class BAdamArgument:
class SwanLabArguments: class SwanLabArguments:
use_swanlab: bool = field( use_swanlab: bool = field(
default=False, default=False,
metadata={"help": ""}, metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tools)."},
) )
swanlab_name: str = field( swanlab_project: str = field(
default="", default="",
metadata={}, metadata={"help": "The project name in SwanLab."},
)
swanlab_workspace: str = field(
default="",
metadata={"help": "The workspace name in SwanLab."},
)
swanlab_experiment_name: str = field(
default="",
metadata={"help": "The experiment name in SwanLab."},
)
swanlab_description: str = field(
default="",
metadata={"help": "The experiment description in SwanLab."},
)
swanlab_mode: Literal["cloud", "local", "disabled"] = field(
default="cloud",
metadata={"help": "The mode of SwanLab."},
)
swanlab_api_key: str = field(
default="",
metadata={"help": "The API key for SwanLab."},
) )

View File

@ -463,6 +463,18 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
r""" r"""
Gets the callback for logging to SwanLab. Gets the callback for logging to SwanLab.
""" """
from swanlab.integration.huggingface import SwanLabCallback import swanlab
from swanlab.integration.transformers import SwanLabCallback
if finetuning_args.swanlab_api_key is not None:
swanlab.login(api_key=finetuning_args.swanlab_api_key)
return SwanLabCallback() swanlab_callback = SwanLabCallback(
project=finetuning_args.swanlab_project,
workspace=finetuning_args.swanlab_workspace,
experiment_name=finetuning_args.swanlab_experiment_name,
description=finetuning_args.swanlab_description,
mode=finetuning_args.swanlab_mode,
)
return swanlab_callback