From 1b39cebe9209172e543bdb9cd4d890d4023d5caf Mon Sep 17 00:00:00 2001 From: Roman Shapovalov Date: Fri, 18 Jun 2021 06:39:08 -0700 Subject: [PATCH] Sign issue about quaternion_to_matrix and matrix_to_quaternion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: As reported on github, `matrix_to_quaternion` was incorrect for rotations by 180˚. We resolved the sign of the component `i` based on the sign of `i*r`, assuming `r > 0`, which is untrue if `r == 0`. This diff handles special cases and ensures we use the non-zero elements to copy the sign from. Reviewed By: bottler Differential Revision: D29149465 fbshipit-source-id: cd508cc31567fc37ea3463dd7e8c8e8d5d64a235 --- pytorch3d/transforms/rotation_conversions.py | 53 +++++++++++++++----- tests/test_rotation_conversions.py | 53 ++++++++++++++++++-- 2 files changed, 89 insertions(+), 17 deletions(-) diff --git a/pytorch3d/transforms/rotation_conversions.py b/pytorch3d/transforms/rotation_conversions.py index b9050089..066cfeac 100644 --- a/pytorch3d/transforms/rotation_conversions.py +++ b/pytorch3d/transforms/rotation_conversions.py @@ -82,7 +82,7 @@ def _copysign(a, b): return torch.where(signs_differ, -a, a) -def _sqrt_positive_part(x): +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: """ Returns torch.sqrt(torch.max(0, x)) but with a zero subgradient where x is 0. @@ -93,7 +93,7 @@ def _sqrt_positive_part(x): return ret -def matrix_to_quaternion(matrix): +def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: """ Convert rotations given as rotation matrices to quaternions. @@ -105,17 +105,44 @@ def matrix_to_quaternion(matrix): """ if matrix.size(-1) != 3 or matrix.size(-2) != 3: raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") - m00 = matrix[..., 0, 0] - m11 = matrix[..., 1, 1] - m22 = matrix[..., 2, 2] - o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22) - x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22) - y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22) - z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22) - o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2]) - o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0]) - o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1]) - return torch.stack((o0, o1, o2, o3), -1) + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(*batch_dim, 9), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # clipping is not important here; if q_abs is small, the candidate won't be picked + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].clip(0.1)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + + return quat_candidates[ + F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : # pyre-ignore[16] + ].reshape(*batch_dim, 4) def _axis_angle_rotation(axis: str, angle): diff --git a/tests/test_rotation_conversions.py b/tests/test_rotation_conversions.py index f575da6e..9875d01a 100644 --- a/tests/test_rotation_conversions.py +++ b/tests/test_rotation_conversions.py @@ -4,7 +4,9 @@ import itertools import math import unittest +from typing import Optional, Union +import numpy as np import torch from common_testing import TestCaseMixin from pytorch3d.transforms.rotation_conversions import ( @@ -64,7 +66,7 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase): """quat -> mtx -> quat""" data = random_quaternions(13, dtype=torch.float64) mdata = matrix_to_quaternion(quaternion_to_matrix(data)) - self.assertClose(data, mdata) + self._assert_quaternions_close(data, mdata) def test_to_quat(self): """mtx -> quat -> mtx""" @@ -146,8 +148,7 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase): b_matrix = quaternion_to_matrix(b) ab_matrix = torch.matmul(a_matrix, b_matrix) ab_from_matrix = matrix_to_quaternion(ab_matrix) - self.assertEqual(ab.shape, ab_from_matrix.shape) - self.assertClose(ab, ab_from_matrix) + self._assert_quaternions_close(ab, ab_from_matrix) def test_matrix_to_quaternion_corner_case(self): """Check no bad gradients from sqrt(0).""" @@ -161,7 +162,34 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase): loss.backward() optimizer.step() - self.assertClose(matrix, 0.95 * torch.eye(3)) + self.assertClose(matrix, matrix, msg="Result has non-finite values") + delta = 1e-2 + self.assertLess( + matrix.trace(), + 3.0 - delta, + msg="Identity initialisation unchanged by a gradient step", + ) + + def test_matrix_to_quaternion_by_pi(self): + # We check that rotations by pi around each of the 26 + # nonzero vectors containing nothing but 0, 1 and -1 + # are mapped to the right quaternions. + # This is representative across the directions. + options = [0.0, -1.0, 1.0] + axes = [ + torch.tensor(vec) + for vec in itertools.islice( # exclude [0, 0, 0] + itertools.product(options, options, options), 1, None + ) + ] + + axes = torch.nn.functional.normalize(torch.stack(axes), dim=-1) + # Rotation by pi around unit vector x is given by + # the matrix 2 x x^T - Id. + R = 2 * torch.matmul(axes[..., None], axes[..., None, :]) - torch.eye(3) + quats_hat = matrix_to_quaternion(R) + R_hat = quaternion_to_matrix(quats_hat) + self.assertClose(R, R_hat, atol=1e-3) def test_from_axis_angle(self): """axis_angle -> mtx -> axis_angle""" @@ -228,3 +256,20 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase): self.assertClose( torch.matmul(r, r.permute(0, 2, 1)), torch.eye(3).expand_as(r), atol=1e-6 ) + + def _assert_quaternions_close( + self, + input: Union[torch.Tensor, np.ndarray], + other: Union[torch.Tensor, np.ndarray], + *, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False, + msg: Optional[str] = None, + ): + self.assertEqual(np.shape(input), np.shape(other)) + dot = (input * other).sum(-1) + ones = torch.ones_like(dot) + self.assertClose( + dot.abs(), ones, rtol=rtol, atol=atol, equal_nan=equal_nan, msg=msg + )