diff --git a/pytorch3d/transforms/rotation_conversions.py b/pytorch3d/transforms/rotation_conversions.py index 870ee086..eb1a42cc 100644 --- a/pytorch3d/transforms/rotation_conversions.py +++ b/pytorch3d/transforms/rotation_conversions.py @@ -94,13 +94,9 @@ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: Returns torch.sqrt(torch.max(0, x)) but with a zero subgradient where x is 0. """ - ret = torch.zeros_like(x) positive_mask = x > 0 - if torch.is_grad_enabled(): - ret[positive_mask] = torch.sqrt(x[positive_mask]) - else: - ret = torch.where(positive_mask, torch.sqrt(x), ret) - return ret + safe_x = torch.where(positive_mask, x, 1.0) + return torch.where(positive_mask, torch.sqrt(safe_x), 0.0) def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: