diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 968067a1..d49e4e1d 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -164,6 +164,9 @@ def _check_extra_dependencies( if finetuning_args.use_adam_mini: check_version("adam-mini", mandatory=True) + if finetuning_args.use_swanlab: + check_version("swanlab", mandatory=True) + if finetuning_args.plot_loss: check_version("matplotlib", mandatory=True) @@ -345,6 +348,9 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _ if finetuning_args.finetuning_type == "lora": # https://github.com/huggingface/transformers/blob/v4.50.0/src/transformers/trainer.py#L782 training_args.label_names = training_args.label_names or ["labels"] + + if "swanlab" in training_args.report_to and finetuning_args.use_swanlab: + training_args.report_to.remove("swanlab") if ( training_args.parallel_mode == ParallelMode.DISTRIBUTED