From c93c4dd7b2b68db92997f65e11d7b08acecce891 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Wed, 21 Oct 2020 06:21:20 -0700 Subject: [PATCH] axis_angle representation of rotations Summary: We can represent a rotation as a vector in the axis direction, whose length is the rotation anticlockwise in radians around that axis. Reviewed By: gkioxari Differential Revision: D24306293 fbshipit-source-id: 2e0f138eda8329f6cceff600a6e5f17a00e4deb7 --- pytorch3d/transforms/rotation_conversions.py | 95 ++++++++++++++++++++ tests/test_rotation_conversions.py | 43 +++++++-- 2 files changed, 131 insertions(+), 7 deletions(-) diff --git a/pytorch3d/transforms/rotation_conversions.py b/pytorch3d/transforms/rotation_conversions.py index 9e41ed3e..254340f5 100644 --- a/pytorch3d/transforms/rotation_conversions.py +++ b/pytorch3d/transforms/rotation_conversions.py @@ -413,6 +413,101 @@ def quaternion_apply(quaternion, point): return out[..., 1:] +def axis_angle_to_matrix(axis_angle): + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def matrix_to_axis_angle(matrix): + """ + Convert rotations given as rotation matrices to axis/angle. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + + +def axis_angle_to_quaternion(axis_angle): + """ + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = 0.5 * angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - torch.square(angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + + +def quaternion_to_axis_angle(quaternions): + """ + Convert rotations given as quaternions to axis/angle. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) + half_angles = torch.atan2(norms, quaternions[..., :1]) + angles = 2 * half_angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - torch.square(angles[small_angles]) / 48 + ) + return quaternions[..., 1:] / sin_half_angles_over_angles + + def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: """ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix diff --git a/tests/test_rotation_conversions.py b/tests/test_rotation_conversions.py index 74dcd1a0..f575da6e 100644 --- a/tests/test_rotation_conversions.py +++ b/tests/test_rotation_conversions.py @@ -8,12 +8,16 @@ import unittest import torch from common_testing import TestCaseMixin from pytorch3d.transforms.rotation_conversions import ( + axis_angle_to_matrix, + axis_angle_to_quaternion, euler_angles_to_matrix, + matrix_to_axis_angle, matrix_to_euler_angles, matrix_to_quaternion, matrix_to_rotation_6d, quaternion_apply, quaternion_multiply, + quaternion_to_axis_angle, quaternion_to_matrix, random_quaternions, random_rotation, @@ -60,13 +64,13 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase): """quat -> mtx -> quat""" data = random_quaternions(13, dtype=torch.float64) mdata = matrix_to_quaternion(quaternion_to_matrix(data)) - self.assertTrue(torch.allclose(data, mdata)) + self.assertClose(data, mdata) def test_to_quat(self): """mtx -> quat -> mtx""" data = random_rotations(13, dtype=torch.float64) mdata = quaternion_to_matrix(matrix_to_quaternion(data)) - self.assertTrue(torch.allclose(data, mdata)) + self.assertClose(data, mdata) def test_quat_grad_exists(self): """Quaternion calculations are differentiable.""" @@ -107,13 +111,13 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase): for convention in self._tait_bryan_conventions(): matrices = euler_angles_to_matrix(data, convention) mdata = matrix_to_euler_angles(matrices, convention) - self.assertTrue(torch.allclose(data, mdata)) + self.assertClose(data, mdata) data[:, 1] += half_pi for convention in self._proper_euler_conventions(): matrices = euler_angles_to_matrix(data, convention) mdata = matrix_to_euler_angles(matrices, convention) - self.assertTrue(torch.allclose(data, mdata)) + self.assertClose(data, mdata) def test_to_euler(self): """mtx -> euler -> mtx""" @@ -121,7 +125,7 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase): for convention in self._all_euler_angle_conventions(): euler_angles = matrix_to_euler_angles(data, convention) mdata = euler_angles_to_matrix(euler_angles, convention) - self.assertTrue(torch.allclose(data, mdata)) + self.assertClose(data, mdata) def test_euler_grad_exists(self): """Euler angle calculations are differentiable.""" @@ -143,7 +147,7 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase): ab_matrix = torch.matmul(a_matrix, b_matrix) ab_from_matrix = matrix_to_quaternion(ab_matrix) self.assertEqual(ab.shape, ab_from_matrix.shape) - self.assertTrue(torch.allclose(ab, ab_from_matrix)) + self.assertClose(ab, ab_from_matrix) def test_matrix_to_quaternion_corner_case(self): """Check no bad gradients from sqrt(0).""" @@ -159,6 +163,31 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase): self.assertClose(matrix, 0.95 * torch.eye(3)) + def test_from_axis_angle(self): + """axis_angle -> mtx -> axis_angle""" + n_repetitions = 20 + data = torch.rand(n_repetitions, 3) + matrices = axis_angle_to_matrix(data) + mdata = matrix_to_axis_angle(matrices) + self.assertClose(data, mdata, atol=2e-6) + + def test_from_axis_angle_has_grad(self): + n_repetitions = 20 + data = torch.rand(n_repetitions, 3, requires_grad=True) + matrices = axis_angle_to_matrix(data) + mdata = matrix_to_axis_angle(matrices) + quats = axis_angle_to_quaternion(data) + mdata2 = quaternion_to_axis_angle(quats) + (grad,) = torch.autograd.grad(mdata.sum() + mdata2.sum(), data) + self.assertTrue(torch.isfinite(grad).all()) + + def test_to_axis_angle(self): + """mtx -> axis_angle -> mtx""" + data = random_rotations(13, dtype=torch.float64) + euler_angles = matrix_to_axis_angle(data) + mdata = axis_angle_to_matrix(euler_angles) + self.assertClose(data, mdata) + def test_quaternion_application(self): """Applying a quaternion is the same as applying the matrix.""" quaternions = random_quaternions(3, torch.float64, requires_grad=True) @@ -166,7 +195,7 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase): points = torch.randn(3, 3, dtype=torch.float64, requires_grad=True) transform1 = quaternion_apply(quaternions, points) transform2 = torch.matmul(matrices, points[..., None])[..., 0] - self.assertTrue(torch.allclose(transform1, transform2)) + self.assertClose(transform1, transform2) [p, q] = torch.autograd.grad(transform1.sum(), [points, quaternions]) self.assertTrue(torch.isfinite(p).all())