From 267bd8ef873d2de8502f5ba52af718d2bdca3375 Mon Sep 17 00:00:00 2001 From: Srivathsan Govindarajan Date: Mon, 30 Jun 2025 14:13:27 -0700 Subject: [PATCH] Revert `_sqrt_positive_part` change Reviewed By: bottler Differential Revision: D77549647 fbshipit-source-id: a0ef0bc015c643ad7416c781886e2e23b5105bdd --- pytorch3d/transforms/rotation_conversions.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pytorch3d/transforms/rotation_conversions.py b/pytorch3d/transforms/rotation_conversions.py index 1aec4fc6..19fb890d 100644 --- a/pytorch3d/transforms/rotation_conversions.py +++ b/pytorch3d/transforms/rotation_conversions.py @@ -95,7 +95,13 @@ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: Returns torch.sqrt(torch.max(0, x)) but with a zero subgradient where x is 0. """ - return torch.sqrt(F.relu(x)) + ret = torch.zeros_like(x) + positive_mask = x > 0 + if torch.is_grad_enabled(): + ret[positive_mask] = torch.sqrt(x[positive_mask]) + else: + ret = torch.where(positive_mask, torch.sqrt(x), ret) + return ret def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: