[feature] support using ray.remote to start distributed training. (#10109)

This commit is contained in:
xvxuopop
2026-01-28 16:05:29 +08:00
committed by GitHub
parent 9640f79ae5
commit 762b480131
4 changed files with 221 additions and 80 deletions

View File

@@ -20,7 +20,6 @@
import json
import os
from collections.abc import Callable, Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union
import torch
@@ -34,6 +33,7 @@ from typing_extensions import override
from ..extras import logging
from ..extras.constants import IGNORE_INDEX, SWANLAB_CONFIG
from ..extras.misc import get_device_name
from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available
from ..hparams import FinetuningArguments, ModelArguments
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
@@ -49,15 +49,15 @@ if is_apollo_available():
if is_ray_available():
import ray
from ray.train import RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer
from ray.util.placement_group import PlacementGroup, placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING:
from transformers import PreTrainedModel, TrainerCallback, TrainerState
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments, RayArguments, TrainingArguments
from ..hparams import DataArguments, TrainingArguments
logger = logging.get_logger(__name__)
@@ -807,36 +807,88 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
return swanlab_callback
def get_ray_trainer(
training_function: Callable,
train_loop_config: dict[str, Any],
ray_args: "RayArguments",
) -> "TorchTrainer":
if not ray_args.use_ray:
raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.")
def get_placement_group(num_workers: int) -> tuple["PlacementGroup", dict[str, int]]:
r"""Get the Ray placement group for distributed training."""
bundle = {"CPU": 10}
device_name = get_device_name().upper()
if device_name != "CPU":
bundle[device_name] = 1
bundles = [bundle for _ in range(num_workers)]
pg = placement_group(bundles, strategy="PACK")
if ray_args.ray_init_kwargs is not None:
ray.init(**ray_args.ray_init_kwargs)
return pg, bundle
if ray_args.ray_storage_filesystem is not None:
# this means we are using s3/gcs
storage_path = ray_args.ray_storage_path
else:
storage_path = Path(ray_args.ray_storage_path).absolute().as_posix()
trainer = TorchTrainer(
training_function,
train_loop_config=train_loop_config,
scaling_config=ScalingConfig(
num_workers=ray_args.ray_num_workers,
resources_per_worker=ray_args.resources_per_worker,
placement_strategy=ray_args.placement_strategy,
use_gpu=True,
def get_ray_remote_config_for_worker(
placement_group: "PlacementGroup",
bundle_idx: int,
rank: int,
world_size: int,
master_addr: str,
master_port: str,
env: dict[str, str] = None,
) -> dict[str, Any]:
r"""Get the remote config for a Ray worker."""
env_vars = {
"RANK": str(rank),
"WORLD_SIZE": str(world_size),
"MASTER_ADDR": master_addr,
"MASTER_PORT": master_port,
"TORCHELASTIC_USE_AGENT_STORE": "False",
}
env.update(env_vars)
remote_config = {
"scheduling_strategy": PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_bundle_index=bundle_idx,
),
run_config=RunConfig(
name=ray_args.ray_run_name,
storage_filesystem=ray_args.ray_storage_filesystem,
storage_path=storage_path,
),
)
return trainer
"runtime_env": {"env_vars": env},
"num_cpus": 10,
}
device_name = get_device_name()
if device_name == "gpu":
remote_config["num_gpus"] = 1
elif device_name == "npu":
remote_config["resources"] = {"NPU": 1}
return remote_config
def get_ray_head_node_ip() -> str:
r"""Get the IP address of the Ray head node."""
head_ip = next(node["NodeManagerAddress"] for node in ray.nodes() if node.get("IsHead", False))
return head_ip
def sort_placement_group_by_node_ip(placement_group: "PlacementGroup", master_addr: str = None) -> list[int]:
r"""Sort the placement group bundles by their node IP addresses."""
@ray.remote
def _get_node_ip():
return ray.util.get_node_ip_address().strip("[]")
tasks = []
for bundle_idx in range(placement_group.bundle_count):
task = _get_node_ip.options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_bundle_index=bundle_idx,
),
).remote()
tasks.append(task)
bundle_ips = ray.get(tasks)
bundle_node_ip_list = list(enumerate(bundle_ips))
sorted_bundle_node_ip_list = sorted(bundle_node_ip_list, key=lambda x: x[1])
sorted_bundle_indices = [item[0] for item in sorted_bundle_node_ip_list]
if master_addr is not None:
preferred_indices = [idx for idx, ip in bundle_node_ip_list if ip == master_addr]
if preferred_indices:
remaining = [i for i in sorted_bundle_indices if i not in preferred_indices]
sorted_bundle_indices = preferred_indices + remaining
return sorted_bundle_indices