diff --git a/pytorch3d/transforms/rotation_conversions.py b/pytorch3d/transforms/rotation_conversions.py index 0164e413..a9fcae22 100644 --- a/pytorch3d/transforms/rotation_conversions.py +++ b/pytorch3d/transforms/rotation_conversions.py @@ -97,7 +97,10 @@ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: """ ret = torch.zeros_like(x) positive_mask = x > 0 - ret[positive_mask] = torch.sqrt(x[positive_mask]) + if torch.is_grad_enabled(): + ret[positive_mask] = torch.sqrt(x[positive_mask]) + else: + ret = torch.where(positive_mask, torch.sqrt(x), ret) return ret