From c67d2b9327a4a0ce5512e99fc66437a9f5d07f73 Mon Sep 17 00:00:00 2001 From: Ze-Yi LIN <58305964+Zeyi-Lin@users.noreply.github.com> Date: Thu, 6 Mar 2025 00:33:37 +0800 Subject: [PATCH] [trainer] fix swanlab callback (#7176) Former-commit-id: 8ad03258e16309158368384e2a0a707845536133 --- src/llamafactory/train/trainer_utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 15ffae53..45494649 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -600,7 +600,15 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall return super().setup(args, state, model, **kwargs) - swanlab_public_config = self._experiment.get_run().public.json() + + try: + if hasattr(self, "_swanlab"): + swanlab_public_config = self._swanlab.get_run().public.json() + else: # swanlab <= 0.4.9 + swanlab_public_config = self._experiment.get_run().public.json() + except Exception as e: + swanlab_public_config = {} + with open(os.path.join(args.output_dir, SWANLAB_CONFIG), "w") as f: f.write(json.dumps(swanlab_public_config, indent=2))