diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 69c4ef77..483e9f4a 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -375,6 +375,14 @@ class SwanLabArguments: default=None, metadata={"help": "The log directory for SwanLab."}, ) + swanlab_lark_webhook_url: Optional[str] = field( + default=None, + metadata={"help": "The Lark(飞书) webhook URL for SwanLab."}, + ) + swanlab_lark_secret: Optional[str] = field( + default=None, + metadata={"help": "The Lark(飞书) secret for SwanLab."}, + ) @dataclass diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 2380540e..e9e58d7e 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -599,6 +599,15 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall if finetuning_args.swanlab_api_key is not None: swanlab.login(api_key=finetuning_args.swanlab_api_key) + if finetuning_args.swanlab_lark_webhook_url is not None: + from swanlab.plugin.notification import LarkCallback # type: ignore + + lark_callback = LarkCallback( + webhook_url=finetuning_args.swanlab_lark_webhook_url, + secret=finetuning_args.swanlab_lark_secret, + ) + swanlab.register_callbacks([lark_callback]) + class SwanLabCallbackExtension(SwanLabCallback): def setup(self, args: "TrainingArguments", state: "TrainerState", model: "PreTrainedModel", **kwargs): if not state.is_world_process_zero: