mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-03 18:25:59 +08:00
[feature] support using ray.remote to start distributed training. (#10109)
This commit is contained in:
@@ -20,7 +20,6 @@
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Callable, Mapping
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -34,6 +33,7 @@ from typing_extensions import override
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import IGNORE_INDEX, SWANLAB_CONFIG
|
||||
from ..extras.misc import get_device_name
|
||||
from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
|
||||
@@ -49,15 +49,15 @@ if is_apollo_available():
|
||||
|
||||
if is_ray_available():
|
||||
import ray
|
||||
from ray.train import RunConfig, ScalingConfig
|
||||
from ray.train.torch import TorchTrainer
|
||||
from ray.util.placement_group import PlacementGroup, placement_group
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, TrainerCallback, TrainerState
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..hparams import DataArguments, RayArguments, TrainingArguments
|
||||
from ..hparams import DataArguments, TrainingArguments
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -807,36 +807,88 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
|
||||
return swanlab_callback
|
||||
|
||||
|
||||
def get_ray_trainer(
|
||||
training_function: Callable,
|
||||
train_loop_config: dict[str, Any],
|
||||
ray_args: "RayArguments",
|
||||
) -> "TorchTrainer":
|
||||
if not ray_args.use_ray:
|
||||
raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.")
|
||||
def get_placement_group(num_workers: int) -> tuple["PlacementGroup", dict[str, int]]:
|
||||
r"""Get the Ray placement group for distributed training."""
|
||||
bundle = {"CPU": 10}
|
||||
device_name = get_device_name().upper()
|
||||
if device_name != "CPU":
|
||||
bundle[device_name] = 1
|
||||
bundles = [bundle for _ in range(num_workers)]
|
||||
pg = placement_group(bundles, strategy="PACK")
|
||||
|
||||
if ray_args.ray_init_kwargs is not None:
|
||||
ray.init(**ray_args.ray_init_kwargs)
|
||||
return pg, bundle
|
||||
|
||||
if ray_args.ray_storage_filesystem is not None:
|
||||
# this means we are using s3/gcs
|
||||
storage_path = ray_args.ray_storage_path
|
||||
else:
|
||||
storage_path = Path(ray_args.ray_storage_path).absolute().as_posix()
|
||||
|
||||
trainer = TorchTrainer(
|
||||
training_function,
|
||||
train_loop_config=train_loop_config,
|
||||
scaling_config=ScalingConfig(
|
||||
num_workers=ray_args.ray_num_workers,
|
||||
resources_per_worker=ray_args.resources_per_worker,
|
||||
placement_strategy=ray_args.placement_strategy,
|
||||
use_gpu=True,
|
||||
def get_ray_remote_config_for_worker(
|
||||
placement_group: "PlacementGroup",
|
||||
bundle_idx: int,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
master_addr: str,
|
||||
master_port: str,
|
||||
env: dict[str, str] = None,
|
||||
) -> dict[str, Any]:
|
||||
r"""Get the remote config for a Ray worker."""
|
||||
env_vars = {
|
||||
"RANK": str(rank),
|
||||
"WORLD_SIZE": str(world_size),
|
||||
"MASTER_ADDR": master_addr,
|
||||
"MASTER_PORT": master_port,
|
||||
"TORCHELASTIC_USE_AGENT_STORE": "False",
|
||||
}
|
||||
env.update(env_vars)
|
||||
|
||||
remote_config = {
|
||||
"scheduling_strategy": PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
placement_group_bundle_index=bundle_idx,
|
||||
),
|
||||
run_config=RunConfig(
|
||||
name=ray_args.ray_run_name,
|
||||
storage_filesystem=ray_args.ray_storage_filesystem,
|
||||
storage_path=storage_path,
|
||||
),
|
||||
)
|
||||
return trainer
|
||||
"runtime_env": {"env_vars": env},
|
||||
"num_cpus": 10,
|
||||
}
|
||||
|
||||
device_name = get_device_name()
|
||||
if device_name == "gpu":
|
||||
remote_config["num_gpus"] = 1
|
||||
elif device_name == "npu":
|
||||
remote_config["resources"] = {"NPU": 1}
|
||||
|
||||
return remote_config
|
||||
|
||||
|
||||
def get_ray_head_node_ip() -> str:
|
||||
r"""Get the IP address of the Ray head node."""
|
||||
head_ip = next(node["NodeManagerAddress"] for node in ray.nodes() if node.get("IsHead", False))
|
||||
return head_ip
|
||||
|
||||
|
||||
def sort_placement_group_by_node_ip(placement_group: "PlacementGroup", master_addr: str = None) -> list[int]:
|
||||
r"""Sort the placement group bundles by their node IP addresses."""
|
||||
|
||||
@ray.remote
|
||||
def _get_node_ip():
|
||||
return ray.util.get_node_ip_address().strip("[]")
|
||||
|
||||
tasks = []
|
||||
for bundle_idx in range(placement_group.bundle_count):
|
||||
task = _get_node_ip.options(
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
placement_group_bundle_index=bundle_idx,
|
||||
),
|
||||
).remote()
|
||||
tasks.append(task)
|
||||
|
||||
bundle_ips = ray.get(tasks)
|
||||
bundle_node_ip_list = list(enumerate(bundle_ips))
|
||||
|
||||
sorted_bundle_node_ip_list = sorted(bundle_node_ip_list, key=lambda x: x[1])
|
||||
sorted_bundle_indices = [item[0] for item in sorted_bundle_node_ip_list]
|
||||
|
||||
if master_addr is not None:
|
||||
preferred_indices = [idx for idx, ip in bundle_node_ip_list if ip == master_addr]
|
||||
if preferred_indices:
|
||||
remaining = [i for i in sorted_bundle_indices if i not in preferred_indices]
|
||||
sorted_bundle_indices = preferred_indices + remaining
|
||||
|
||||
return sorted_bundle_indices
|
||||
|
||||
Reference in New Issue
Block a user