mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Increase performance for conversions including axis angles (#1948)
Summary: This is an extension of https://github.com/facebookresearch/pytorch3d/issues/1544 with various speed, stability, and readability improvements. (I could not find a way to make a commit to the existing PR). This PR is still based on the [Rodrigues' rotation formula](https://en.wikipedia.org/wiki/Rotation_formalisms_in_three_dimensions#Rotation_matrix_%E2%86%94_Euler_axis/angle). The motivation is the same; this change speeds up the conversions up to 10x, depending on the device, batch size, etc. ### Notes - As the angles get very close to `π`, the existing implementation and the proposed one start to differ. However, (my understanding is that) this is not a problem as the axis can not be stably inferred from the rotation matrix in this case in general. - bottler , I tried to follow similar conventions as existing functions to deal with weird angles, let me know if something needs to be changed to merge this. Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1948 Reviewed By: MichaelRamamonjisoa Differential Revision: D69193009 Pulled By: bottler fbshipit-source-id: e5ed34b45b625114ec4419bb89e22a6aefad4eeb
This commit is contained in:
		
							parent
							
								
									215590b497
								
							
						
					
					
						commit
						7a3c0cbc9d
					
				@ -463,7 +463,7 @@ def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Ten
 | 
			
		||||
    return out[..., 1:]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
def axis_angle_to_matrix(axis_angle: torch.Tensor, fast: bool = False) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    Convert rotations given as axis/angle to rotation matrices.
 | 
			
		||||
 | 
			
		||||
@ -472,28 +472,94 @@ def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
            as a tensor of shape (..., 3), where the magnitude is
 | 
			
		||||
            the angle turned anticlockwise in radians around the
 | 
			
		||||
            vector's direction.
 | 
			
		||||
        fast: Whether to use the new faster implementation (based on the
 | 
			
		||||
            Rodrigues formula) instead of the original implementation (which
 | 
			
		||||
            first converted to a quaternion and then back to a rotation matrix).
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        Rotation matrices as tensor of shape (..., 3, 3).
 | 
			
		||||
    """
 | 
			
		||||
    if not fast:
 | 
			
		||||
        return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
 | 
			
		||||
 | 
			
		||||
    shape = axis_angle.shape
 | 
			
		||||
    device, dtype = axis_angle.device, axis_angle.dtype
 | 
			
		||||
 | 
			
		||||
def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True).unsqueeze(-1)
 | 
			
		||||
 | 
			
		||||
    rx, ry, rz = axis_angle[..., 0], axis_angle[..., 1], axis_angle[..., 2]
 | 
			
		||||
    zeros = torch.zeros(shape[:-1], dtype=dtype, device=device)
 | 
			
		||||
    cross_product_matrix = torch.stack(
 | 
			
		||||
        [zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=-1
 | 
			
		||||
    ).view(shape + (3,))
 | 
			
		||||
    cross_product_matrix_sqrd = cross_product_matrix @ cross_product_matrix
 | 
			
		||||
 | 
			
		||||
    identity = torch.eye(3, dtype=dtype, device=device)
 | 
			
		||||
    angles_sqrd = angles * angles
 | 
			
		||||
    angles_sqrd = torch.where(angles_sqrd == 0, 1, angles_sqrd)
 | 
			
		||||
    return (
 | 
			
		||||
        identity.expand(cross_product_matrix.shape)
 | 
			
		||||
        + torch.sinc(angles / torch.pi) * cross_product_matrix
 | 
			
		||||
        + ((1 - torch.cos(angles)) / angles_sqrd) * cross_product_matrix_sqrd
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def matrix_to_axis_angle(matrix: torch.Tensor, fast: bool = False) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    Convert rotations given as rotation matrices to axis/angle.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        matrix: Rotation matrices as tensor of shape (..., 3, 3).
 | 
			
		||||
        fast: Whether to use the new faster implementation (based on the
 | 
			
		||||
            Rodrigues formula) instead of the original implementation (which
 | 
			
		||||
            first converted to a quaternion and then back to a rotation matrix).
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        Rotations given as a vector in axis angle form, as a tensor
 | 
			
		||||
            of shape (..., 3), where the magnitude is the angle
 | 
			
		||||
            turned anticlockwise in radians around the vector's
 | 
			
		||||
            direction.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    if not fast:
 | 
			
		||||
        return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
 | 
			
		||||
 | 
			
		||||
    if matrix.size(-1) != 3 or matrix.size(-2) != 3:
 | 
			
		||||
        raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
 | 
			
		||||
 | 
			
		||||
    omegas = torch.stack(
 | 
			
		||||
        [
 | 
			
		||||
            matrix[..., 2, 1] - matrix[..., 1, 2],
 | 
			
		||||
            matrix[..., 0, 2] - matrix[..., 2, 0],
 | 
			
		||||
            matrix[..., 1, 0] - matrix[..., 0, 1],
 | 
			
		||||
        ],
 | 
			
		||||
        dim=-1,
 | 
			
		||||
    )
 | 
			
		||||
    norms = torch.norm(omegas, p=2, dim=-1, keepdim=True)
 | 
			
		||||
    traces = torch.diagonal(matrix, dim1=-2, dim2=-1).sum(-1).unsqueeze(-1)
 | 
			
		||||
    angles = torch.atan2(norms, traces - 1)
 | 
			
		||||
 | 
			
		||||
    zeros = torch.zeros(3, dtype=matrix.dtype, device=matrix.device)
 | 
			
		||||
    omegas = torch.where(torch.isclose(angles, torch.zeros_like(angles)), zeros, omegas)
 | 
			
		||||
 | 
			
		||||
    near_pi = torch.isclose(((traces - 1) / 2).abs(), torch.ones_like(traces)).squeeze(
 | 
			
		||||
        -1
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    axis_angles = torch.empty_like(omegas)
 | 
			
		||||
    axis_angles[~near_pi] = (
 | 
			
		||||
        0.5 * omegas[~near_pi] / torch.sinc(angles[~near_pi] / torch.pi)
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # this derives from: nnT = (R + 1) / 2
 | 
			
		||||
    n = 0.5 * (
 | 
			
		||||
        matrix[near_pi][..., 0, :]
 | 
			
		||||
        + torch.eye(1, 3, dtype=matrix.dtype, device=matrix.device)
 | 
			
		||||
    )
 | 
			
		||||
    axis_angles[near_pi] = angles[near_pi] * n / torch.norm(n)
 | 
			
		||||
 | 
			
		||||
    return axis_angles
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
@ -509,22 +575,10 @@ def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        quaternions with real part first, as tensor of shape (..., 4).
 | 
			
		||||
    """
 | 
			
		||||
    angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
 | 
			
		||||
    half_angles = angles * 0.5
 | 
			
		||||
    eps = 1e-6
 | 
			
		||||
    small_angles = angles.abs() < eps
 | 
			
		||||
    sin_half_angles_over_angles = torch.empty_like(angles)
 | 
			
		||||
    sin_half_angles_over_angles[~small_angles] = (
 | 
			
		||||
        torch.sin(half_angles[~small_angles]) / angles[~small_angles]
 | 
			
		||||
    sin_half_angles_over_angles = 0.5 * torch.sinc(angles * 0.5 / torch.pi)
 | 
			
		||||
    return torch.cat(
 | 
			
		||||
        [torch.cos(angles * 0.5), axis_angle * sin_half_angles_over_angles], dim=-1
 | 
			
		||||
    )
 | 
			
		||||
    # for x small, sin(x/2) is about x/2 - (x/2)^3/6
 | 
			
		||||
    # so sin(x/2)/x is about 1/2 - (x*x)/48
 | 
			
		||||
    sin_half_angles_over_angles[small_angles] = (
 | 
			
		||||
        0.5 - (angles[small_angles] * angles[small_angles]) / 48
 | 
			
		||||
    )
 | 
			
		||||
    quaternions = torch.cat(
 | 
			
		||||
        [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
 | 
			
		||||
    )
 | 
			
		||||
    return quaternions
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
@ -543,18 +597,9 @@ def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
 | 
			
		||||
    half_angles = torch.atan2(norms, quaternions[..., :1])
 | 
			
		||||
    angles = 2 * half_angles
 | 
			
		||||
    eps = 1e-6
 | 
			
		||||
    small_angles = angles.abs() < eps
 | 
			
		||||
    sin_half_angles_over_angles = torch.empty_like(angles)
 | 
			
		||||
    sin_half_angles_over_angles[~small_angles] = (
 | 
			
		||||
        torch.sin(half_angles[~small_angles]) / angles[~small_angles]
 | 
			
		||||
    )
 | 
			
		||||
    # for x small, sin(x/2) is about x/2 - (x/2)^3/6
 | 
			
		||||
    # so sin(x/2)/x is about 1/2 - (x*x)/48
 | 
			
		||||
    sin_half_angles_over_angles[small_angles] = (
 | 
			
		||||
        0.5 - (angles[small_angles] * angles[small_angles]) / 48
 | 
			
		||||
    )
 | 
			
		||||
    sin_half_angles_over_angles = 0.5 * torch.sinc(half_angles / torch.pi)
 | 
			
		||||
    # angles/2 are between [-pi/2, pi/2], thus sin_half_angles_over_angles
 | 
			
		||||
    # can't be zero
 | 
			
		||||
    return quaternions[..., 1:] / sin_half_angles_over_angles
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -204,6 +204,9 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        n_repetitions = 20
 | 
			
		||||
        data = torch.rand(n_repetitions, 3)
 | 
			
		||||
        matrices = axis_angle_to_matrix(data)
 | 
			
		||||
        self.assertClose(data, matrix_to_axis_angle(matrices), atol=2e-6)
 | 
			
		||||
        self.assertClose(data, matrix_to_axis_angle(matrices, fast=True), atol=2e-6)
 | 
			
		||||
        matrices = axis_angle_to_matrix(data, fast=True)
 | 
			
		||||
        mdata = matrix_to_axis_angle(matrices)
 | 
			
		||||
        self.assertClose(data, mdata, atol=2e-6)
 | 
			
		||||
 | 
			
		||||
@ -221,8 +224,10 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        """mtx -> axis_angle -> mtx"""
 | 
			
		||||
        data = random_rotations(13, dtype=torch.float64)
 | 
			
		||||
        euler_angles = matrix_to_axis_angle(data)
 | 
			
		||||
        mdata = axis_angle_to_matrix(euler_angles)
 | 
			
		||||
        self.assertClose(data, mdata)
 | 
			
		||||
        euler_angles_fast = matrix_to_axis_angle(data)
 | 
			
		||||
        self.assertClose(data, axis_angle_to_matrix(euler_angles))
 | 
			
		||||
        self.assertClose(data, axis_angle_to_matrix(euler_angles_fast))
 | 
			
		||||
        self.assertClose(data, axis_angle_to_matrix(euler_angles, fast=True))
 | 
			
		||||
 | 
			
		||||
    def test_quaternion_application(self):
 | 
			
		||||
        """Applying a quaternion is the same as applying the matrix."""
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user