mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-27 09:10:35 +08:00
[test] add allreduce test on npu (#9619)
Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
@@ -103,6 +103,36 @@ def is_torch_xpu_available():
|
||||
return get_current_accelerator().type == DeviceType.XPU
|
||||
|
||||
|
||||
def get_current_device() -> "torch.device":
|
||||
r"""Get the current available device."""
|
||||
if is_torch_xpu_available():
|
||||
device = "xpu:{}".format(os.getenv("LOCAL_RANK", "0"))
|
||||
elif is_torch_npu_available():
|
||||
device = "npu:{}".format(os.getenv("LOCAL_RANK", "0"))
|
||||
elif is_torch_mps_available():
|
||||
device = "mps:{}".format(os.getenv("LOCAL_RANK", "0"))
|
||||
elif is_torch_cuda_available():
|
||||
device = "cuda:{}".format(os.getenv("LOCAL_RANK", "0"))
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
return torch.device(device)
|
||||
|
||||
|
||||
def get_device_count() -> int:
|
||||
r"""Get the number of available devices."""
|
||||
if is_torch_xpu_available():
|
||||
return torch.xpu.device_count()
|
||||
elif is_torch_npu_available():
|
||||
return torch.npu.device_count()
|
||||
elif is_torch_mps_available():
|
||||
return torch.mps.device_count()
|
||||
elif is_torch_cuda_available():
|
||||
return torch.cuda.device_count()
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
"""Gathers the tensor from all ranks and concats them along the first dim."""
|
||||
world_size = get_world_size()
|
||||
|
||||
Reference in New Issue
Block a user