diff --git a/pytorch3d/transforms/rotation_conversions.py b/pytorch3d/transforms/rotation_conversions.py index a9fcae22..0208e112 100644 --- a/pytorch3d/transforms/rotation_conversions.py +++ b/pytorch3d/transforms/rotation_conversions.py @@ -463,7 +463,7 @@ def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Ten return out[..., 1:] -def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor: +def axis_angle_to_matrix(axis_angle: torch.Tensor, fast: bool = False) -> torch.Tensor: """ Convert rotations given as axis/angle to rotation matrices. @@ -472,27 +472,93 @@ def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor: as a tensor of shape (..., 3), where the magnitude is the angle turned anticlockwise in radians around the vector's direction. + fast: Whether to use the new faster implementation (based on the + Rodrigues formula) instead of the original implementation (which + first converted to a quaternion and then back to a rotation matrix). Returns: Rotation matrices as tensor of shape (..., 3, 3). """ - return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + if not fast: + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + shape = axis_angle.shape + device, dtype = axis_angle.device, axis_angle.dtype + + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True).unsqueeze(-1) + + rx, ry, rz = axis_angle[..., 0], axis_angle[..., 1], axis_angle[..., 2] + zeros = torch.zeros(shape[:-1], dtype=dtype, device=device) + cross_product_matrix = torch.stack( + [zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=-1 + ).view(shape + (3,)) + cross_product_matrix_sqrd = cross_product_matrix @ cross_product_matrix + + identity = torch.eye(3, dtype=dtype, device=device) + angles_sqrd = angles * angles + angles_sqrd = torch.where(angles_sqrd == 0, 1, angles_sqrd) + return ( + identity.expand(cross_product_matrix.shape) + + torch.sinc(angles / torch.pi) * cross_product_matrix + + ((1 - torch.cos(angles)) / angles_sqrd) * cross_product_matrix_sqrd + ) -def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor: +def matrix_to_axis_angle(matrix: torch.Tensor, fast: bool = False) -> torch.Tensor: """ Convert rotations given as rotation matrices to axis/angle. Args: matrix: Rotation matrices as tensor of shape (..., 3, 3). + fast: Whether to use the new faster implementation (based on the + Rodrigues formula) instead of the original implementation (which + first converted to a quaternion and then back to a rotation matrix). Returns: Rotations given as a vector in axis angle form, as a tensor of shape (..., 3), where the magnitude is the angle turned anticlockwise in radians around the vector's direction. + """ - return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + if not fast: + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + omegas = torch.stack( + [ + matrix[..., 2, 1] - matrix[..., 1, 2], + matrix[..., 0, 2] - matrix[..., 2, 0], + matrix[..., 1, 0] - matrix[..., 0, 1], + ], + dim=-1, + ) + norms = torch.norm(omegas, p=2, dim=-1, keepdim=True) + traces = torch.diagonal(matrix, dim1=-2, dim2=-1).sum(-1).unsqueeze(-1) + angles = torch.atan2(norms, traces - 1) + + zeros = torch.zeros(3, dtype=matrix.dtype, device=matrix.device) + omegas = torch.where(torch.isclose(angles, torch.zeros_like(angles)), zeros, omegas) + + near_pi = torch.isclose(((traces - 1) / 2).abs(), torch.ones_like(traces)).squeeze( + -1 + ) + + axis_angles = torch.empty_like(omegas) + axis_angles[~near_pi] = ( + 0.5 * omegas[~near_pi] / torch.sinc(angles[~near_pi] / torch.pi) + ) + + # this derives from: nnT = (R + 1) / 2 + n = 0.5 * ( + matrix[near_pi][..., 0, :] + + torch.eye(1, 3, dtype=matrix.dtype, device=matrix.device) + ) + axis_angles[near_pi] = angles[near_pi] * n / torch.norm(n) + + return axis_angles def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor: @@ -509,22 +575,10 @@ def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor: quaternions with real part first, as tensor of shape (..., 4). """ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) - half_angles = angles * 0.5 - eps = 1e-6 - small_angles = angles.abs() < eps - sin_half_angles_over_angles = torch.empty_like(angles) - sin_half_angles_over_angles[~small_angles] = ( - torch.sin(half_angles[~small_angles]) / angles[~small_angles] + sin_half_angles_over_angles = 0.5 * torch.sinc(angles * 0.5 / torch.pi) + return torch.cat( + [torch.cos(angles * 0.5), axis_angle * sin_half_angles_over_angles], dim=-1 ) - # for x small, sin(x/2) is about x/2 - (x/2)^3/6 - # so sin(x/2)/x is about 1/2 - (x*x)/48 - sin_half_angles_over_angles[small_angles] = ( - 0.5 - (angles[small_angles] * angles[small_angles]) / 48 - ) - quaternions = torch.cat( - [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 - ) - return quaternions def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor: @@ -543,18 +597,9 @@ def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor: """ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) half_angles = torch.atan2(norms, quaternions[..., :1]) - angles = 2 * half_angles - eps = 1e-6 - small_angles = angles.abs() < eps - sin_half_angles_over_angles = torch.empty_like(angles) - sin_half_angles_over_angles[~small_angles] = ( - torch.sin(half_angles[~small_angles]) / angles[~small_angles] - ) - # for x small, sin(x/2) is about x/2 - (x/2)^3/6 - # so sin(x/2)/x is about 1/2 - (x*x)/48 - sin_half_angles_over_angles[small_angles] = ( - 0.5 - (angles[small_angles] * angles[small_angles]) / 48 - ) + sin_half_angles_over_angles = 0.5 * torch.sinc(half_angles / torch.pi) + # angles/2 are between [-pi/2, pi/2], thus sin_half_angles_over_angles + # can't be zero return quaternions[..., 1:] / sin_half_angles_over_angles diff --git a/tests/test_rotation_conversions.py b/tests/test_rotation_conversions.py index 7090d3ca..87bd250f 100644 --- a/tests/test_rotation_conversions.py +++ b/tests/test_rotation_conversions.py @@ -204,6 +204,9 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase): n_repetitions = 20 data = torch.rand(n_repetitions, 3) matrices = axis_angle_to_matrix(data) + self.assertClose(data, matrix_to_axis_angle(matrices), atol=2e-6) + self.assertClose(data, matrix_to_axis_angle(matrices, fast=True), atol=2e-6) + matrices = axis_angle_to_matrix(data, fast=True) mdata = matrix_to_axis_angle(matrices) self.assertClose(data, mdata, atol=2e-6) @@ -221,8 +224,10 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase): """mtx -> axis_angle -> mtx""" data = random_rotations(13, dtype=torch.float64) euler_angles = matrix_to_axis_angle(data) - mdata = axis_angle_to_matrix(euler_angles) - self.assertClose(data, mdata) + euler_angles_fast = matrix_to_axis_angle(data) + self.assertClose(data, axis_angle_to_matrix(euler_angles)) + self.assertClose(data, axis_angle_to_matrix(euler_angles_fast)) + self.assertClose(data, axis_angle_to_matrix(euler_angles, fast=True)) def test_quaternion_application(self): """Applying a quaternion is the same as applying the matrix."""