diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 15ffae53..45494649 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -600,7 +600,15 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall return super().setup(args, state, model, **kwargs) - swanlab_public_config = self._experiment.get_run().public.json() + + try: + if hasattr(self, "_swanlab"): + swanlab_public_config = self._swanlab.get_run().public.json() + else: # swanlab <= 0.4.9 + swanlab_public_config = self._experiment.get_run().public.json() + except Exception as e: + swanlab_public_config = {} + with open(os.path.join(args.output_dir, SWANLAB_CONFIG), "w") as f: f.write(json.dumps(swanlab_public_config, indent=2))