From 61cc79aa340412c33407771bc97236ccd9ee1548 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Fri, 6 Mar 2026 05:23:55 -0800 Subject: [PATCH] Make _sqrt_positive_part ONNX-exportable Summary: Replace boolean indexing and torch.is_grad_enabled() control flow in _sqrt_positive_part with a pure torch.where implementation. The old code used ret[positive_mask] = torch.sqrt(x[positive_mask]) which produces an incorrect ONNX Where/index_put node with mismatched broadcast shapes when the model is exported via torch.onnx.export. The new implementation substitutes 1.0 for non-positive values before sqrt (avoiding infinite gradient at sqrt(0)) and masks the result back to 0, preserving the zero-subgradient-at-zero property. Fixes https://github.com/facebookresearch/pytorch3d/issues/2020 Reviewed By: sgrigory Differential Revision: D94365479 fbshipit-source-id: a1ebe8dc077573f83efc262520b6669159b83ef0 --- pytorch3d/transforms/rotation_conversions.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/pytorch3d/transforms/rotation_conversions.py b/pytorch3d/transforms/rotation_conversions.py index 870ee086..eb1a42cc 100644 --- a/pytorch3d/transforms/rotation_conversions.py +++ b/pytorch3d/transforms/rotation_conversions.py @@ -94,13 +94,9 @@ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: Returns torch.sqrt(torch.max(0, x)) but with a zero subgradient where x is 0. """ - ret = torch.zeros_like(x) positive_mask = x > 0 - if torch.is_grad_enabled(): - ret[positive_mask] = torch.sqrt(x[positive_mask]) - else: - ret = torch.where(positive_mask, torch.sqrt(x), ret) - return ret + safe_x = torch.where(positive_mask, x, 1.0) + return torch.where(positive_mask, torch.sqrt(safe_x), 0.0) def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: