diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index c5d926ac..3adb382b 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -38,6 +38,7 @@ from .trainer_utils import get_ray_trainer, get_swanlab_callback if is_ray_available(): + import ray from ray.train.huggingface.transformers import RayTrainReportCallback @@ -77,6 +78,9 @@ def _training_function(config: dict[str, Any]) -> None: else: raise ValueError(f"Unknown task: {finetuning_args.stage}.") + if is_ray_available() and ray.is_initialized(): + return # if ray is intialized it will destroy the process group on return + try: if dist.is_initialized(): dist.destroy_process_group()