feat: swanlab params

Former-commit-id: d5cf87990e5bea920ecd1561def09fa17cf328b1
This commit is contained in:
ZeYi Lin 2024-12-19 18:47:27 +08:00
parent 1a48340680
commit cc5cde734b
2 changed files with 37 additions and 5 deletions

View File

@ -308,11 +308,31 @@ class BAdamArgument:
class SwanLabArguments:
use_swanlab: bool = field(
default=False,
metadata={"help": ""},
metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tools)."},
)
swanlab_name: str = field(
swanlab_project: str = field(
default="",
metadata={},
metadata={"help": "The project name in SwanLab."},
)
swanlab_workspace: str = field(
default="",
metadata={"help": "The workspace name in SwanLab."},
)
swanlab_experiment_name: str = field(
default="",
metadata={"help": "The experiment name in SwanLab."},
)
swanlab_description: str = field(
default="",
metadata={"help": "The experiment description in SwanLab."},
)
swanlab_mode: Literal["cloud", "local", "disabled"] = field(
default="cloud",
metadata={"help": "The mode of SwanLab."},
)
swanlab_api_key: str = field(
default="",
metadata={"help": "The API key for SwanLab."},
)

View File

@ -463,6 +463,18 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
r"""
Gets the callback for logging to SwanLab.
"""
from swanlab.integration.huggingface import SwanLabCallback
import swanlab
from swanlab.integration.transformers import SwanLabCallback
return SwanLabCallback()
if finetuning_args.swanlab_api_key is not None:
swanlab.login(api_key=finetuning_args.swanlab_api_key)
swanlab_callback = SwanLabCallback(
project=finetuning_args.swanlab_project,
workspace=finetuning_args.swanlab_workspace,
experiment_name=finetuning_args.swanlab_experiment_name,
description=finetuning_args.swanlab_description,
mode=finetuning_args.swanlab_mode,
)
return swanlab_callback