diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 5f927ddea..9e14c763e 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -50,6 +50,7 @@ if is_apollo_available(): if is_ray_available(): import ray + from ray.util.state import list_nodes from ray.util.placement_group import PlacementGroup, placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -941,7 +942,7 @@ def get_ray_remote_config_for_worker( def get_ray_head_node_ip() -> str: r"""Get the IP address of the Ray head node.""" - head_ip = next(node["NodeManagerAddress"] for node in ray.nodes() if node.get("IsHead", False)) + head_ip = next(node["node_ip"] for node in list_nodes() if node.get("is_head_node", False)) return head_ip