[v1] add dp & mp mesh (#9611)

This commit is contained in:
Yaowei Zheng
2025-12-13 01:44:28 +08:00
committed by GitHub
parent 203069e11c
commit 110d21713e
3 changed files with 158 additions and 55 deletions

View File

@@ -60,16 +60,16 @@ def get_rank() -> int:
return int(os.getenv("RANK", "0"))
def get_local_rank() -> int:
"""Get local rank."""
return int(os.getenv("LOCAL_RANK", "0"))
def get_world_size() -> int:
"""Get world size."""
return int(os.getenv("WORLD_SIZE", "1"))
def get_local_rank() -> int:
"""Get local rank."""
return int(os.getenv("LOCAL_RANK", "0"))
def get_local_world_size() -> int:
"""Get local world size."""
return int(os.getenv("LOCAL_WORLD_SIZE", "1"))
@@ -79,7 +79,7 @@ def get_local_world_size() -> int:
def get_current_accelerator(check_available: bool = True) -> torch.device:
"""Get current accelerator.
Note: this api requires torch>=2.7.0, 2.6 or lower will get an AttributeError or RuntimeError
Note: this api requires torch>=2.7.0, otherwise it will raise an AttributeError or RuntimeError
"""
if not hasattr(torch, "accelerator"):
raise RuntimeError("torch.accelerator is not available, please upgrade torch to 2.7.0 or higher.")
@@ -123,7 +123,7 @@ def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional["
is_tensor = isinstance(data, torch.Tensor)
if is_ndarray:
data = torch.from_numpy(data)
data = torch.from_numpy(data).to(device=device, dtype=torch.float)
elif not is_tensor:
data = torch.tensor(data, dtype=torch.float, device=device)
@@ -140,7 +140,7 @@ def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional["
if is_tensor:
return data
elif is_ndarray:
return data.numpy()
return data.cpu().numpy()
elif data.numel() == 1:
return data.item()
else: