mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 03:10:35 +08:00
[v1] add dp & mp mesh (#9611)
This commit is contained in:
@@ -60,16 +60,16 @@ def get_rank() -> int:
|
|||||||
return int(os.getenv("RANK", "0"))
|
return int(os.getenv("RANK", "0"))
|
||||||
|
|
||||||
|
|
||||||
def get_local_rank() -> int:
|
|
||||||
"""Get local rank."""
|
|
||||||
return int(os.getenv("LOCAL_RANK", "0"))
|
|
||||||
|
|
||||||
|
|
||||||
def get_world_size() -> int:
|
def get_world_size() -> int:
|
||||||
"""Get world size."""
|
"""Get world size."""
|
||||||
return int(os.getenv("WORLD_SIZE", "1"))
|
return int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_rank() -> int:
|
||||||
|
"""Get local rank."""
|
||||||
|
return int(os.getenv("LOCAL_RANK", "0"))
|
||||||
|
|
||||||
|
|
||||||
def get_local_world_size() -> int:
|
def get_local_world_size() -> int:
|
||||||
"""Get local world size."""
|
"""Get local world size."""
|
||||||
return int(os.getenv("LOCAL_WORLD_SIZE", "1"))
|
return int(os.getenv("LOCAL_WORLD_SIZE", "1"))
|
||||||
@@ -79,7 +79,7 @@ def get_local_world_size() -> int:
|
|||||||
def get_current_accelerator(check_available: bool = True) -> torch.device:
|
def get_current_accelerator(check_available: bool = True) -> torch.device:
|
||||||
"""Get current accelerator.
|
"""Get current accelerator.
|
||||||
|
|
||||||
Note: this api requires torch>=2.7.0, 2.6 or lower will get an AttributeError or RuntimeError
|
Note: this api requires torch>=2.7.0, otherwise it will raise an AttributeError or RuntimeError
|
||||||
"""
|
"""
|
||||||
if not hasattr(torch, "accelerator"):
|
if not hasattr(torch, "accelerator"):
|
||||||
raise RuntimeError("torch.accelerator is not available, please upgrade torch to 2.7.0 or higher.")
|
raise RuntimeError("torch.accelerator is not available, please upgrade torch to 2.7.0 or higher.")
|
||||||
@@ -123,7 +123,7 @@ def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional["
|
|||||||
is_tensor = isinstance(data, torch.Tensor)
|
is_tensor = isinstance(data, torch.Tensor)
|
||||||
|
|
||||||
if is_ndarray:
|
if is_ndarray:
|
||||||
data = torch.from_numpy(data)
|
data = torch.from_numpy(data).to(device=device, dtype=torch.float)
|
||||||
elif not is_tensor:
|
elif not is_tensor:
|
||||||
data = torch.tensor(data, dtype=torch.float, device=device)
|
data = torch.tensor(data, dtype=torch.float, device=device)
|
||||||
|
|
||||||
@@ -140,7 +140,7 @@ def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional["
|
|||||||
if is_tensor:
|
if is_tensor:
|
||||||
return data
|
return data
|
||||||
elif is_ndarray:
|
elif is_ndarray:
|
||||||
return data.numpy()
|
return data.cpu().numpy()
|
||||||
elif data.numel() == 1:
|
elif data.numel() == 1:
|
||||||
return data.item()
|
return data.item()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
# Copyright 2025 the LlamaFactory team.
|
# Copyright 2025 Bytedance Ltd. and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the Bytedance's VeOmni library.
|
||||||
|
# https://github.com/ByteDance-Seed/VeOmni/blob/v0.1.4/veomni/distributed/parallel_state.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -13,41 +16,91 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional
|
from enum import Enum
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||||
|
|
||||||
from ..utils.types import TensorLike
|
from ..utils.types import Tensor, TensorLike
|
||||||
from .helper import ReduceOp, all_reduce, get_current_accelerator, get_rank, get_world_size, is_distributed
|
from .helper import (
|
||||||
|
ReduceOp,
|
||||||
|
all_gather,
|
||||||
|
all_reduce,
|
||||||
|
get_current_accelerator,
|
||||||
|
get_local_rank,
|
||||||
|
get_local_world_size,
|
||||||
|
get_rank,
|
||||||
|
get_world_size,
|
||||||
|
is_distributed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
|
||||||
|
class Dim(str, Enum):
|
||||||
|
"""Dimension names."""
|
||||||
|
|
||||||
|
MP_REPLICATE = "mp_replicate"
|
||||||
|
MP_SHARD = "mp_shard"
|
||||||
|
DP = "dp"
|
||||||
|
CP = "cp"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DistributedStrategy:
|
class DistributedStrategy:
|
||||||
"""Distributed strategy."""
|
"""Distributed strategy."""
|
||||||
|
|
||||||
|
mp_replicate_size: int = 1
|
||||||
|
"""Model parallel replicate size, default to 1."""
|
||||||
|
mp_shard_size: Optional[int] = None
|
||||||
|
"""Model parallel shard size, default to world_size // mp_replicate_size."""
|
||||||
dp_size: Optional[int] = None
|
dp_size: Optional[int] = None
|
||||||
tp_size: int = 1
|
"""Data parallel size, default to world_size // cp_size."""
|
||||||
|
cp_size: int = 1
|
||||||
|
"""Context parallel size, default to 1."""
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
|
if not is_distributed():
|
||||||
|
self.mp_shard_size = 1
|
||||||
|
elif self.mp_shard_size is None:
|
||||||
|
self.mp_shard_size = get_world_size() // self.mp_replicate_size
|
||||||
|
elif self.mp_replicate_size * self.mp_shard_size != get_world_size():
|
||||||
|
raise ValueError(
|
||||||
|
f"mp_replicate_size * mp_shard_size must equal to world_size, "
|
||||||
|
f"got {self.mp_replicate_size} * {self.mp_shard_size} != {get_world_size()}."
|
||||||
|
)
|
||||||
|
|
||||||
if not is_distributed():
|
if not is_distributed():
|
||||||
self.dp_size = 1
|
self.dp_size = 1
|
||||||
elif self.dp_size is None:
|
elif self.dp_size is None:
|
||||||
self.dp_size = get_world_size() // self.tp_size
|
self.dp_size = get_world_size() // self.cp_size
|
||||||
elif self.dp_size * self.tp_size != get_world_size():
|
elif self.dp_size * self.cp_size != get_world_size():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"dp_size * tp_size must equal to world_size, "
|
f"dp_size * cp_size must equal to world_size, "
|
||||||
f"got {self.dp_size} * {self.tp_size} != {get_world_size()}."
|
f"got {self.dp_size} * {self.cp_size} != {get_world_size()}."
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mesh_shape(self) -> tuple[int, int]:
|
def model_mesh_shape(self) -> tuple[int, int]:
|
||||||
"""Mesh shape."""
|
"""Model parallel mesh shape."""
|
||||||
return (self.dp_size, self.tp_size)
|
return (self.mp_replicate_size, self.mp_shard_size)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mesh_dim_names(self) -> tuple[str, str]:
|
def model_mesh_dim_names(self) -> tuple[str, str]:
|
||||||
"""Mesh dimension names."""
|
"""Model parallel mesh dimension names."""
|
||||||
return ("dp", "tp")
|
return (Dim.MP_REPLICATE.value, Dim.MP_SHARD.value)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data_mesh_shape(self) -> tuple[int, int]:
|
||||||
|
"""Data parallel mesh shape."""
|
||||||
|
return (self.dp_size, self.cp_size)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data_mesh_dim_names(self) -> tuple[str, str]:
|
||||||
|
"""Data parallel mesh dimension names."""
|
||||||
|
return (Dim.DP.value, Dim.CP.value)
|
||||||
|
|
||||||
|
|
||||||
class DistributedInterface:
|
class DistributedInterface:
|
||||||
@@ -55,15 +108,18 @@ class DistributedInterface:
|
|||||||
|
|
||||||
_instance: Optional["DistributedInterface"] = None
|
_instance: Optional["DistributedInterface"] = None
|
||||||
_initialized: bool = False
|
_initialized: bool = False
|
||||||
|
_is_distributed = is_distributed()
|
||||||
|
_rank = get_rank()
|
||||||
|
_world_size = get_world_size()
|
||||||
|
_local_rank = get_local_rank()
|
||||||
|
_local_world_size = get_local_world_size()
|
||||||
|
|
||||||
is_distributed = is_distributed()
|
strategy: Optional[DistributedStrategy] = None
|
||||||
"""Check if distributed environment is available."""
|
"""Distributed strategy."""
|
||||||
rank = get_rank()
|
model_device_mesh: Optional[DeviceMesh] = None
|
||||||
"""Global rank."""
|
"""Model parallel device mesh."""
|
||||||
world_size = get_world_size()
|
data_device_mesh: Optional[DeviceMesh] = None
|
||||||
"""Global world size."""
|
"""Data parallel device mesh."""
|
||||||
device_mesh: Optional[DeviceMesh] = None
|
|
||||||
"""Device mesh."""
|
|
||||||
current_accelerator = get_current_accelerator()
|
current_accelerator = get_current_accelerator()
|
||||||
"""Current accelerator."""
|
"""Current accelerator."""
|
||||||
|
|
||||||
@@ -79,44 +135,89 @@ class DistributedInterface:
|
|||||||
return
|
return
|
||||||
|
|
||||||
self.strategy = strategy
|
self.strategy = strategy
|
||||||
if self.is_distributed:
|
if self._is_distributed:
|
||||||
self.device_mesh = init_device_mesh(
|
self.model_device_mesh = init_device_mesh(
|
||||||
device_type=self.current_accelerator.type,
|
device_type=self.current_accelerator.type,
|
||||||
mesh_shape=strategy.mesh_shape,
|
mesh_shape=strategy.model_mesh_shape,
|
||||||
mesh_dim_names=strategy.mesh_dim_names,
|
mesh_dim_names=strategy.model_mesh_dim_names,
|
||||||
|
)
|
||||||
|
self.data_device_mesh = init_device_mesh(
|
||||||
|
device_type=self.current_accelerator.type,
|
||||||
|
mesh_shape=strategy.data_mesh_shape,
|
||||||
|
mesh_dim_names=strategy.data_mesh_dim_names,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.device_mesh = None
|
self.model_device_mesh = None
|
||||||
|
self.data_device_mesh = None
|
||||||
|
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return (
|
return (
|
||||||
f"DistributedInterface(strategy={self.strategy}), is_distributed={self.is_distributed}, "
|
f"DistributedInterface(strategy={self.strategy}), is_distributed={self._is_distributed}, "
|
||||||
f"rank={self.rank}, world_size={self.world_size}, "
|
f"current_accelerator={self.current_accelerator}, rank={self._rank}, world_size={self._world_size}, "
|
||||||
f"device_mesh={self.device_mesh}, current_accelerator={self.current_accelerator}"
|
f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def dp_rank(self) -> int:
|
@classmethod
|
||||||
"""Data parallel rank."""
|
def get_device_mesh(cls, dim: Optional[Dim] = None) -> Optional[DeviceMesh]:
|
||||||
if self.device_mesh is None:
|
"""Get device mesh for specified dimension."""
|
||||||
|
if dim is None:
|
||||||
|
raise ValueError("dim must be specified.")
|
||||||
|
elif cls.model_device_mesh is None:
|
||||||
|
return None
|
||||||
|
elif dim in cls.strategy.data_mesh_dim_names:
|
||||||
|
return cls.data_device_mesh[dim.value]
|
||||||
|
else:
|
||||||
|
return cls.model_device_mesh[dim.value]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_group(cls, dim: Optional[Dim] = None) -> Optional["ProcessGroup"]:
|
||||||
|
"""Get process group for specified dimension."""
|
||||||
|
if cls.model_device_mesh is None or dim is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return cls.get_device_mesh(dim).get_group()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_rank(cls, dim: Optional[Dim] = None) -> int:
|
||||||
|
"""Get parallel rank for specified dimension."""
|
||||||
|
if cls.model_device_mesh is None:
|
||||||
return 0
|
return 0
|
||||||
|
elif dim is None:
|
||||||
|
return cls._rank
|
||||||
|
else:
|
||||||
|
return cls.get_device_mesh(dim).get_local_rank()
|
||||||
|
|
||||||
return self.device_mesh["dp"].get_rank()
|
@classmethod
|
||||||
|
def get_world_size(cls, dim: Optional[Dim] = None) -> int:
|
||||||
def dp_size(self) -> int:
|
"""Get parallel size for specified dimension."""
|
||||||
"""Data parallel size."""
|
if cls.model_device_mesh is None:
|
||||||
if self.device_mesh is None:
|
|
||||||
return 1
|
return 1
|
||||||
|
elif dim is None:
|
||||||
|
return cls._world_size
|
||||||
|
else:
|
||||||
|
return cls.get_device_mesh(dim).size()
|
||||||
|
|
||||||
return self.device_mesh["dp"].size()
|
@classmethod
|
||||||
|
def get_local_rank(cls) -> int:
|
||||||
|
"""Get parallel local rank."""
|
||||||
|
return cls._local_rank
|
||||||
|
|
||||||
def all_reduce_over_dp(self, data: TensorLike, op: ReduceOp = ReduceOp.MEAN) -> TensorLike:
|
@classmethod
|
||||||
"""All reduce tensor."""
|
def get_local_world_size(cls) -> int:
|
||||||
if self.device_mesh is None:
|
"""Get parallel local world size."""
|
||||||
return data
|
return cls._local_world_size
|
||||||
|
|
||||||
return all_reduce(data, op, self.device_mesh["dp"].get_group())
|
@classmethod
|
||||||
|
def all_gather(cls, data: Tensor, dim: Optional[Dim] = Dim.DP) -> Tensor:
|
||||||
|
"""Gather tensor across specified parallel group."""
|
||||||
|
return all_gather(data, cls.get_group(dim)) if cls.model_device_mesh is not None else data
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def all_reduce(cls, data: TensorLike, op: ReduceOp = ReduceOp.MEAN, dim: Optional[Dim] = Dim.DP) -> TensorLike:
|
||||||
|
"""Reduce tensor across specified parallel group."""
|
||||||
|
return all_reduce(data, op, cls.get_group(dim)) if cls.model_device_mesh is not None else data
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -20,5 +20,7 @@ from llamafactory.v1.accelerator.interface import DistributedInterface, Distribu
|
|||||||
|
|
||||||
def test_distributed_interface():
|
def test_distributed_interface():
|
||||||
DistributedInterface(DistributedStrategy())
|
DistributedInterface(DistributedStrategy())
|
||||||
assert DistributedInterface.rank == int(os.getenv("RANK", "0"))
|
assert DistributedInterface.get_rank() == int(os.getenv("RANK", "0"))
|
||||||
assert DistributedInterface.world_size == int(os.getenv("WORLD_SIZE", "1"))
|
assert DistributedInterface.get_world_size() == int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
assert DistributedInterface.get_local_rank() == int(os.getenv("LOCAL_RANK", "0"))
|
||||||
|
assert DistributedInterface.get_local_world_size() == int(os.getenv("LOCAL_WORLD_SIZE", "1"))
|
||||||
|
|||||||
Reference in New Issue
Block a user