mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-27 09:10:35 +08:00
[v1] add dp & mp mesh (#9611)
This commit is contained in:
@@ -60,16 +60,16 @@ def get_rank() -> int:
|
||||
return int(os.getenv("RANK", "0"))
|
||||
|
||||
|
||||
def get_local_rank() -> int:
|
||||
"""Get local rank."""
|
||||
return int(os.getenv("LOCAL_RANK", "0"))
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
"""Get world size."""
|
||||
return int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
||||
|
||||
def get_local_rank() -> int:
|
||||
"""Get local rank."""
|
||||
return int(os.getenv("LOCAL_RANK", "0"))
|
||||
|
||||
|
||||
def get_local_world_size() -> int:
|
||||
"""Get local world size."""
|
||||
return int(os.getenv("LOCAL_WORLD_SIZE", "1"))
|
||||
@@ -79,7 +79,7 @@ def get_local_world_size() -> int:
|
||||
def get_current_accelerator(check_available: bool = True) -> torch.device:
|
||||
"""Get current accelerator.
|
||||
|
||||
Note: this api requires torch>=2.7.0, 2.6 or lower will get an AttributeError or RuntimeError
|
||||
Note: this api requires torch>=2.7.0, otherwise it will raise an AttributeError or RuntimeError
|
||||
"""
|
||||
if not hasattr(torch, "accelerator"):
|
||||
raise RuntimeError("torch.accelerator is not available, please upgrade torch to 2.7.0 or higher.")
|
||||
@@ -123,7 +123,7 @@ def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional["
|
||||
is_tensor = isinstance(data, torch.Tensor)
|
||||
|
||||
if is_ndarray:
|
||||
data = torch.from_numpy(data)
|
||||
data = torch.from_numpy(data).to(device=device, dtype=torch.float)
|
||||
elif not is_tensor:
|
||||
data = torch.tensor(data, dtype=torch.float, device=device)
|
||||
|
||||
@@ -140,7 +140,7 @@ def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional["
|
||||
if is_tensor:
|
||||
return data
|
||||
elif is_ndarray:
|
||||
return data.numpy()
|
||||
return data.cpu().numpy()
|
||||
elif data.numel() == 1:
|
||||
return data.item()
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user