Increase performance for conversions including axis angles (#1948)

Summary:
This is an extension of https://github.com/facebookresearch/pytorch3d/issues/1544 with various speed, stability, and readability improvements. (I could not find a way to make a commit to the existing PR). This PR is still based on the [Rodrigues' rotation formula](https://en.wikipedia.org/wiki/Rotation_formalisms_in_three_dimensions#Rotation_matrix_%E2%86%94_Euler_axis/angle).

The motivation is the same; this change speeds up the conversions up to 10x, depending on the device, batch size, etc.

### Notes
- As the angles get very close to `π`, the existing implementation and the proposed one start to differ. However, (my understanding is that) this is not a problem as the axis can not be stably inferred from the rotation matrix in this case in general.
- bottler , I tried to follow similar conventions as existing functions to deal with weird angles, let me know if something needs to be changed to merge this.

Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1948

Reviewed By: MichaelRamamonjisoa

Differential Revision: D69193009

Pulled By: bottler

fbshipit-source-id: e5ed34b45b625114ec4419bb89e22a6aefad4eeb
This commit is contained in:
alex-bene 2025-02-07 07:37:42 -08:00 committed by Facebook GitHub Bot
parent 215590b497
commit 7a3c0cbc9d
2 changed files with 83 additions and 33 deletions

View File

@ -463,7 +463,7 @@ def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Ten
return out[..., 1:] return out[..., 1:]
def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor: def axis_angle_to_matrix(axis_angle: torch.Tensor, fast: bool = False) -> torch.Tensor:
""" """
Convert rotations given as axis/angle to rotation matrices. Convert rotations given as axis/angle to rotation matrices.
@ -472,27 +472,93 @@ def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
as a tensor of shape (..., 3), where the magnitude is as a tensor of shape (..., 3), where the magnitude is
the angle turned anticlockwise in radians around the the angle turned anticlockwise in radians around the
vector's direction. vector's direction.
fast: Whether to use the new faster implementation (based on the
Rodrigues formula) instead of the original implementation (which
first converted to a quaternion and then back to a rotation matrix).
Returns: Returns:
Rotation matrices as tensor of shape (..., 3, 3). Rotation matrices as tensor of shape (..., 3, 3).
""" """
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) if not fast:
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
shape = axis_angle.shape
device, dtype = axis_angle.device, axis_angle.dtype
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True).unsqueeze(-1)
rx, ry, rz = axis_angle[..., 0], axis_angle[..., 1], axis_angle[..., 2]
zeros = torch.zeros(shape[:-1], dtype=dtype, device=device)
cross_product_matrix = torch.stack(
[zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=-1
).view(shape + (3,))
cross_product_matrix_sqrd = cross_product_matrix @ cross_product_matrix
identity = torch.eye(3, dtype=dtype, device=device)
angles_sqrd = angles * angles
angles_sqrd = torch.where(angles_sqrd == 0, 1, angles_sqrd)
return (
identity.expand(cross_product_matrix.shape)
+ torch.sinc(angles / torch.pi) * cross_product_matrix
+ ((1 - torch.cos(angles)) / angles_sqrd) * cross_product_matrix_sqrd
)
def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor: def matrix_to_axis_angle(matrix: torch.Tensor, fast: bool = False) -> torch.Tensor:
""" """
Convert rotations given as rotation matrices to axis/angle. Convert rotations given as rotation matrices to axis/angle.
Args: Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3). matrix: Rotation matrices as tensor of shape (..., 3, 3).
fast: Whether to use the new faster implementation (based on the
Rodrigues formula) instead of the original implementation (which
first converted to a quaternion and then back to a rotation matrix).
Returns: Returns:
Rotations given as a vector in axis angle form, as a tensor Rotations given as a vector in axis angle form, as a tensor
of shape (..., 3), where the magnitude is the angle of shape (..., 3), where the magnitude is the angle
turned anticlockwise in radians around the vector's turned anticlockwise in radians around the vector's
direction. direction.
""" """
return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) if not fast:
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
omegas = torch.stack(
[
matrix[..., 2, 1] - matrix[..., 1, 2],
matrix[..., 0, 2] - matrix[..., 2, 0],
matrix[..., 1, 0] - matrix[..., 0, 1],
],
dim=-1,
)
norms = torch.norm(omegas, p=2, dim=-1, keepdim=True)
traces = torch.diagonal(matrix, dim1=-2, dim2=-1).sum(-1).unsqueeze(-1)
angles = torch.atan2(norms, traces - 1)
zeros = torch.zeros(3, dtype=matrix.dtype, device=matrix.device)
omegas = torch.where(torch.isclose(angles, torch.zeros_like(angles)), zeros, omegas)
near_pi = torch.isclose(((traces - 1) / 2).abs(), torch.ones_like(traces)).squeeze(
-1
)
axis_angles = torch.empty_like(omegas)
axis_angles[~near_pi] = (
0.5 * omegas[~near_pi] / torch.sinc(angles[~near_pi] / torch.pi)
)
# this derives from: nnT = (R + 1) / 2
n = 0.5 * (
matrix[near_pi][..., 0, :]
+ torch.eye(1, 3, dtype=matrix.dtype, device=matrix.device)
)
axis_angles[near_pi] = angles[near_pi] * n / torch.norm(n)
return axis_angles
def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor: def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
@ -509,22 +575,10 @@ def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
quaternions with real part first, as tensor of shape (..., 4). quaternions with real part first, as tensor of shape (..., 4).
""" """
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
half_angles = angles * 0.5 sin_half_angles_over_angles = 0.5 * torch.sinc(angles * 0.5 / torch.pi)
eps = 1e-6 return torch.cat(
small_angles = angles.abs() < eps [torch.cos(angles * 0.5), axis_angle * sin_half_angles_over_angles], dim=-1
sin_half_angles_over_angles = torch.empty_like(angles)
sin_half_angles_over_angles[~small_angles] = (
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
) )
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48
sin_half_angles_over_angles[small_angles] = (
0.5 - (angles[small_angles] * angles[small_angles]) / 48
)
quaternions = torch.cat(
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
)
return quaternions
def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor: def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
@ -543,18 +597,9 @@ def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
""" """
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
half_angles = torch.atan2(norms, quaternions[..., :1]) half_angles = torch.atan2(norms, quaternions[..., :1])
angles = 2 * half_angles sin_half_angles_over_angles = 0.5 * torch.sinc(half_angles / torch.pi)
eps = 1e-6 # angles/2 are between [-pi/2, pi/2], thus sin_half_angles_over_angles
small_angles = angles.abs() < eps # can't be zero
sin_half_angles_over_angles = torch.empty_like(angles)
sin_half_angles_over_angles[~small_angles] = (
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
)
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48
sin_half_angles_over_angles[small_angles] = (
0.5 - (angles[small_angles] * angles[small_angles]) / 48
)
return quaternions[..., 1:] / sin_half_angles_over_angles return quaternions[..., 1:] / sin_half_angles_over_angles

View File

@ -204,6 +204,9 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
n_repetitions = 20 n_repetitions = 20
data = torch.rand(n_repetitions, 3) data = torch.rand(n_repetitions, 3)
matrices = axis_angle_to_matrix(data) matrices = axis_angle_to_matrix(data)
self.assertClose(data, matrix_to_axis_angle(matrices), atol=2e-6)
self.assertClose(data, matrix_to_axis_angle(matrices, fast=True), atol=2e-6)
matrices = axis_angle_to_matrix(data, fast=True)
mdata = matrix_to_axis_angle(matrices) mdata = matrix_to_axis_angle(matrices)
self.assertClose(data, mdata, atol=2e-6) self.assertClose(data, mdata, atol=2e-6)
@ -221,8 +224,10 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
"""mtx -> axis_angle -> mtx""" """mtx -> axis_angle -> mtx"""
data = random_rotations(13, dtype=torch.float64) data = random_rotations(13, dtype=torch.float64)
euler_angles = matrix_to_axis_angle(data) euler_angles = matrix_to_axis_angle(data)
mdata = axis_angle_to_matrix(euler_angles) euler_angles_fast = matrix_to_axis_angle(data)
self.assertClose(data, mdata) self.assertClose(data, axis_angle_to_matrix(euler_angles))
self.assertClose(data, axis_angle_to_matrix(euler_angles_fast))
self.assertClose(data, axis_angle_to_matrix(euler_angles, fast=True))
def test_quaternion_application(self): def test_quaternion_application(self):
"""Applying a quaternion is the same as applying the matrix.""" """Applying a quaternion is the same as applying the matrix."""