mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
axis_angle representation of rotations
Summary: We can represent a rotation as a vector in the axis direction, whose length is the rotation anticlockwise in radians around that axis. Reviewed By: gkioxari Differential Revision: D24306293 fbshipit-source-id: 2e0f138eda8329f6cceff600a6e5f17a00e4deb7
This commit is contained in:
parent
005a334f99
commit
c93c4dd7b2
@ -413,6 +413,101 @@ def quaternion_apply(quaternion, point):
|
||||
return out[..., 1:]
|
||||
|
||||
|
||||
def axis_angle_to_matrix(axis_angle):
|
||||
"""
|
||||
Convert rotations given as axis/angle to rotation matrices.
|
||||
|
||||
Args:
|
||||
axis_angle: Rotations given as a vector in axis angle form,
|
||||
as a tensor of shape (..., 3), where the magnitude is
|
||||
the angle turned anticlockwise in radians around the
|
||||
vector's direction.
|
||||
|
||||
Returns:
|
||||
Rotation matrices as tensor of shape (..., 3, 3).
|
||||
"""
|
||||
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
|
||||
|
||||
|
||||
def matrix_to_axis_angle(matrix):
|
||||
"""
|
||||
Convert rotations given as rotation matrices to axis/angle.
|
||||
|
||||
Args:
|
||||
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
||||
|
||||
Returns:
|
||||
Rotations given as a vector in axis angle form, as a tensor
|
||||
of shape (..., 3), where the magnitude is the angle
|
||||
turned anticlockwise in radians around the vector's
|
||||
direction.
|
||||
"""
|
||||
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
|
||||
|
||||
|
||||
def axis_angle_to_quaternion(axis_angle):
|
||||
"""
|
||||
Convert rotations given as axis/angle to quaternions.
|
||||
|
||||
Args:
|
||||
axis_angle: Rotations given as a vector in axis angle form,
|
||||
as a tensor of shape (..., 3), where the magnitude is
|
||||
the angle turned anticlockwise in radians around the
|
||||
vector's direction.
|
||||
|
||||
Returns:
|
||||
quaternions with real part first, as tensor of shape (..., 4).
|
||||
"""
|
||||
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
|
||||
half_angles = 0.5 * angles
|
||||
eps = 1e-6
|
||||
small_angles = angles.abs() < eps
|
||||
sin_half_angles_over_angles = torch.empty_like(angles)
|
||||
sin_half_angles_over_angles[~small_angles] = (
|
||||
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
||||
)
|
||||
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
||||
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
||||
sin_half_angles_over_angles[small_angles] = (
|
||||
0.5 - torch.square(angles[small_angles]) / 48
|
||||
)
|
||||
quaternions = torch.cat(
|
||||
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
|
||||
)
|
||||
return quaternions
|
||||
|
||||
|
||||
def quaternion_to_axis_angle(quaternions):
|
||||
"""
|
||||
Convert rotations given as quaternions to axis/angle.
|
||||
|
||||
Args:
|
||||
quaternions: quaternions with real part first,
|
||||
as tensor of shape (..., 4).
|
||||
|
||||
Returns:
|
||||
Rotations given as a vector in axis angle form, as a tensor
|
||||
of shape (..., 3), where the magnitude is the angle
|
||||
turned anticlockwise in radians around the vector's
|
||||
direction.
|
||||
"""
|
||||
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
|
||||
half_angles = torch.atan2(norms, quaternions[..., :1])
|
||||
angles = 2 * half_angles
|
||||
eps = 1e-6
|
||||
small_angles = angles.abs() < eps
|
||||
sin_half_angles_over_angles = torch.empty_like(angles)
|
||||
sin_half_angles_over_angles[~small_angles] = (
|
||||
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
||||
)
|
||||
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
||||
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
||||
sin_half_angles_over_angles[small_angles] = (
|
||||
0.5 - torch.square(angles[small_angles]) / 48
|
||||
)
|
||||
return quaternions[..., 1:] / sin_half_angles_over_angles
|
||||
|
||||
|
||||
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
|
||||
|
@ -8,12 +8,16 @@ import unittest
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.transforms.rotation_conversions import (
|
||||
axis_angle_to_matrix,
|
||||
axis_angle_to_quaternion,
|
||||
euler_angles_to_matrix,
|
||||
matrix_to_axis_angle,
|
||||
matrix_to_euler_angles,
|
||||
matrix_to_quaternion,
|
||||
matrix_to_rotation_6d,
|
||||
quaternion_apply,
|
||||
quaternion_multiply,
|
||||
quaternion_to_axis_angle,
|
||||
quaternion_to_matrix,
|
||||
random_quaternions,
|
||||
random_rotation,
|
||||
@ -60,13 +64,13 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
||||
"""quat -> mtx -> quat"""
|
||||
data = random_quaternions(13, dtype=torch.float64)
|
||||
mdata = matrix_to_quaternion(quaternion_to_matrix(data))
|
||||
self.assertTrue(torch.allclose(data, mdata))
|
||||
self.assertClose(data, mdata)
|
||||
|
||||
def test_to_quat(self):
|
||||
"""mtx -> quat -> mtx"""
|
||||
data = random_rotations(13, dtype=torch.float64)
|
||||
mdata = quaternion_to_matrix(matrix_to_quaternion(data))
|
||||
self.assertTrue(torch.allclose(data, mdata))
|
||||
self.assertClose(data, mdata)
|
||||
|
||||
def test_quat_grad_exists(self):
|
||||
"""Quaternion calculations are differentiable."""
|
||||
@ -107,13 +111,13 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
||||
for convention in self._tait_bryan_conventions():
|
||||
matrices = euler_angles_to_matrix(data, convention)
|
||||
mdata = matrix_to_euler_angles(matrices, convention)
|
||||
self.assertTrue(torch.allclose(data, mdata))
|
||||
self.assertClose(data, mdata)
|
||||
|
||||
data[:, 1] += half_pi
|
||||
for convention in self._proper_euler_conventions():
|
||||
matrices = euler_angles_to_matrix(data, convention)
|
||||
mdata = matrix_to_euler_angles(matrices, convention)
|
||||
self.assertTrue(torch.allclose(data, mdata))
|
||||
self.assertClose(data, mdata)
|
||||
|
||||
def test_to_euler(self):
|
||||
"""mtx -> euler -> mtx"""
|
||||
@ -121,7 +125,7 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
||||
for convention in self._all_euler_angle_conventions():
|
||||
euler_angles = matrix_to_euler_angles(data, convention)
|
||||
mdata = euler_angles_to_matrix(euler_angles, convention)
|
||||
self.assertTrue(torch.allclose(data, mdata))
|
||||
self.assertClose(data, mdata)
|
||||
|
||||
def test_euler_grad_exists(self):
|
||||
"""Euler angle calculations are differentiable."""
|
||||
@ -143,7 +147,7 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
||||
ab_matrix = torch.matmul(a_matrix, b_matrix)
|
||||
ab_from_matrix = matrix_to_quaternion(ab_matrix)
|
||||
self.assertEqual(ab.shape, ab_from_matrix.shape)
|
||||
self.assertTrue(torch.allclose(ab, ab_from_matrix))
|
||||
self.assertClose(ab, ab_from_matrix)
|
||||
|
||||
def test_matrix_to_quaternion_corner_case(self):
|
||||
"""Check no bad gradients from sqrt(0)."""
|
||||
@ -159,6 +163,31 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
self.assertClose(matrix, 0.95 * torch.eye(3))
|
||||
|
||||
def test_from_axis_angle(self):
|
||||
"""axis_angle -> mtx -> axis_angle"""
|
||||
n_repetitions = 20
|
||||
data = torch.rand(n_repetitions, 3)
|
||||
matrices = axis_angle_to_matrix(data)
|
||||
mdata = matrix_to_axis_angle(matrices)
|
||||
self.assertClose(data, mdata, atol=2e-6)
|
||||
|
||||
def test_from_axis_angle_has_grad(self):
|
||||
n_repetitions = 20
|
||||
data = torch.rand(n_repetitions, 3, requires_grad=True)
|
||||
matrices = axis_angle_to_matrix(data)
|
||||
mdata = matrix_to_axis_angle(matrices)
|
||||
quats = axis_angle_to_quaternion(data)
|
||||
mdata2 = quaternion_to_axis_angle(quats)
|
||||
(grad,) = torch.autograd.grad(mdata.sum() + mdata2.sum(), data)
|
||||
self.assertTrue(torch.isfinite(grad).all())
|
||||
|
||||
def test_to_axis_angle(self):
|
||||
"""mtx -> axis_angle -> mtx"""
|
||||
data = random_rotations(13, dtype=torch.float64)
|
||||
euler_angles = matrix_to_axis_angle(data)
|
||||
mdata = axis_angle_to_matrix(euler_angles)
|
||||
self.assertClose(data, mdata)
|
||||
|
||||
def test_quaternion_application(self):
|
||||
"""Applying a quaternion is the same as applying the matrix."""
|
||||
quaternions = random_quaternions(3, torch.float64, requires_grad=True)
|
||||
@ -166,7 +195,7 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
||||
points = torch.randn(3, 3, dtype=torch.float64, requires_grad=True)
|
||||
transform1 = quaternion_apply(quaternions, points)
|
||||
transform2 = torch.matmul(matrices, points[..., None])[..., 0]
|
||||
self.assertTrue(torch.allclose(transform1, transform2))
|
||||
self.assertClose(transform1, transform2)
|
||||
|
||||
[p, q] = torch.autograd.grad(transform1.sum(), [points, quaternions])
|
||||
self.assertTrue(torch.isfinite(p).all())
|
||||
|
Loading…
x
Reference in New Issue
Block a user