mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-07-31 10:52:50 +08:00
Increase performance for conversions including axis angles (#1948)
Summary: This is an extension of https://github.com/facebookresearch/pytorch3d/issues/1544 with various speed, stability, and readability improvements. (I could not find a way to make a commit to the existing PR). This PR is still based on the [Rodrigues' rotation formula](https://en.wikipedia.org/wiki/Rotation_formalisms_in_three_dimensions#Rotation_matrix_%E2%86%94_Euler_axis/angle). The motivation is the same; this change speeds up the conversions up to 10x, depending on the device, batch size, etc. ### Notes - As the angles get very close to `π`, the existing implementation and the proposed one start to differ. However, (my understanding is that) this is not a problem as the axis can not be stably inferred from the rotation matrix in this case in general. - bottler , I tried to follow similar conventions as existing functions to deal with weird angles, let me know if something needs to be changed to merge this. Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1948 Reviewed By: MichaelRamamonjisoa Differential Revision: D69193009 Pulled By: bottler fbshipit-source-id: e5ed34b45b625114ec4419bb89e22a6aefad4eeb
This commit is contained in:
parent
215590b497
commit
7a3c0cbc9d
@ -463,7 +463,7 @@ def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Ten
|
||||
return out[..., 1:]
|
||||
|
||||
|
||||
def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
|
||||
def axis_angle_to_matrix(axis_angle: torch.Tensor, fast: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
Convert rotations given as axis/angle to rotation matrices.
|
||||
|
||||
@ -472,27 +472,93 @@ def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
|
||||
as a tensor of shape (..., 3), where the magnitude is
|
||||
the angle turned anticlockwise in radians around the
|
||||
vector's direction.
|
||||
fast: Whether to use the new faster implementation (based on the
|
||||
Rodrigues formula) instead of the original implementation (which
|
||||
first converted to a quaternion and then back to a rotation matrix).
|
||||
|
||||
Returns:
|
||||
Rotation matrices as tensor of shape (..., 3, 3).
|
||||
"""
|
||||
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
|
||||
if not fast:
|
||||
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
|
||||
|
||||
shape = axis_angle.shape
|
||||
device, dtype = axis_angle.device, axis_angle.dtype
|
||||
|
||||
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True).unsqueeze(-1)
|
||||
|
||||
rx, ry, rz = axis_angle[..., 0], axis_angle[..., 1], axis_angle[..., 2]
|
||||
zeros = torch.zeros(shape[:-1], dtype=dtype, device=device)
|
||||
cross_product_matrix = torch.stack(
|
||||
[zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=-1
|
||||
).view(shape + (3,))
|
||||
cross_product_matrix_sqrd = cross_product_matrix @ cross_product_matrix
|
||||
|
||||
identity = torch.eye(3, dtype=dtype, device=device)
|
||||
angles_sqrd = angles * angles
|
||||
angles_sqrd = torch.where(angles_sqrd == 0, 1, angles_sqrd)
|
||||
return (
|
||||
identity.expand(cross_product_matrix.shape)
|
||||
+ torch.sinc(angles / torch.pi) * cross_product_matrix
|
||||
+ ((1 - torch.cos(angles)) / angles_sqrd) * cross_product_matrix_sqrd
|
||||
)
|
||||
|
||||
|
||||
def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
|
||||
def matrix_to_axis_angle(matrix: torch.Tensor, fast: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
Convert rotations given as rotation matrices to axis/angle.
|
||||
|
||||
Args:
|
||||
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
||||
fast: Whether to use the new faster implementation (based on the
|
||||
Rodrigues formula) instead of the original implementation (which
|
||||
first converted to a quaternion and then back to a rotation matrix).
|
||||
|
||||
Returns:
|
||||
Rotations given as a vector in axis angle form, as a tensor
|
||||
of shape (..., 3), where the magnitude is the angle
|
||||
turned anticlockwise in radians around the vector's
|
||||
direction.
|
||||
|
||||
"""
|
||||
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
|
||||
if not fast:
|
||||
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
|
||||
|
||||
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
||||
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
||||
|
||||
omegas = torch.stack(
|
||||
[
|
||||
matrix[..., 2, 1] - matrix[..., 1, 2],
|
||||
matrix[..., 0, 2] - matrix[..., 2, 0],
|
||||
matrix[..., 1, 0] - matrix[..., 0, 1],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
norms = torch.norm(omegas, p=2, dim=-1, keepdim=True)
|
||||
traces = torch.diagonal(matrix, dim1=-2, dim2=-1).sum(-1).unsqueeze(-1)
|
||||
angles = torch.atan2(norms, traces - 1)
|
||||
|
||||
zeros = torch.zeros(3, dtype=matrix.dtype, device=matrix.device)
|
||||
omegas = torch.where(torch.isclose(angles, torch.zeros_like(angles)), zeros, omegas)
|
||||
|
||||
near_pi = torch.isclose(((traces - 1) / 2).abs(), torch.ones_like(traces)).squeeze(
|
||||
-1
|
||||
)
|
||||
|
||||
axis_angles = torch.empty_like(omegas)
|
||||
axis_angles[~near_pi] = (
|
||||
0.5 * omegas[~near_pi] / torch.sinc(angles[~near_pi] / torch.pi)
|
||||
)
|
||||
|
||||
# this derives from: nnT = (R + 1) / 2
|
||||
n = 0.5 * (
|
||||
matrix[near_pi][..., 0, :]
|
||||
+ torch.eye(1, 3, dtype=matrix.dtype, device=matrix.device)
|
||||
)
|
||||
axis_angles[near_pi] = angles[near_pi] * n / torch.norm(n)
|
||||
|
||||
return axis_angles
|
||||
|
||||
|
||||
def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
|
||||
@ -509,22 +575,10 @@ def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
|
||||
quaternions with real part first, as tensor of shape (..., 4).
|
||||
"""
|
||||
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
|
||||
half_angles = angles * 0.5
|
||||
eps = 1e-6
|
||||
small_angles = angles.abs() < eps
|
||||
sin_half_angles_over_angles = torch.empty_like(angles)
|
||||
sin_half_angles_over_angles[~small_angles] = (
|
||||
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
||||
sin_half_angles_over_angles = 0.5 * torch.sinc(angles * 0.5 / torch.pi)
|
||||
return torch.cat(
|
||||
[torch.cos(angles * 0.5), axis_angle * sin_half_angles_over_angles], dim=-1
|
||||
)
|
||||
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
||||
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
||||
sin_half_angles_over_angles[small_angles] = (
|
||||
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
||||
)
|
||||
quaternions = torch.cat(
|
||||
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
|
||||
)
|
||||
return quaternions
|
||||
|
||||
|
||||
def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
|
||||
@ -543,18 +597,9 @@ def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
|
||||
half_angles = torch.atan2(norms, quaternions[..., :1])
|
||||
angles = 2 * half_angles
|
||||
eps = 1e-6
|
||||
small_angles = angles.abs() < eps
|
||||
sin_half_angles_over_angles = torch.empty_like(angles)
|
||||
sin_half_angles_over_angles[~small_angles] = (
|
||||
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
||||
)
|
||||
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
||||
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
||||
sin_half_angles_over_angles[small_angles] = (
|
||||
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
||||
)
|
||||
sin_half_angles_over_angles = 0.5 * torch.sinc(half_angles / torch.pi)
|
||||
# angles/2 are between [-pi/2, pi/2], thus sin_half_angles_over_angles
|
||||
# can't be zero
|
||||
return quaternions[..., 1:] / sin_half_angles_over_angles
|
||||
|
||||
|
||||
|
@ -204,6 +204,9 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
||||
n_repetitions = 20
|
||||
data = torch.rand(n_repetitions, 3)
|
||||
matrices = axis_angle_to_matrix(data)
|
||||
self.assertClose(data, matrix_to_axis_angle(matrices), atol=2e-6)
|
||||
self.assertClose(data, matrix_to_axis_angle(matrices, fast=True), atol=2e-6)
|
||||
matrices = axis_angle_to_matrix(data, fast=True)
|
||||
mdata = matrix_to_axis_angle(matrices)
|
||||
self.assertClose(data, mdata, atol=2e-6)
|
||||
|
||||
@ -221,8 +224,10 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
||||
"""mtx -> axis_angle -> mtx"""
|
||||
data = random_rotations(13, dtype=torch.float64)
|
||||
euler_angles = matrix_to_axis_angle(data)
|
||||
mdata = axis_angle_to_matrix(euler_angles)
|
||||
self.assertClose(data, mdata)
|
||||
euler_angles_fast = matrix_to_axis_angle(data)
|
||||
self.assertClose(data, axis_angle_to_matrix(euler_angles))
|
||||
self.assertClose(data, axis_angle_to_matrix(euler_angles_fast))
|
||||
self.assertClose(data, axis_angle_to_matrix(euler_angles, fast=True))
|
||||
|
||||
def test_quaternion_application(self):
|
||||
"""Applying a quaternion is the same as applying the matrix."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user