mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-27 09:10:35 +08:00
[misc] fix accelerator (#9661)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -15,11 +15,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utility functions used by the distributed interface.
|
||||
|
||||
Including:
|
||||
- Environment info (rank, world_size, local_rank, etc.)
|
||||
- Accelerator info (device type, device count, etc.)
|
||||
- Collective communication operations (all_gather, all_reduce, broadcast)
|
||||
- Synchronize processes and ensure main-process-first execution order
|
||||
"""
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum, unique
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
from functools import lru_cache, wraps
|
||||
from typing import Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -46,6 +55,22 @@ class ReduceOp(str, Enum):
|
||||
MIN = "min"
|
||||
|
||||
|
||||
def requires_accelerator(fn):
|
||||
"""Decorator to check if torch.accelerator is available.
|
||||
|
||||
Note: this api requires torch>=2.7.0, otherwise it will raise an AttributeError or RuntimeError
|
||||
"""
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not hasattr(torch, "accelerator"):
|
||||
raise RuntimeError("torch.accelerator is not available, please upgrade torch to 2.7.0 or higher.")
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def is_distributed() -> bool:
|
||||
"""Check if distributed environment is available."""
|
||||
return os.getenv("RANK") is not None
|
||||
@@ -72,105 +97,105 @@ def get_local_world_size() -> int:
|
||||
|
||||
|
||||
@lru_cache
|
||||
@requires_accelerator
|
||||
def get_current_accelerator(check_available: bool = True) -> torch.device:
|
||||
"""Get current accelerator.
|
||||
|
||||
Note: this api requires torch>=2.7.0, otherwise it will raise an AttributeError or RuntimeError
|
||||
"""
|
||||
if not hasattr(torch, "accelerator"):
|
||||
raise RuntimeError("torch.accelerator is not available, please upgrade torch to 2.7.0 or higher.")
|
||||
|
||||
"""Get current accelerator."""
|
||||
accelerator = torch.accelerator.current_accelerator(check_available=check_available)
|
||||
if accelerator is None:
|
||||
return torch.device(DeviceType.CPU.value)
|
||||
return accelerator or torch.device(DeviceType.CPU.value)
|
||||
|
||||
return accelerator
|
||||
|
||||
@lru_cache
|
||||
@requires_accelerator
|
||||
def get_device_count() -> int:
|
||||
"""Get the number of available devices."""
|
||||
return torch.accelerator.device_count()
|
||||
|
||||
|
||||
@requires_accelerator
|
||||
def synchronize() -> None:
|
||||
"""Synchronize all processes."""
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
|
||||
@requires_accelerator
|
||||
def set_device() -> None:
|
||||
"""Set current accelerator."""
|
||||
torch.accelerator.set_device_index(get_local_rank())
|
||||
|
||||
|
||||
def is_torch_cuda_available():
|
||||
"""Check if CUDA is available."""
|
||||
return get_current_accelerator().type == DeviceType.CUDA
|
||||
|
||||
|
||||
def is_torch_mps_available():
|
||||
"""Check if MPS is available."""
|
||||
return get_current_accelerator().type == DeviceType.MPS
|
||||
|
||||
|
||||
def is_torch_npu_available():
|
||||
"""Check if NPU is available."""
|
||||
return get_current_accelerator().type == DeviceType.NPU
|
||||
|
||||
|
||||
def is_torch_xpu_available():
|
||||
"""Check if XPU is available."""
|
||||
return get_current_accelerator().type == DeviceType.XPU
|
||||
|
||||
|
||||
def get_current_device() -> "torch.device":
|
||||
r"""Get the current available device."""
|
||||
if is_torch_xpu_available():
|
||||
device = "xpu:{}".format(os.getenv("LOCAL_RANK", "0"))
|
||||
elif is_torch_npu_available():
|
||||
device = "npu:{}".format(os.getenv("LOCAL_RANK", "0"))
|
||||
elif is_torch_mps_available():
|
||||
device = "mps:{}".format(os.getenv("LOCAL_RANK", "0"))
|
||||
elif is_torch_cuda_available():
|
||||
device = "cuda:{}".format(os.getenv("LOCAL_RANK", "0"))
|
||||
def operate_tensorlike(fn: Callable[[...], Tensor], data: TensorLike, **kwargs) -> TensorLike:
|
||||
"""Operate tensorlike data on current accelerator."""
|
||||
device = get_current_accelerator()
|
||||
is_tensor = isinstance(data, torch.Tensor)
|
||||
is_ndarray = isinstance(data, np.ndarray)
|
||||
|
||||
if is_tensor:
|
||||
orig_device = data.device
|
||||
data = data.to(device=device)
|
||||
elif is_ndarray:
|
||||
data = torch.from_numpy(data).to(device=device, dtype=torch.float)
|
||||
else:
|
||||
device = "cpu"
|
||||
data = torch.tensor(data, dtype=torch.float, device=device)
|
||||
|
||||
return torch.device(device)
|
||||
result = fn(data, **kwargs)
|
||||
|
||||
|
||||
def get_device_count() -> int:
|
||||
r"""Get the number of available devices."""
|
||||
if is_torch_xpu_available():
|
||||
return torch.xpu.device_count()
|
||||
elif is_torch_npu_available():
|
||||
return torch.npu.device_count()
|
||||
elif is_torch_mps_available():
|
||||
return torch.mps.device_count()
|
||||
elif is_torch_cuda_available():
|
||||
return torch.cuda.device_count()
|
||||
if is_tensor:
|
||||
return result.to(orig_device)
|
||||
elif is_ndarray:
|
||||
return result.cpu().numpy()
|
||||
elif result.numel() == 1:
|
||||
return result.item()
|
||||
else:
|
||||
return 0
|
||||
return result.tolist()
|
||||
|
||||
|
||||
def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
"""Gathers the tensor from all ranks and concats them along the first dim."""
|
||||
"""Gathers the tensor from all ranks and stacks them at the first dim."""
|
||||
world_size = get_world_size()
|
||||
device = get_current_accelerator()
|
||||
output_tensor = torch.empty(world_size * tensor.numel(), dtype=tensor.dtype, device=device)
|
||||
output_tensor = torch.empty(world_size * tensor.numel(), dtype=tensor.dtype, device=tensor.device)
|
||||
dist.all_gather_into_tensor(output_tensor, tensor, group=group)
|
||||
return output_tensor.view(-1, *tensor.size()[1:])
|
||||
return output_tensor.view(-1, *tensor.size())
|
||||
|
||||
|
||||
def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional[ProcessGroup] = None) -> TensorLike:
|
||||
def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.MEAN, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
"""Performs all reduce in the given process group."""
|
||||
device = get_current_accelerator()
|
||||
is_ndarray = isinstance(data, np.ndarray)
|
||||
is_tensor = isinstance(data, torch.Tensor)
|
||||
|
||||
if is_ndarray:
|
||||
data = torch.from_numpy(data).to(device=device, dtype=torch.float)
|
||||
elif not is_tensor:
|
||||
data = torch.tensor(data, dtype=torch.float, device=device)
|
||||
|
||||
reduce_ops = {
|
||||
ReduceOp.MEAN: dist.ReduceOp.SUM,
|
||||
ReduceOp.SUM: dist.ReduceOp.SUM,
|
||||
ReduceOp.MAX: dist.ReduceOp.MAX,
|
||||
ReduceOp.MIN: dist.ReduceOp.MIN,
|
||||
}
|
||||
dist.all_reduce(data, op=reduce_ops[op], group=group)
|
||||
dist.all_reduce(tensor, op=reduce_ops[op], group=group)
|
||||
if op == ReduceOp.MEAN: # ReduceOp.AVG is not supported by the NPU backend
|
||||
data /= dist.get_world_size(group=group)
|
||||
tensor /= dist.get_world_size(group=group)
|
||||
|
||||
if is_tensor:
|
||||
return data
|
||||
elif is_ndarray:
|
||||
return data.cpu().numpy()
|
||||
elif data.numel() == 1:
|
||||
return data.item()
|
||||
else:
|
||||
return data.tolist()
|
||||
return tensor
|
||||
|
||||
|
||||
def broadcast(tensor: Tensor, src: int = 0, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
"""Broadcasts the tensor from the src process to all other processes."""
|
||||
dist.broadcast(tensor, src=src, group=group)
|
||||
return tensor
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
||||
Reference in New Issue
Block a user