axis_angle representation of rotations

Summary: We can represent a rotation as a vector in the axis direction, whose length is the rotation anticlockwise in radians around that axis.

Reviewed By: gkioxari

Differential Revision: D24306293

fbshipit-source-id: 2e0f138eda8329f6cceff600a6e5f17a00e4deb7
This commit is contained in:
Jeremy Reizenstein
2020-10-21 06:21:20 -07:00
committed by Facebook GitHub Bot
parent 005a334f99
commit c93c4dd7b2
2 changed files with 131 additions and 7 deletions

View File

@@ -8,12 +8,16 @@ import unittest
import torch
from common_testing import TestCaseMixin
from pytorch3d.transforms.rotation_conversions import (
axis_angle_to_matrix,
axis_angle_to_quaternion,
euler_angles_to_matrix,
matrix_to_axis_angle,
matrix_to_euler_angles,
matrix_to_quaternion,
matrix_to_rotation_6d,
quaternion_apply,
quaternion_multiply,
quaternion_to_axis_angle,
quaternion_to_matrix,
random_quaternions,
random_rotation,
@@ -60,13 +64,13 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
"""quat -> mtx -> quat"""
data = random_quaternions(13, dtype=torch.float64)
mdata = matrix_to_quaternion(quaternion_to_matrix(data))
self.assertTrue(torch.allclose(data, mdata))
self.assertClose(data, mdata)
def test_to_quat(self):
"""mtx -> quat -> mtx"""
data = random_rotations(13, dtype=torch.float64)
mdata = quaternion_to_matrix(matrix_to_quaternion(data))
self.assertTrue(torch.allclose(data, mdata))
self.assertClose(data, mdata)
def test_quat_grad_exists(self):
"""Quaternion calculations are differentiable."""
@@ -107,13 +111,13 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
for convention in self._tait_bryan_conventions():
matrices = euler_angles_to_matrix(data, convention)
mdata = matrix_to_euler_angles(matrices, convention)
self.assertTrue(torch.allclose(data, mdata))
self.assertClose(data, mdata)
data[:, 1] += half_pi
for convention in self._proper_euler_conventions():
matrices = euler_angles_to_matrix(data, convention)
mdata = matrix_to_euler_angles(matrices, convention)
self.assertTrue(torch.allclose(data, mdata))
self.assertClose(data, mdata)
def test_to_euler(self):
"""mtx -> euler -> mtx"""
@@ -121,7 +125,7 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
for convention in self._all_euler_angle_conventions():
euler_angles = matrix_to_euler_angles(data, convention)
mdata = euler_angles_to_matrix(euler_angles, convention)
self.assertTrue(torch.allclose(data, mdata))
self.assertClose(data, mdata)
def test_euler_grad_exists(self):
"""Euler angle calculations are differentiable."""
@@ -143,7 +147,7 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
ab_matrix = torch.matmul(a_matrix, b_matrix)
ab_from_matrix = matrix_to_quaternion(ab_matrix)
self.assertEqual(ab.shape, ab_from_matrix.shape)
self.assertTrue(torch.allclose(ab, ab_from_matrix))
self.assertClose(ab, ab_from_matrix)
def test_matrix_to_quaternion_corner_case(self):
"""Check no bad gradients from sqrt(0)."""
@@ -159,6 +163,31 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
self.assertClose(matrix, 0.95 * torch.eye(3))
def test_from_axis_angle(self):
"""axis_angle -> mtx -> axis_angle"""
n_repetitions = 20
data = torch.rand(n_repetitions, 3)
matrices = axis_angle_to_matrix(data)
mdata = matrix_to_axis_angle(matrices)
self.assertClose(data, mdata, atol=2e-6)
def test_from_axis_angle_has_grad(self):
n_repetitions = 20
data = torch.rand(n_repetitions, 3, requires_grad=True)
matrices = axis_angle_to_matrix(data)
mdata = matrix_to_axis_angle(matrices)
quats = axis_angle_to_quaternion(data)
mdata2 = quaternion_to_axis_angle(quats)
(grad,) = torch.autograd.grad(mdata.sum() + mdata2.sum(), data)
self.assertTrue(torch.isfinite(grad).all())
def test_to_axis_angle(self):
"""mtx -> axis_angle -> mtx"""
data = random_rotations(13, dtype=torch.float64)
euler_angles = matrix_to_axis_angle(data)
mdata = axis_angle_to_matrix(euler_angles)
self.assertClose(data, mdata)
def test_quaternion_application(self):
"""Applying a quaternion is the same as applying the matrix."""
quaternions = random_quaternions(3, torch.float64, requires_grad=True)
@@ -166,7 +195,7 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
points = torch.randn(3, 3, dtype=torch.float64, requires_grad=True)
transform1 = quaternion_apply(quaternions, points)
transform2 = torch.matmul(matrices, points[..., None])[..., 0]
self.assertTrue(torch.allclose(transform1, transform2))
self.assertClose(transform1, transform2)
[p, q] = torch.autograd.grad(transform1.sum(), [points, quaternions])
self.assertTrue(torch.isfinite(p).all())