[3rdparty] support swanlab lark notification (#7481)

This commit is contained in:
Xu-pixel 2025-03-27 01:52:01 +08:00 committed by GitHub
parent 01166841cf
commit f547334604
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 0 deletions

View File

@ -375,6 +375,14 @@ class SwanLabArguments:
default=None,
metadata={"help": "The log directory for SwanLab."},
)
swanlab_lark_webhook_url: Optional[str] = field(
default=None,
metadata={"help": "The Lark(飞书) webhook URL for SwanLab."},
)
swanlab_lark_secret: Optional[str] = field(
default=None,
metadata={"help": "The Lark(飞书) secret for SwanLab."},
)
@dataclass

View File

@ -599,6 +599,15 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
if finetuning_args.swanlab_api_key is not None:
swanlab.login(api_key=finetuning_args.swanlab_api_key)
if finetuning_args.swanlab_lark_webhook_url is not None:
from swanlab.plugin.notification import LarkCallback # type: ignore
lark_callback = LarkCallback(
webhook_url=finetuning_args.swanlab_lark_webhook_url,
secret=finetuning_args.swanlab_lark_secret,
)
swanlab.register_callbacks([lark_callback])
class SwanLabCallbackExtension(SwanLabCallback):
def setup(self, args: "TrainingArguments", state: "TrainerState", model: "PreTrainedModel", **kwargs):
if not state.is_world_process_zero: