Increase performance for conversions including axis angles (#1948)

Summary:
This is an extension of https://github.com/facebookresearch/pytorch3d/issues/1544 with various speed, stability, and readability improvements. (I could not find a way to make a commit to the existing PR). This PR is still based on the [Rodrigues' rotation formula](https://en.wikipedia.org/wiki/Rotation_formalisms_in_three_dimensions#Rotation_matrix_%E2%86%94_Euler_axis/angle).

The motivation is the same; this change speeds up the conversions up to 10x, depending on the device, batch size, etc.

### Notes
- As the angles get very close to `π`, the existing implementation and the proposed one start to differ. However, (my understanding is that) this is not a problem as the axis can not be stably inferred from the rotation matrix in this case in general.
- bottler , I tried to follow similar conventions as existing functions to deal with weird angles, let me know if something needs to be changed to merge this.

Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1948

Reviewed By: MichaelRamamonjisoa

Differential Revision: D69193009

Pulled By: bottler

fbshipit-source-id: e5ed34b45b625114ec4419bb89e22a6aefad4eeb
This commit is contained in:
alex-bene 2025-02-07 07:37:42 -08:00 committed by Facebook GitHub Bot
parent 215590b497
commit 7a3c0cbc9d
2 changed files with 83 additions and 33 deletions

View File

@ -463,7 +463,7 @@ def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Ten
return out[..., 1:]
def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
def axis_angle_to_matrix(axis_angle: torch.Tensor, fast: bool = False) -> torch.Tensor:
"""
Convert rotations given as axis/angle to rotation matrices.
@ -472,27 +472,93 @@ def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
as a tensor of shape (..., 3), where the magnitude is
the angle turned anticlockwise in radians around the
vector's direction.
fast: Whether to use the new faster implementation (based on the
Rodrigues formula) instead of the original implementation (which
first converted to a quaternion and then back to a rotation matrix).
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
if not fast:
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
shape = axis_angle.shape
device, dtype = axis_angle.device, axis_angle.dtype
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True).unsqueeze(-1)
rx, ry, rz = axis_angle[..., 0], axis_angle[..., 1], axis_angle[..., 2]
zeros = torch.zeros(shape[:-1], dtype=dtype, device=device)
cross_product_matrix = torch.stack(
[zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=-1
).view(shape + (3,))
cross_product_matrix_sqrd = cross_product_matrix @ cross_product_matrix
identity = torch.eye(3, dtype=dtype, device=device)
angles_sqrd = angles * angles
angles_sqrd = torch.where(angles_sqrd == 0, 1, angles_sqrd)
return (
identity.expand(cross_product_matrix.shape)
+ torch.sinc(angles / torch.pi) * cross_product_matrix
+ ((1 - torch.cos(angles)) / angles_sqrd) * cross_product_matrix_sqrd
)
def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
def matrix_to_axis_angle(matrix: torch.Tensor, fast: bool = False) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to axis/angle.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
fast: Whether to use the new faster implementation (based on the
Rodrigues formula) instead of the original implementation (which
first converted to a quaternion and then back to a rotation matrix).
Returns:
Rotations given as a vector in axis angle form, as a tensor
of shape (..., 3), where the magnitude is the angle
turned anticlockwise in radians around the vector's
direction.
"""
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
if not fast:
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
omegas = torch.stack(
[
matrix[..., 2, 1] - matrix[..., 1, 2],
matrix[..., 0, 2] - matrix[..., 2, 0],
matrix[..., 1, 0] - matrix[..., 0, 1],
],
dim=-1,
)
norms = torch.norm(omegas, p=2, dim=-1, keepdim=True)
traces = torch.diagonal(matrix, dim1=-2, dim2=-1).sum(-1).unsqueeze(-1)
angles = torch.atan2(norms, traces - 1)
zeros = torch.zeros(3, dtype=matrix.dtype, device=matrix.device)
omegas = torch.where(torch.isclose(angles, torch.zeros_like(angles)), zeros, omegas)
near_pi = torch.isclose(((traces - 1) / 2).abs(), torch.ones_like(traces)).squeeze(
-1
)
axis_angles = torch.empty_like(omegas)
axis_angles[~near_pi] = (
0.5 * omegas[~near_pi] / torch.sinc(angles[~near_pi] / torch.pi)
)
# this derives from: nnT = (R + 1) / 2
n = 0.5 * (
matrix[near_pi][..., 0, :]
+ torch.eye(1, 3, dtype=matrix.dtype, device=matrix.device)
)
axis_angles[near_pi] = angles[near_pi] * n / torch.norm(n)
return axis_angles
def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
@ -509,22 +575,10 @@ def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
quaternions with real part first, as tensor of shape (..., 4).
"""
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
half_angles = angles * 0.5
eps = 1e-6
small_angles = angles.abs() < eps
sin_half_angles_over_angles = torch.empty_like(angles)
sin_half_angles_over_angles[~small_angles] = (
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
sin_half_angles_over_angles = 0.5 * torch.sinc(angles * 0.5 / torch.pi)
return torch.cat(
[torch.cos(angles * 0.5), axis_angle * sin_half_angles_over_angles], dim=-1
)
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48
sin_half_angles_over_angles[small_angles] = (
0.5 - (angles[small_angles] * angles[small_angles]) / 48
)
quaternions = torch.cat(
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
)
return quaternions
def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
@ -543,18 +597,9 @@ def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
"""
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
half_angles = torch.atan2(norms, quaternions[..., :1])
angles = 2 * half_angles
eps = 1e-6
small_angles = angles.abs() < eps
sin_half_angles_over_angles = torch.empty_like(angles)
sin_half_angles_over_angles[~small_angles] = (
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
)
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48
sin_half_angles_over_angles[small_angles] = (
0.5 - (angles[small_angles] * angles[small_angles]) / 48
)
sin_half_angles_over_angles = 0.5 * torch.sinc(half_angles / torch.pi)
# angles/2 are between [-pi/2, pi/2], thus sin_half_angles_over_angles
# can't be zero
return quaternions[..., 1:] / sin_half_angles_over_angles

View File

@ -204,6 +204,9 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
n_repetitions = 20
data = torch.rand(n_repetitions, 3)
matrices = axis_angle_to_matrix(data)
self.assertClose(data, matrix_to_axis_angle(matrices), atol=2e-6)
self.assertClose(data, matrix_to_axis_angle(matrices, fast=True), atol=2e-6)
matrices = axis_angle_to_matrix(data, fast=True)
mdata = matrix_to_axis_angle(matrices)
self.assertClose(data, mdata, atol=2e-6)
@ -221,8 +224,10 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
"""mtx -> axis_angle -> mtx"""
data = random_rotations(13, dtype=torch.float64)
euler_angles = matrix_to_axis_angle(data)
mdata = axis_angle_to_matrix(euler_angles)
self.assertClose(data, mdata)
euler_angles_fast = matrix_to_axis_angle(data)
self.assertClose(data, axis_angle_to_matrix(euler_angles))
self.assertClose(data, axis_angle_to_matrix(euler_angles_fast))
self.assertClose(data, axis_angle_to_matrix(euler_angles, fast=True))
def test_quaternion_application(self):
"""Applying a quaternion is the same as applying the matrix."""