diff --git a/pytorch3d/transforms/rotation_conversions.py b/pytorch3d/transforms/rotation_conversions.py index c3320a45..1aec4fc6 100644 --- a/pytorch3d/transforms/rotation_conversions.py +++ b/pytorch3d/transforms/rotation_conversions.py @@ -95,13 +95,7 @@ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: Returns torch.sqrt(torch.max(0, x)) but with a zero subgradient where x is 0. """ - ret = torch.zeros_like(x) - positive_mask = x > 0 - if torch.is_grad_enabled(): - ret[positive_mask] = torch.sqrt(x[positive_mask]) - else: - ret = torch.where(positive_mask, torch.sqrt(x), ret) - return ret + return torch.sqrt(F.relu(x)) def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: @@ -160,9 +154,10 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: # if not for numerical problems, quat_candidates[i] should be same (up to a sign), # forall i; we pick the best-conditioned one (with the largest denominator) - out = quat_candidates[ - F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : - ].reshape(batch_dim + (4,)) + indices = q_abs.argmax(dim=-1, keepdim=True) + expand_dims = list(batch_dim) + [1, 4] + gather_indices = indices.unsqueeze(-1).expand(expand_dims) + out = torch.gather(quat_candidates, -2, gather_indices).squeeze(-2) return standardize_quaternion(out)