diff --git a/pytorch3d/transforms/rotation_conversions.py b/pytorch3d/transforms/rotation_conversions.py index b5f73bf5..459441ca 100644 --- a/pytorch3d/transforms/rotation_conversions.py +++ b/pytorch3d/transforms/rotation_conversions.py @@ -155,10 +155,10 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: # if not for numerical problems, quat_candidates[i] should be same (up to a sign), # forall i; we pick the best-conditioned one (with the largest denominator) - - return quat_candidates[ + out = quat_candidates[ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : ].reshape(batch_dim + (4,)) + return standardize_quaternion(out) def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: