[v1] add accelerator (#9607)

This commit is contained in:
Yaowei Zheng
2025-12-12 19:22:06 +08:00
committed by GitHub
parent 4fd94141a4
commit 203069e11c
36 changed files with 941 additions and 443 deletions

View File

@@ -1,36 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from torch.distributed.device_mesh import DeviceMesh
class DeviceMeshManager:
"""Device mesh manager."""
_instance: Optional["DeviceMeshManager"] = None
_initialized: bool = False
def __new__(cls) -> "DeviceMeshManager":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
if self._initialized:
return
self.device_mesh: Optional[DeviceMesh] = None
self._initialized = True

View File

@@ -1,4 +1,7 @@
# Copyright 2025 the LlamaFactory team.
# Copyright 2025 Bytedance Ltd. and the LlamaFactory team.
#
# This code is inspired by the Bytedance's VeOmni library.
# https://github.com/ByteDance-Seed/VeOmni/blob/v0.1.4/veomni/utils/dist_utils.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,12 +15,68 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from contextlib import contextmanager
from enum import Enum, unique
from functools import lru_cache
from typing import TYPE_CHECKING, Optional
import numpy as np
import torch
import torch.distributed as dist
from ..utils.types import Tensor, TensorLike
def get_current_accelerator(check_available: bool = True):
if TYPE_CHECKING:
from torch.distributed import ProcessGroup
@unique
class DeviceType(str, Enum):
CPU = "cpu"
CUDA = "cuda"
META = "meta"
MPS = "mps"
NPU = "npu"
XPU = "xpu"
@unique
class ReduceOp(str, Enum):
SUM = "sum"
MEAN = "mean"
MAX = "max"
MIN = "min"
def is_distributed() -> bool:
"""Check if distributed environment is available."""
return os.getenv("RANK") is not None
def get_rank() -> int:
"""Get rank."""
return int(os.getenv("RANK", "0"))
def get_local_rank() -> int:
"""Get local rank."""
return int(os.getenv("LOCAL_RANK", "0"))
def get_world_size() -> int:
"""Get world size."""
return int(os.getenv("WORLD_SIZE", "1"))
def get_local_world_size() -> int:
"""Get local world size."""
return int(os.getenv("LOCAL_WORLD_SIZE", "1"))
@lru_cache
def get_current_accelerator(check_available: bool = True) -> torch.device:
"""Get current accelerator.
Note: this api requires torch>=2.7.0, 2.6 or lower will get an AttributeError or RuntimeError
@@ -27,26 +86,78 @@ def get_current_accelerator(check_available: bool = True):
accelerator = torch.accelerator.current_accelerator(check_available=check_available)
if accelerator is None:
return torch.device("cpu")
return torch.device(DeviceType.CPU.value)
return accelerator
@lru_cache
def is_torch_npu_available():
return get_current_accelerator().type == "npu"
@lru_cache
def is_torch_cuda_available():
return get_current_accelerator().type == "cuda"
return get_current_accelerator().type == DeviceType.CUDA
@lru_cache
def is_torch_xpu_available():
return get_current_accelerator().type == "xpu"
@lru_cache
def is_torch_mps_available():
return get_current_accelerator().type == "mps"
return get_current_accelerator().type == DeviceType.MPS
def is_torch_npu_available():
return get_current_accelerator().type == DeviceType.NPU
def is_torch_xpu_available():
return get_current_accelerator().type == DeviceType.XPU
def all_gather(tensor: Tensor, group: Optional["ProcessGroup"] = None) -> Tensor:
"""Gathers the tensor from all ranks and concats them along the first dim."""
world_size = get_world_size()
device = get_current_accelerator()
output_tensor = torch.empty(world_size * tensor.numel(), dtype=tensor.dtype, device=device)
dist.all_gather_into_tensor(output_tensor, tensor, group=group)
return output_tensor.view(-1, *tensor.size()[1:])
def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional["ProcessGroup"] = None) -> TensorLike:
"""Performs all reduce in the given process group."""
device = get_current_accelerator()
is_ndarray = isinstance(data, np.ndarray)
is_tensor = isinstance(data, torch.Tensor)
if is_ndarray:
data = torch.from_numpy(data)
elif not is_tensor:
data = torch.tensor(data, dtype=torch.float, device=device)
reduce_ops = {
ReduceOp.MEAN: dist.ReduceOp.SUM,
ReduceOp.SUM: dist.ReduceOp.SUM,
ReduceOp.MAX: dist.ReduceOp.MAX,
ReduceOp.MIN: dist.ReduceOp.MIN,
}
dist.all_reduce(data, op=reduce_ops[op], group=group)
if op == ReduceOp.MEAN: # ReduceOp.AVG is not supported by the NPU backend
data /= dist.get_world_size(group=group)
if is_tensor:
return data
elif is_ndarray:
return data.numpy()
elif data.numel() == 1:
return data.item()
else:
return data.tolist()
@contextmanager
def main_process_first(local_only: bool = True) -> None:
"""A context manager for torch distributed environment to do something on the main process firstly."""
if get_world_size() > 1:
is_main_process = get_local_rank() == 0 if local_only else get_rank() == 0
try:
if not is_main_process:
dist.barrier()
yield
finally:
if is_main_process:
dist.barrier()
else:
yield

View File

@@ -0,0 +1,123 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Optional
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from ..utils.types import TensorLike
from .helper import ReduceOp, all_reduce, get_current_accelerator, get_rank, get_world_size, is_distributed
@dataclass
class DistributedStrategy:
"""Distributed strategy."""
dp_size: Optional[int] = None
tp_size: int = 1
def __post_init__(self) -> None:
if not is_distributed():
self.dp_size = 1
elif self.dp_size is None:
self.dp_size = get_world_size() // self.tp_size
elif self.dp_size * self.tp_size != get_world_size():
raise ValueError(
f"dp_size * tp_size must equal to world_size, "
f"got {self.dp_size} * {self.tp_size} != {get_world_size()}."
)
@property
def mesh_shape(self) -> tuple[int, int]:
"""Mesh shape."""
return (self.dp_size, self.tp_size)
@property
def mesh_dim_names(self) -> tuple[str, str]:
"""Mesh dimension names."""
return ("dp", "tp")
class DistributedInterface:
"""Distributed interface."""
_instance: Optional["DistributedInterface"] = None
_initialized: bool = False
is_distributed = is_distributed()
"""Check if distributed environment is available."""
rank = get_rank()
"""Global rank."""
world_size = get_world_size()
"""Global world size."""
device_mesh: Optional[DeviceMesh] = None
"""Device mesh."""
current_accelerator = get_current_accelerator()
"""Current accelerator."""
def __new__(cls, *args: Any, **kwargs: Any) -> "DistributedInterface":
"""Singleton pattern."""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, strategy: DistributedStrategy) -> None:
if self._initialized:
return
self.strategy = strategy
if self.is_distributed:
self.device_mesh = init_device_mesh(
device_type=self.current_accelerator.type,
mesh_shape=strategy.mesh_shape,
mesh_dim_names=strategy.mesh_dim_names,
)
else:
self.device_mesh = None
self._initialized = True
def __str__(self) -> str:
return (
f"DistributedInterface(strategy={self.strategy}), is_distributed={self.is_distributed}, "
f"rank={self.rank}, world_size={self.world_size}, "
f"device_mesh={self.device_mesh}, current_accelerator={self.current_accelerator}"
)
def dp_rank(self) -> int:
"""Data parallel rank."""
if self.device_mesh is None:
return 0
return self.device_mesh["dp"].get_rank()
def dp_size(self) -> int:
"""Data parallel size."""
if self.device_mesh is None:
return 1
return self.device_mesh["dp"].size()
def all_reduce_over_dp(self, data: TensorLike, op: ReduceOp = ReduceOp.MEAN) -> TensorLike:
"""All reduce tensor."""
if self.device_mesh is None:
return data
return all_reduce(data, op, self.device_mesh["dp"].get_group())
if __name__ == "__main__":
print(DistributedInterface(DistributedStrategy()))