diff --git a/pytorch3d/transforms/rotation_conversions.py b/pytorch3d/transforms/rotation_conversions.py index 1aec4fc6..19fb890d 100644 --- a/pytorch3d/transforms/rotation_conversions.py +++ b/pytorch3d/transforms/rotation_conversions.py @@ -95,7 +95,13 @@ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: Returns torch.sqrt(torch.max(0, x)) but with a zero subgradient where x is 0. """ - return torch.sqrt(F.relu(x)) + ret = torch.zeros_like(x) + positive_mask = x > 0 + if torch.is_grad_enabled(): + ret[positive_mask] = torch.sqrt(x[positive_mask]) + else: + ret = torch.where(positive_mask, torch.sqrt(x), ret) + return ret def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: