From 71db7a0ea293f2626f5ea1c70366870a607129d0 Mon Sep 17 00:00:00 2001 From: Srivathsan Govindarajan Date: Wed, 25 Jun 2025 01:18:46 -0700 Subject: [PATCH] Removing dynamic shape ops and boolean indexing in matrix_to_quaternion Summary: The current implementation of `matrix_to_quaternion` and `_sqrt_positive_part` uses boolean indexing, which can slow down performance and cause incompatibility with `torch.compile` unless `torch._dynamo.config.capture_dynamic_output_shape_ops` is set to `True`. To enhance performance and compatibility, I recommend using `torch.gather` to select the best-conditioned quaternions and `F.relu` instead of `x>0` (bottler's suggestion) For a detailed comparison of the implementation differences when using `torch.compile`, please refer to my Bento notebook N7438339. Reviewed By: bottler Differential Revision: D77176230 fbshipit-source-id: 9a6a2e0015b5865056297d5f45badc3c425b93ce --- pytorch3d/transforms/rotation_conversions.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/pytorch3d/transforms/rotation_conversions.py b/pytorch3d/transforms/rotation_conversions.py index c3320a45..1aec4fc6 100644 --- a/pytorch3d/transforms/rotation_conversions.py +++ b/pytorch3d/transforms/rotation_conversions.py @@ -95,13 +95,7 @@ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: Returns torch.sqrt(torch.max(0, x)) but with a zero subgradient where x is 0. """ - ret = torch.zeros_like(x) - positive_mask = x > 0 - if torch.is_grad_enabled(): - ret[positive_mask] = torch.sqrt(x[positive_mask]) - else: - ret = torch.where(positive_mask, torch.sqrt(x), ret) - return ret + return torch.sqrt(F.relu(x)) def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: @@ -160,9 +154,10 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: # if not for numerical problems, quat_candidates[i] should be same (up to a sign), # forall i; we pick the best-conditioned one (with the largest denominator) - out = quat_candidates[ - F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : - ].reshape(batch_dim + (4,)) + indices = q_abs.argmax(dim=-1, keepdim=True) + expand_dims = list(batch_dim) + [1, 4] + gather_indices = indices.unsqueeze(-1).expand(expand_dims) + out = torch.gather(quat_candidates, -2, gather_indices).squeeze(-2) return standardize_quaternion(out)