mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
axis_angle representation of rotations
Summary: We can represent a rotation as a vector in the axis direction, whose length is the rotation anticlockwise in radians around that axis. Reviewed By: gkioxari Differential Revision: D24306293 fbshipit-source-id: 2e0f138eda8329f6cceff600a6e5f17a00e4deb7
This commit is contained in:
parent
005a334f99
commit
c93c4dd7b2
@ -413,6 +413,101 @@ def quaternion_apply(quaternion, point):
|
|||||||
return out[..., 1:]
|
return out[..., 1:]
|
||||||
|
|
||||||
|
|
||||||
|
def axis_angle_to_matrix(axis_angle):
|
||||||
|
"""
|
||||||
|
Convert rotations given as axis/angle to rotation matrices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
axis_angle: Rotations given as a vector in axis angle form,
|
||||||
|
as a tensor of shape (..., 3), where the magnitude is
|
||||||
|
the angle turned anticlockwise in radians around the
|
||||||
|
vector's direction.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rotation matrices as tensor of shape (..., 3, 3).
|
||||||
|
"""
|
||||||
|
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
|
||||||
|
|
||||||
|
|
||||||
|
def matrix_to_axis_angle(matrix):
|
||||||
|
"""
|
||||||
|
Convert rotations given as rotation matrices to axis/angle.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rotations given as a vector in axis angle form, as a tensor
|
||||||
|
of shape (..., 3), where the magnitude is the angle
|
||||||
|
turned anticlockwise in radians around the vector's
|
||||||
|
direction.
|
||||||
|
"""
|
||||||
|
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
|
||||||
|
|
||||||
|
|
||||||
|
def axis_angle_to_quaternion(axis_angle):
|
||||||
|
"""
|
||||||
|
Convert rotations given as axis/angle to quaternions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
axis_angle: Rotations given as a vector in axis angle form,
|
||||||
|
as a tensor of shape (..., 3), where the magnitude is
|
||||||
|
the angle turned anticlockwise in radians around the
|
||||||
|
vector's direction.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
quaternions with real part first, as tensor of shape (..., 4).
|
||||||
|
"""
|
||||||
|
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
|
||||||
|
half_angles = 0.5 * angles
|
||||||
|
eps = 1e-6
|
||||||
|
small_angles = angles.abs() < eps
|
||||||
|
sin_half_angles_over_angles = torch.empty_like(angles)
|
||||||
|
sin_half_angles_over_angles[~small_angles] = (
|
||||||
|
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
||||||
|
)
|
||||||
|
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
||||||
|
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
||||||
|
sin_half_angles_over_angles[small_angles] = (
|
||||||
|
0.5 - torch.square(angles[small_angles]) / 48
|
||||||
|
)
|
||||||
|
quaternions = torch.cat(
|
||||||
|
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
|
||||||
|
)
|
||||||
|
return quaternions
|
||||||
|
|
||||||
|
|
||||||
|
def quaternion_to_axis_angle(quaternions):
|
||||||
|
"""
|
||||||
|
Convert rotations given as quaternions to axis/angle.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quaternions: quaternions with real part first,
|
||||||
|
as tensor of shape (..., 4).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rotations given as a vector in axis angle form, as a tensor
|
||||||
|
of shape (..., 3), where the magnitude is the angle
|
||||||
|
turned anticlockwise in radians around the vector's
|
||||||
|
direction.
|
||||||
|
"""
|
||||||
|
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
|
||||||
|
half_angles = torch.atan2(norms, quaternions[..., :1])
|
||||||
|
angles = 2 * half_angles
|
||||||
|
eps = 1e-6
|
||||||
|
small_angles = angles.abs() < eps
|
||||||
|
sin_half_angles_over_angles = torch.empty_like(angles)
|
||||||
|
sin_half_angles_over_angles[~small_angles] = (
|
||||||
|
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
||||||
|
)
|
||||||
|
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
||||||
|
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
||||||
|
sin_half_angles_over_angles[small_angles] = (
|
||||||
|
0.5 - torch.square(angles[small_angles]) / 48
|
||||||
|
)
|
||||||
|
return quaternions[..., 1:] / sin_half_angles_over_angles
|
||||||
|
|
||||||
|
|
||||||
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
|
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
|
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
|
||||||
|
@ -8,12 +8,16 @@ import unittest
|
|||||||
import torch
|
import torch
|
||||||
from common_testing import TestCaseMixin
|
from common_testing import TestCaseMixin
|
||||||
from pytorch3d.transforms.rotation_conversions import (
|
from pytorch3d.transforms.rotation_conversions import (
|
||||||
|
axis_angle_to_matrix,
|
||||||
|
axis_angle_to_quaternion,
|
||||||
euler_angles_to_matrix,
|
euler_angles_to_matrix,
|
||||||
|
matrix_to_axis_angle,
|
||||||
matrix_to_euler_angles,
|
matrix_to_euler_angles,
|
||||||
matrix_to_quaternion,
|
matrix_to_quaternion,
|
||||||
matrix_to_rotation_6d,
|
matrix_to_rotation_6d,
|
||||||
quaternion_apply,
|
quaternion_apply,
|
||||||
quaternion_multiply,
|
quaternion_multiply,
|
||||||
|
quaternion_to_axis_angle,
|
||||||
quaternion_to_matrix,
|
quaternion_to_matrix,
|
||||||
random_quaternions,
|
random_quaternions,
|
||||||
random_rotation,
|
random_rotation,
|
||||||
@ -60,13 +64,13 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
|||||||
"""quat -> mtx -> quat"""
|
"""quat -> mtx -> quat"""
|
||||||
data = random_quaternions(13, dtype=torch.float64)
|
data = random_quaternions(13, dtype=torch.float64)
|
||||||
mdata = matrix_to_quaternion(quaternion_to_matrix(data))
|
mdata = matrix_to_quaternion(quaternion_to_matrix(data))
|
||||||
self.assertTrue(torch.allclose(data, mdata))
|
self.assertClose(data, mdata)
|
||||||
|
|
||||||
def test_to_quat(self):
|
def test_to_quat(self):
|
||||||
"""mtx -> quat -> mtx"""
|
"""mtx -> quat -> mtx"""
|
||||||
data = random_rotations(13, dtype=torch.float64)
|
data = random_rotations(13, dtype=torch.float64)
|
||||||
mdata = quaternion_to_matrix(matrix_to_quaternion(data))
|
mdata = quaternion_to_matrix(matrix_to_quaternion(data))
|
||||||
self.assertTrue(torch.allclose(data, mdata))
|
self.assertClose(data, mdata)
|
||||||
|
|
||||||
def test_quat_grad_exists(self):
|
def test_quat_grad_exists(self):
|
||||||
"""Quaternion calculations are differentiable."""
|
"""Quaternion calculations are differentiable."""
|
||||||
@ -107,13 +111,13 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
|||||||
for convention in self._tait_bryan_conventions():
|
for convention in self._tait_bryan_conventions():
|
||||||
matrices = euler_angles_to_matrix(data, convention)
|
matrices = euler_angles_to_matrix(data, convention)
|
||||||
mdata = matrix_to_euler_angles(matrices, convention)
|
mdata = matrix_to_euler_angles(matrices, convention)
|
||||||
self.assertTrue(torch.allclose(data, mdata))
|
self.assertClose(data, mdata)
|
||||||
|
|
||||||
data[:, 1] += half_pi
|
data[:, 1] += half_pi
|
||||||
for convention in self._proper_euler_conventions():
|
for convention in self._proper_euler_conventions():
|
||||||
matrices = euler_angles_to_matrix(data, convention)
|
matrices = euler_angles_to_matrix(data, convention)
|
||||||
mdata = matrix_to_euler_angles(matrices, convention)
|
mdata = matrix_to_euler_angles(matrices, convention)
|
||||||
self.assertTrue(torch.allclose(data, mdata))
|
self.assertClose(data, mdata)
|
||||||
|
|
||||||
def test_to_euler(self):
|
def test_to_euler(self):
|
||||||
"""mtx -> euler -> mtx"""
|
"""mtx -> euler -> mtx"""
|
||||||
@ -121,7 +125,7 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
|||||||
for convention in self._all_euler_angle_conventions():
|
for convention in self._all_euler_angle_conventions():
|
||||||
euler_angles = matrix_to_euler_angles(data, convention)
|
euler_angles = matrix_to_euler_angles(data, convention)
|
||||||
mdata = euler_angles_to_matrix(euler_angles, convention)
|
mdata = euler_angles_to_matrix(euler_angles, convention)
|
||||||
self.assertTrue(torch.allclose(data, mdata))
|
self.assertClose(data, mdata)
|
||||||
|
|
||||||
def test_euler_grad_exists(self):
|
def test_euler_grad_exists(self):
|
||||||
"""Euler angle calculations are differentiable."""
|
"""Euler angle calculations are differentiable."""
|
||||||
@ -143,7 +147,7 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
|||||||
ab_matrix = torch.matmul(a_matrix, b_matrix)
|
ab_matrix = torch.matmul(a_matrix, b_matrix)
|
||||||
ab_from_matrix = matrix_to_quaternion(ab_matrix)
|
ab_from_matrix = matrix_to_quaternion(ab_matrix)
|
||||||
self.assertEqual(ab.shape, ab_from_matrix.shape)
|
self.assertEqual(ab.shape, ab_from_matrix.shape)
|
||||||
self.assertTrue(torch.allclose(ab, ab_from_matrix))
|
self.assertClose(ab, ab_from_matrix)
|
||||||
|
|
||||||
def test_matrix_to_quaternion_corner_case(self):
|
def test_matrix_to_quaternion_corner_case(self):
|
||||||
"""Check no bad gradients from sqrt(0)."""
|
"""Check no bad gradients from sqrt(0)."""
|
||||||
@ -159,6 +163,31 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.assertClose(matrix, 0.95 * torch.eye(3))
|
self.assertClose(matrix, 0.95 * torch.eye(3))
|
||||||
|
|
||||||
|
def test_from_axis_angle(self):
|
||||||
|
"""axis_angle -> mtx -> axis_angle"""
|
||||||
|
n_repetitions = 20
|
||||||
|
data = torch.rand(n_repetitions, 3)
|
||||||
|
matrices = axis_angle_to_matrix(data)
|
||||||
|
mdata = matrix_to_axis_angle(matrices)
|
||||||
|
self.assertClose(data, mdata, atol=2e-6)
|
||||||
|
|
||||||
|
def test_from_axis_angle_has_grad(self):
|
||||||
|
n_repetitions = 20
|
||||||
|
data = torch.rand(n_repetitions, 3, requires_grad=True)
|
||||||
|
matrices = axis_angle_to_matrix(data)
|
||||||
|
mdata = matrix_to_axis_angle(matrices)
|
||||||
|
quats = axis_angle_to_quaternion(data)
|
||||||
|
mdata2 = quaternion_to_axis_angle(quats)
|
||||||
|
(grad,) = torch.autograd.grad(mdata.sum() + mdata2.sum(), data)
|
||||||
|
self.assertTrue(torch.isfinite(grad).all())
|
||||||
|
|
||||||
|
def test_to_axis_angle(self):
|
||||||
|
"""mtx -> axis_angle -> mtx"""
|
||||||
|
data = random_rotations(13, dtype=torch.float64)
|
||||||
|
euler_angles = matrix_to_axis_angle(data)
|
||||||
|
mdata = axis_angle_to_matrix(euler_angles)
|
||||||
|
self.assertClose(data, mdata)
|
||||||
|
|
||||||
def test_quaternion_application(self):
|
def test_quaternion_application(self):
|
||||||
"""Applying a quaternion is the same as applying the matrix."""
|
"""Applying a quaternion is the same as applying the matrix."""
|
||||||
quaternions = random_quaternions(3, torch.float64, requires_grad=True)
|
quaternions = random_quaternions(3, torch.float64, requires_grad=True)
|
||||||
@ -166,7 +195,7 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
|||||||
points = torch.randn(3, 3, dtype=torch.float64, requires_grad=True)
|
points = torch.randn(3, 3, dtype=torch.float64, requires_grad=True)
|
||||||
transform1 = quaternion_apply(quaternions, points)
|
transform1 = quaternion_apply(quaternions, points)
|
||||||
transform2 = torch.matmul(matrices, points[..., None])[..., 0]
|
transform2 = torch.matmul(matrices, points[..., None])[..., 0]
|
||||||
self.assertTrue(torch.allclose(transform1, transform2))
|
self.assertClose(transform1, transform2)
|
||||||
|
|
||||||
[p, q] = torch.autograd.grad(transform1.sum(), [points, quaternions])
|
[p, q] = torch.autograd.grad(transform1.sum(), [points, quaternions])
|
||||||
self.assertTrue(torch.isfinite(p).all())
|
self.assertTrue(torch.isfinite(p).all())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user