Revert _sqrt_positive_part change

Reviewed By: bottler

Differential Revision: D77549647

fbshipit-source-id: a0ef0bc015c643ad7416c781886e2e23b5105bdd
This commit is contained in:
Srivathsan Govindarajan 2025-06-30 14:13:27 -07:00 committed by Facebook GitHub Bot
parent 177eec6378
commit 267bd8ef87

View File

@ -95,7 +95,13 @@ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
Returns torch.sqrt(torch.max(0, x))
but with a zero subgradient where x is 0.
"""
return torch.sqrt(F.relu(x))
ret = torch.zeros_like(x)
positive_mask = x > 0
if torch.is_grad_enabled():
ret[positive_mask] = torch.sqrt(x[positive_mask])
else:
ret = torch.where(positive_mask, torch.sqrt(x), ret)
return ret
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: