[misc] fix accelerator (#9661)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Yaowei Zheng
2025-12-25 02:11:04 +08:00
committed by GitHub
parent 6a2eafbae3
commit a754604c11
44 changed files with 396 additions and 448 deletions

View File

@@ -94,9 +94,8 @@ def configure_quantization(
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if (
quant_method not in (QuantizationMethod.MXFP4 and QuantizationMethod.FP8)
and (is_deepspeed_zero3_enabled() or is_fsdp_enabled())
if quant_method not in (QuantizationMethod.MXFP4, QuantizationMethod.FP8) and (
is_deepspeed_zero3_enabled() or is_fsdp_enabled()
):
# mxfp4 will dequant the model weights
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")

View File

@@ -15,11 +15,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions used by the distributed interface.
Including:
- Environment info (rank, world_size, local_rank, etc.)
- Accelerator info (device type, device count, etc.)
- Collective communication operations (all_gather, all_reduce, broadcast)
- Synchronize processes and ensure main-process-first execution order
"""
import os
from contextlib import contextmanager
from enum import Enum, unique
from functools import lru_cache
from typing import Optional
from functools import lru_cache, wraps
from typing import Callable, Optional
import numpy as np
import torch
@@ -46,6 +55,22 @@ class ReduceOp(str, Enum):
MIN = "min"
def requires_accelerator(fn):
"""Decorator to check if torch.accelerator is available.
Note: this api requires torch>=2.7.0, otherwise it will raise an AttributeError or RuntimeError
"""
@wraps(fn)
def wrapper(*args, **kwargs):
if not hasattr(torch, "accelerator"):
raise RuntimeError("torch.accelerator is not available, please upgrade torch to 2.7.0 or higher.")
return fn(*args, **kwargs)
return wrapper
def is_distributed() -> bool:
"""Check if distributed environment is available."""
return os.getenv("RANK") is not None
@@ -72,105 +97,105 @@ def get_local_world_size() -> int:
@lru_cache
@requires_accelerator
def get_current_accelerator(check_available: bool = True) -> torch.device:
"""Get current accelerator.
Note: this api requires torch>=2.7.0, otherwise it will raise an AttributeError or RuntimeError
"""
if not hasattr(torch, "accelerator"):
raise RuntimeError("torch.accelerator is not available, please upgrade torch to 2.7.0 or higher.")
"""Get current accelerator."""
accelerator = torch.accelerator.current_accelerator(check_available=check_available)
if accelerator is None:
return torch.device(DeviceType.CPU.value)
return accelerator or torch.device(DeviceType.CPU.value)
return accelerator
@lru_cache
@requires_accelerator
def get_device_count() -> int:
"""Get the number of available devices."""
return torch.accelerator.device_count()
@requires_accelerator
def synchronize() -> None:
"""Synchronize all processes."""
torch.accelerator.synchronize()
@requires_accelerator
def set_device() -> None:
"""Set current accelerator."""
torch.accelerator.set_device_index(get_local_rank())
def is_torch_cuda_available():
"""Check if CUDA is available."""
return get_current_accelerator().type == DeviceType.CUDA
def is_torch_mps_available():
"""Check if MPS is available."""
return get_current_accelerator().type == DeviceType.MPS
def is_torch_npu_available():
"""Check if NPU is available."""
return get_current_accelerator().type == DeviceType.NPU
def is_torch_xpu_available():
"""Check if XPU is available."""
return get_current_accelerator().type == DeviceType.XPU
def get_current_device() -> "torch.device":
r"""Get the current available device."""
if is_torch_xpu_available():
device = "xpu:{}".format(os.getenv("LOCAL_RANK", "0"))
elif is_torch_npu_available():
device = "npu:{}".format(os.getenv("LOCAL_RANK", "0"))
elif is_torch_mps_available():
device = "mps:{}".format(os.getenv("LOCAL_RANK", "0"))
elif is_torch_cuda_available():
device = "cuda:{}".format(os.getenv("LOCAL_RANK", "0"))
def operate_tensorlike(fn: Callable[[...], Tensor], data: TensorLike, **kwargs) -> TensorLike:
"""Operate tensorlike data on current accelerator."""
device = get_current_accelerator()
is_tensor = isinstance(data, torch.Tensor)
is_ndarray = isinstance(data, np.ndarray)
if is_tensor:
orig_device = data.device
data = data.to(device=device)
elif is_ndarray:
data = torch.from_numpy(data).to(device=device, dtype=torch.float)
else:
device = "cpu"
data = torch.tensor(data, dtype=torch.float, device=device)
return torch.device(device)
result = fn(data, **kwargs)
def get_device_count() -> int:
r"""Get the number of available devices."""
if is_torch_xpu_available():
return torch.xpu.device_count()
elif is_torch_npu_available():
return torch.npu.device_count()
elif is_torch_mps_available():
return torch.mps.device_count()
elif is_torch_cuda_available():
return torch.cuda.device_count()
if is_tensor:
return result.to(orig_device)
elif is_ndarray:
return result.cpu().numpy()
elif result.numel() == 1:
return result.item()
else:
return 0
return result.tolist()
def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
"""Gathers the tensor from all ranks and concats them along the first dim."""
"""Gathers the tensor from all ranks and stacks them at the first dim."""
world_size = get_world_size()
device = get_current_accelerator()
output_tensor = torch.empty(world_size * tensor.numel(), dtype=tensor.dtype, device=device)
output_tensor = torch.empty(world_size * tensor.numel(), dtype=tensor.dtype, device=tensor.device)
dist.all_gather_into_tensor(output_tensor, tensor, group=group)
return output_tensor.view(-1, *tensor.size()[1:])
return output_tensor.view(-1, *tensor.size())
def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional[ProcessGroup] = None) -> TensorLike:
def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.MEAN, group: Optional[ProcessGroup] = None) -> Tensor:
"""Performs all reduce in the given process group."""
device = get_current_accelerator()
is_ndarray = isinstance(data, np.ndarray)
is_tensor = isinstance(data, torch.Tensor)
if is_ndarray:
data = torch.from_numpy(data).to(device=device, dtype=torch.float)
elif not is_tensor:
data = torch.tensor(data, dtype=torch.float, device=device)
reduce_ops = {
ReduceOp.MEAN: dist.ReduceOp.SUM,
ReduceOp.SUM: dist.ReduceOp.SUM,
ReduceOp.MAX: dist.ReduceOp.MAX,
ReduceOp.MIN: dist.ReduceOp.MIN,
}
dist.all_reduce(data, op=reduce_ops[op], group=group)
dist.all_reduce(tensor, op=reduce_ops[op], group=group)
if op == ReduceOp.MEAN: # ReduceOp.AVG is not supported by the NPU backend
data /= dist.get_world_size(group=group)
tensor /= dist.get_world_size(group=group)
if is_tensor:
return data
elif is_ndarray:
return data.cpu().numpy()
elif data.numel() == 1:
return data.item()
else:
return data.tolist()
return tensor
def broadcast(tensor: Tensor, src: int = 0, group: Optional[ProcessGroup] = None) -> Tensor:
"""Broadcasts the tensor from the src process to all other processes."""
dist.broadcast(tensor, src=src, group=group)
return tensor
@contextmanager

View File

@@ -15,26 +15,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""A unified interface for model parallelism and data parallelism.
Supports model parallelism types:
- mp_replicate: Replicate model across multiple devices.
- mp_shard: Shard model across multiple devices.
And data parallelism types:
- dp: Data parallelism.
- cp: Context parallelism.
"""
from dataclasses import dataclass
from datetime import timedelta
from enum import Enum
from typing import Any, Optional
from torch.distributed import init_process_group
from torch.distributed import barrier, destroy_process_group, init_process_group
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike
from .helper import (
ReduceOp,
all_gather,
all_reduce,
get_current_accelerator,
get_local_rank,
get_local_world_size,
get_rank,
get_world_size,
is_distributed,
)
from . import helper
class Dim(str, Enum):
@@ -60,24 +61,24 @@ class DistributedStrategy:
"""Context parallel size, default to 1."""
def __post_init__(self) -> None:
if not is_distributed():
if not helper.is_distributed():
self.mp_shard_size = 1
elif self.mp_shard_size is None:
self.mp_shard_size = get_world_size() // self.mp_replicate_size
elif self.mp_replicate_size * self.mp_shard_size != get_world_size():
self.mp_shard_size = helper.get_world_size() // self.mp_replicate_size
elif self.mp_replicate_size * self.mp_shard_size != helper.get_world_size():
raise ValueError(
f"mp_replicate_size * mp_shard_size must equal to world_size, "
f"got {self.mp_replicate_size} * {self.mp_shard_size} != {get_world_size()}."
f"got {self.mp_replicate_size} * {self.mp_shard_size} != {helper.get_world_size()}."
)
if not is_distributed():
if not helper.is_distributed():
self.dp_size = 1
elif self.dp_size is None:
self.dp_size = get_world_size() // self.cp_size
elif self.dp_size * self.cp_size != get_world_size():
self.dp_size = helper.get_world_size() // self.cp_size
elif self.dp_size * self.cp_size != helper.get_world_size():
raise ValueError(
f"dp_size * cp_size must equal to world_size, "
f"got {self.dp_size} * {self.cp_size} != {get_world_size()}."
f"got {self.dp_size} * {self.cp_size} != {helper.get_world_size()}."
)
@property
@@ -106,20 +107,6 @@ class DistributedInterface:
_instance: Optional["DistributedInterface"] = None
_initialized: bool = False
_is_distributed = is_distributed()
_rank = get_rank()
_world_size = get_world_size()
_local_rank = get_local_rank()
_local_world_size = get_local_world_size()
strategy: Optional[DistributedStrategy] = None
"""Distributed strategy."""
model_device_mesh: Optional[DeviceMesh] = None
"""Model parallel device mesh."""
data_device_mesh: Optional[DeviceMesh] = None
"""Data parallel device mesh."""
current_accelerator = get_current_accelerator()
"""Current accelerator."""
def __new__(cls, *args: Any, **kwargs: Any) -> "DistributedInterface":
"""Singleton pattern."""
@@ -132,6 +119,14 @@ class DistributedInterface:
if self._initialized:
return
self._is_distributed = helper.is_distributed()
self._rank = helper.get_rank()
self._world_size = helper.get_world_size()
self._local_rank = helper.get_local_rank()
self._local_world_size = helper.get_local_world_size()
self.current_accelerator = helper.get_current_accelerator()
self.device_count = helper.get_device_count()
if config is None:
self.strategy = DistributedStrategy()
timeout = 18000
@@ -145,6 +140,7 @@ class DistributedInterface:
timeout = config.get("timeout", 18000)
if self._is_distributed:
helper.set_device()
init_process_group(timeout=timedelta(seconds=timeout))
self.model_device_mesh = init_device_mesh(
device_type=self.current_accelerator.type,
@@ -169,65 +165,84 @@ class DistributedInterface:
f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}"
)
@classmethod
def get_device_mesh(cls, dim: Optional[Dim] = None) -> Optional[DeviceMesh]:
def get_device_mesh(self, dim: Optional[Dim] = None) -> Optional[DeviceMesh]:
"""Get device mesh for specified dimension."""
if dim is None:
raise ValueError("dim must be specified.")
elif cls.model_device_mesh is None:
elif self.model_device_mesh is None:
return None
elif dim in cls.strategy.data_mesh_dim_names:
return cls.data_device_mesh[dim.value]
elif dim in self.strategy.data_mesh_dim_names:
return self.data_device_mesh[dim.value]
else:
return cls.model_device_mesh[dim.value]
return self.model_device_mesh[dim.value]
@classmethod
def get_group(cls, dim: Optional[Dim] = None) -> Optional[ProcessGroup]:
def get_group(self, dim: Optional[Dim] = None) -> Optional[ProcessGroup]:
"""Get process group for specified dimension."""
if cls.model_device_mesh is None or dim is None:
if self.model_device_mesh is None or dim is None:
return None
else:
return cls.get_device_mesh(dim).get_group()
return self.get_device_mesh(dim).get_group()
@classmethod
def get_rank(cls, dim: Optional[Dim] = None) -> int:
def get_rank(self, dim: Optional[Dim] = None) -> int:
"""Get parallel rank for specified dimension."""
if cls.model_device_mesh is None:
if self.model_device_mesh is None:
return 0
elif dim is None:
return cls._rank
return self._rank
else:
return cls.get_device_mesh(dim).get_local_rank()
return self.get_device_mesh(dim).get_local_rank()
@classmethod
def get_world_size(cls, dim: Optional[Dim] = None) -> int:
def get_world_size(self, dim: Optional[Dim] = None) -> int:
"""Get parallel size for specified dimension."""
if cls.model_device_mesh is None:
if self.model_device_mesh is None:
return 1
elif dim is None:
return cls._world_size
return self._world_size
else:
return cls.get_device_mesh(dim).size()
return self.get_device_mesh(dim).size()
@classmethod
def get_local_rank(cls) -> int:
def get_local_rank(self) -> int:
"""Get parallel local rank."""
return cls._local_rank
return self._local_rank
@classmethod
def get_local_world_size(cls) -> int:
def get_local_world_size(self) -> int:
"""Get parallel local world size."""
return cls._local_world_size
return self._local_world_size
@classmethod
def all_gather(cls, data: Tensor, dim: Optional[Dim] = Dim.DP) -> Tensor:
def all_gather(self, data: Tensor, dim: Optional[Dim] = Dim.DP) -> Tensor:
"""Gather tensor across specified parallel group."""
return all_gather(data, cls.get_group(dim)) if cls.model_device_mesh is not None else data
if self.model_device_mesh is not None:
return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim))
else:
return data
@classmethod
def all_reduce(cls, data: TensorLike, op: ReduceOp = ReduceOp.MEAN, dim: Optional[Dim] = Dim.DP) -> TensorLike:
def all_reduce(
self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Optional[Dim] = Dim.DP
) -> TensorLike:
"""Reduce tensor across specified parallel group."""
return all_reduce(data, op, cls.get_group(dim)) if cls.model_device_mesh is not None else data
if self.model_device_mesh is not None:
return helper.operate_tensorlike(helper.all_reduce, data, op=op, group=self.get_group(dim))
else:
return data
def broadcast(self, data: TensorLike, src: int = 0, dim: Optional[Dim] = Dim.DP) -> TensorLike:
"""Broadcast tensor across specified parallel group."""
if self.model_device_mesh is not None:
return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim))
else:
return data
def sync(self) -> None:
"""Synchronize all processes."""
helper.synchronize()
def barrier(self) -> None:
"""Barrier all processes."""
barrier()
def destroy(self) -> None:
"""Destroy all processes."""
destroy_process_group()
if __name__ == "__main__":

View File

@@ -97,7 +97,7 @@ class ModelLoader:
self.args.model,
config=self.model_config,
dtype="auto",
device_map=DistributedInterface.current_accelerator,
device_map=DistributedInterface().current_accelerator,
trust_remote_code=self.args.trust_remote_code,
)

View File

@@ -22,10 +22,10 @@ from typing import Optional
from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
from ..utils.batching_queue import BaseBatchingQueue
from ..utils.logging import get_logger
from ..utils.types import Processor, TorchDataset
from .trainer_utils.data_collator import DataCollator
from ...utils.batching_queue import BaseBatchingQueue
from ...utils.logging import get_logger
from ...utils.types import Processor, TorchDataset
from .data_collator import DataCollator
logger = get_logger(__name__)

View File

@@ -17,7 +17,7 @@ import socket
def find_available_port() -> int:
r"""Find an available port on the local machine."""
"""Find an available port on the local machine."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", 0))
port = sock.getsockname()[1]
@@ -26,9 +26,5 @@ def find_available_port() -> int:
def is_env_enabled(env_var: str, default: str = "0") -> bool:
r"""Check if the environment variable is enabled."""
return os.getenv(env_var, default).lower() in ["true", "y", "1"]
if __name__ == "__main__":
print(find_available_port())
"""Check if the environment variable is enabled."""
return os.getenv(env_var, default).lower() in ["true", "yes", "on", "t", "y", "1"]

View File

@@ -0,0 +1,35 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from contextlib import contextmanager
@contextmanager
def dist_env(local_rank: int = 0, world_size: int = 1, master_port: int = 25595):
"""Set distributed environment variables."""
env_vars = {
"MASTER_ADDR": "127.0.0.1",
"MASTER_PORT": str(master_port),
"RANK": str(local_rank),
"LOCAL_RANK": str(local_rank),
"WORLD_SIZE": str(world_size),
"LOCAL_WORLD_SIZE": str(world_size),
}
os.environ.update(env_vars)
try:
yield
finally:
for key in env_vars.keys():
os.environ.pop(key, None)