Make _sqrt_positive_part ONNX-exportable

Summary:
Replace boolean indexing and torch.is_grad_enabled() control flow in _sqrt_positive_part with a pure torch.where implementation. The old code used ret[positive_mask] = torch.sqrt(x[positive_mask]) which produces an incorrect ONNX Where/index_put node with mismatched broadcast shapes when the model is exported via torch.onnx.export.

The new implementation substitutes 1.0 for non-positive values before sqrt (avoiding infinite gradient at sqrt(0)) and masks the result back to 0, preserving the zero-subgradient-at-zero property.

Fixes https://github.com/facebookresearch/pytorch3d/issues/2020

Reviewed By: sgrigory

Differential Revision: D94365479

fbshipit-source-id: a1ebe8dc077573f83efc262520b6669159b83ef0
This commit is contained in:
Jeremy Reizenstein
2026-03-06 05:23:55 -08:00
committed by meta-codesync[bot]
parent 7a6157e38e
commit 61cc79aa34

View File

@@ -94,13 +94,9 @@ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
Returns torch.sqrt(torch.max(0, x))
but with a zero subgradient where x is 0.
"""
ret = torch.zeros_like(x)
positive_mask = x > 0
if torch.is_grad_enabled():
ret[positive_mask] = torch.sqrt(x[positive_mask])
else:
ret = torch.where(positive_mask, torch.sqrt(x), ret)
return ret
safe_x = torch.where(positive_mask, x, 1.0)
return torch.where(positive_mask, torch.sqrt(safe_x), 0.0)
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: