diff --git a/src/llamafactory/v1/accelerator/helper.py b/src/llamafactory/v1/accelerator/helper.py index a3108954..8a7d6869 100644 --- a/src/llamafactory/v1/accelerator/helper.py +++ b/src/llamafactory/v1/accelerator/helper.py @@ -60,16 +60,16 @@ def get_rank() -> int: return int(os.getenv("RANK", "0")) -def get_local_rank() -> int: - """Get local rank.""" - return int(os.getenv("LOCAL_RANK", "0")) - - def get_world_size() -> int: """Get world size.""" return int(os.getenv("WORLD_SIZE", "1")) +def get_local_rank() -> int: + """Get local rank.""" + return int(os.getenv("LOCAL_RANK", "0")) + + def get_local_world_size() -> int: """Get local world size.""" return int(os.getenv("LOCAL_WORLD_SIZE", "1")) @@ -79,7 +79,7 @@ def get_local_world_size() -> int: def get_current_accelerator(check_available: bool = True) -> torch.device: """Get current accelerator. - Note: this api requires torch>=2.7.0, 2.6 or lower will get an AttributeError or RuntimeError + Note: this api requires torch>=2.7.0, otherwise it will raise an AttributeError or RuntimeError """ if not hasattr(torch, "accelerator"): raise RuntimeError("torch.accelerator is not available, please upgrade torch to 2.7.0 or higher.") @@ -123,7 +123,7 @@ def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional[" is_tensor = isinstance(data, torch.Tensor) if is_ndarray: - data = torch.from_numpy(data) + data = torch.from_numpy(data).to(device=device, dtype=torch.float) elif not is_tensor: data = torch.tensor(data, dtype=torch.float, device=device) @@ -140,7 +140,7 @@ def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional[" if is_tensor: return data elif is_ndarray: - return data.numpy() + return data.cpu().numpy() elif data.numel() == 1: return data.item() else: diff --git a/src/llamafactory/v1/accelerator/interface.py b/src/llamafactory/v1/accelerator/interface.py index de4306f0..47a878f2 100644 --- a/src/llamafactory/v1/accelerator/interface.py +++ b/src/llamafactory/v1/accelerator/interface.py @@ -1,4 +1,7 @@ -# Copyright 2025 the LlamaFactory team. +# Copyright 2025 Bytedance Ltd. and the LlamaFactory team. +# +# This code is inspired by the Bytedance's VeOmni library. +# https://github.com/ByteDance-Seed/VeOmni/blob/v0.1.4/veomni/distributed/parallel_state.py # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,41 +16,91 @@ # limitations under the License. from dataclasses import dataclass -from typing import Any, Optional +from enum import Enum +from typing import TYPE_CHECKING, Any, Optional from torch.distributed.device_mesh import DeviceMesh, init_device_mesh -from ..utils.types import TensorLike -from .helper import ReduceOp, all_reduce, get_current_accelerator, get_rank, get_world_size, is_distributed +from ..utils.types import Tensor, TensorLike +from .helper import ( + ReduceOp, + all_gather, + all_reduce, + get_current_accelerator, + get_local_rank, + get_local_world_size, + get_rank, + get_world_size, + is_distributed, +) + + +if TYPE_CHECKING: + from torch.distributed import ProcessGroup + + +class Dim(str, Enum): + """Dimension names.""" + + MP_REPLICATE = "mp_replicate" + MP_SHARD = "mp_shard" + DP = "dp" + CP = "cp" @dataclass class DistributedStrategy: """Distributed strategy.""" + mp_replicate_size: int = 1 + """Model parallel replicate size, default to 1.""" + mp_shard_size: Optional[int] = None + """Model parallel shard size, default to world_size // mp_replicate_size.""" dp_size: Optional[int] = None - tp_size: int = 1 + """Data parallel size, default to world_size // cp_size.""" + cp_size: int = 1 + """Context parallel size, default to 1.""" def __post_init__(self) -> None: + if not is_distributed(): + self.mp_shard_size = 1 + elif self.mp_shard_size is None: + self.mp_shard_size = get_world_size() // self.mp_replicate_size + elif self.mp_replicate_size * self.mp_shard_size != get_world_size(): + raise ValueError( + f"mp_replicate_size * mp_shard_size must equal to world_size, " + f"got {self.mp_replicate_size} * {self.mp_shard_size} != {get_world_size()}." + ) + if not is_distributed(): self.dp_size = 1 elif self.dp_size is None: - self.dp_size = get_world_size() // self.tp_size - elif self.dp_size * self.tp_size != get_world_size(): + self.dp_size = get_world_size() // self.cp_size + elif self.dp_size * self.cp_size != get_world_size(): raise ValueError( - f"dp_size * tp_size must equal to world_size, " - f"got {self.dp_size} * {self.tp_size} != {get_world_size()}." + f"dp_size * cp_size must equal to world_size, " + f"got {self.dp_size} * {self.cp_size} != {get_world_size()}." ) @property - def mesh_shape(self) -> tuple[int, int]: - """Mesh shape.""" - return (self.dp_size, self.tp_size) + def model_mesh_shape(self) -> tuple[int, int]: + """Model parallel mesh shape.""" + return (self.mp_replicate_size, self.mp_shard_size) @property - def mesh_dim_names(self) -> tuple[str, str]: - """Mesh dimension names.""" - return ("dp", "tp") + def model_mesh_dim_names(self) -> tuple[str, str]: + """Model parallel mesh dimension names.""" + return (Dim.MP_REPLICATE.value, Dim.MP_SHARD.value) + + @property + def data_mesh_shape(self) -> tuple[int, int]: + """Data parallel mesh shape.""" + return (self.dp_size, self.cp_size) + + @property + def data_mesh_dim_names(self) -> tuple[str, str]: + """Data parallel mesh dimension names.""" + return (Dim.DP.value, Dim.CP.value) class DistributedInterface: @@ -55,15 +108,18 @@ class DistributedInterface: _instance: Optional["DistributedInterface"] = None _initialized: bool = False + _is_distributed = is_distributed() + _rank = get_rank() + _world_size = get_world_size() + _local_rank = get_local_rank() + _local_world_size = get_local_world_size() - is_distributed = is_distributed() - """Check if distributed environment is available.""" - rank = get_rank() - """Global rank.""" - world_size = get_world_size() - """Global world size.""" - device_mesh: Optional[DeviceMesh] = None - """Device mesh.""" + strategy: Optional[DistributedStrategy] = None + """Distributed strategy.""" + model_device_mesh: Optional[DeviceMesh] = None + """Model parallel device mesh.""" + data_device_mesh: Optional[DeviceMesh] = None + """Data parallel device mesh.""" current_accelerator = get_current_accelerator() """Current accelerator.""" @@ -79,44 +135,89 @@ class DistributedInterface: return self.strategy = strategy - if self.is_distributed: - self.device_mesh = init_device_mesh( + if self._is_distributed: + self.model_device_mesh = init_device_mesh( device_type=self.current_accelerator.type, - mesh_shape=strategy.mesh_shape, - mesh_dim_names=strategy.mesh_dim_names, + mesh_shape=strategy.model_mesh_shape, + mesh_dim_names=strategy.model_mesh_dim_names, + ) + self.data_device_mesh = init_device_mesh( + device_type=self.current_accelerator.type, + mesh_shape=strategy.data_mesh_shape, + mesh_dim_names=strategy.data_mesh_dim_names, ) else: - self.device_mesh = None + self.model_device_mesh = None + self.data_device_mesh = None self._initialized = True def __str__(self) -> str: return ( - f"DistributedInterface(strategy={self.strategy}), is_distributed={self.is_distributed}, " - f"rank={self.rank}, world_size={self.world_size}, " - f"device_mesh={self.device_mesh}, current_accelerator={self.current_accelerator}" + f"DistributedInterface(strategy={self.strategy}), is_distributed={self._is_distributed}, " + f"current_accelerator={self.current_accelerator}, rank={self._rank}, world_size={self._world_size}, " + f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}" ) - def dp_rank(self) -> int: - """Data parallel rank.""" - if self.device_mesh is None: + @classmethod + def get_device_mesh(cls, dim: Optional[Dim] = None) -> Optional[DeviceMesh]: + """Get device mesh for specified dimension.""" + if dim is None: + raise ValueError("dim must be specified.") + elif cls.model_device_mesh is None: + return None + elif dim in cls.strategy.data_mesh_dim_names: + return cls.data_device_mesh[dim.value] + else: + return cls.model_device_mesh[dim.value] + + @classmethod + def get_group(cls, dim: Optional[Dim] = None) -> Optional["ProcessGroup"]: + """Get process group for specified dimension.""" + if cls.model_device_mesh is None or dim is None: + return None + else: + return cls.get_device_mesh(dim).get_group() + + @classmethod + def get_rank(cls, dim: Optional[Dim] = None) -> int: + """Get parallel rank for specified dimension.""" + if cls.model_device_mesh is None: return 0 + elif dim is None: + return cls._rank + else: + return cls.get_device_mesh(dim).get_local_rank() - return self.device_mesh["dp"].get_rank() - - def dp_size(self) -> int: - """Data parallel size.""" - if self.device_mesh is None: + @classmethod + def get_world_size(cls, dim: Optional[Dim] = None) -> int: + """Get parallel size for specified dimension.""" + if cls.model_device_mesh is None: return 1 + elif dim is None: + return cls._world_size + else: + return cls.get_device_mesh(dim).size() - return self.device_mesh["dp"].size() + @classmethod + def get_local_rank(cls) -> int: + """Get parallel local rank.""" + return cls._local_rank - def all_reduce_over_dp(self, data: TensorLike, op: ReduceOp = ReduceOp.MEAN) -> TensorLike: - """All reduce tensor.""" - if self.device_mesh is None: - return data + @classmethod + def get_local_world_size(cls) -> int: + """Get parallel local world size.""" + return cls._local_world_size - return all_reduce(data, op, self.device_mesh["dp"].get_group()) + @classmethod + def all_gather(cls, data: Tensor, dim: Optional[Dim] = Dim.DP) -> Tensor: + """Gather tensor across specified parallel group.""" + return all_gather(data, cls.get_group(dim)) if cls.model_device_mesh is not None else data + + @classmethod + def all_reduce(cls, data: TensorLike, op: ReduceOp = ReduceOp.MEAN, dim: Optional[Dim] = Dim.DP) -> TensorLike: + """Reduce tensor across specified parallel group.""" + return all_reduce(data, op, cls.get_group(dim)) if cls.model_device_mesh is not None else data if __name__ == "__main__": diff --git a/tests_v1/accelerator/test_interface.py b/tests_v1/accelerator/test_interface.py index ab39915b..2651ebf7 100644 --- a/tests_v1/accelerator/test_interface.py +++ b/tests_v1/accelerator/test_interface.py @@ -20,5 +20,7 @@ from llamafactory.v1.accelerator.interface import DistributedInterface, Distribu def test_distributed_interface(): DistributedInterface(DistributedStrategy()) - assert DistributedInterface.rank == int(os.getenv("RANK", "0")) - assert DistributedInterface.world_size == int(os.getenv("WORLD_SIZE", "1")) + assert DistributedInterface.get_rank() == int(os.getenv("RANK", "0")) + assert DistributedInterface.get_world_size() == int(os.getenv("WORLD_SIZE", "1")) + assert DistributedInterface.get_local_rank() == int(os.getenv("LOCAL_RANK", "0")) + assert DistributedInterface.get_local_world_size() == int(os.getenv("LOCAL_WORLD_SIZE", "1"))