feat: swanlab params

This commit is contained in:
ZeYi Lin
2024-12-19 18:47:27 +08:00
parent 96f8f103e5
commit d5cf87990e
2 changed files with 37 additions and 5 deletions

View File

@@ -308,11 +308,31 @@ class BAdamArgument:
class SwanLabArguments: class SwanLabArguments:
use_swanlab: bool = field( use_swanlab: bool = field(
default=False, default=False,
metadata={"help": ""}, metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tools)."},
) )
swanlab_name: str = field( swanlab_project: str = field(
default="", default="",
metadata={}, metadata={"help": "The project name in SwanLab."},
)
swanlab_workspace: str = field(
default="",
metadata={"help": "The workspace name in SwanLab."},
)
swanlab_experiment_name: str = field(
default="",
metadata={"help": "The experiment name in SwanLab."},
)
swanlab_description: str = field(
default="",
metadata={"help": "The experiment description in SwanLab."},
)
swanlab_mode: Literal["cloud", "local", "disabled"] = field(
default="cloud",
metadata={"help": "The mode of SwanLab."},
)
swanlab_api_key: str = field(
default="",
metadata={"help": "The API key for SwanLab."},
) )

View File

@@ -463,6 +463,18 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
r""" r"""
Gets the callback for logging to SwanLab. Gets the callback for logging to SwanLab.
""" """
from swanlab.integration.huggingface import SwanLabCallback import swanlab
from swanlab.integration.transformers import SwanLabCallback
return SwanLabCallback() if finetuning_args.swanlab_api_key is not None:
swanlab.login(api_key=finetuning_args.swanlab_api_key)
swanlab_callback = SwanLabCallback(
project=finetuning_args.swanlab_project,
workspace=finetuning_args.swanlab_workspace,
experiment_name=finetuning_args.swanlab_experiment_name,
description=finetuning_args.swanlab_description,
mode=finetuning_args.swanlab_mode,
)
return swanlab_callback