Make some matrix conversion jittable (#898)

Summary:
Make sure the functions from `rotation_conversion` are jittable, and add some type hints.

Add tests to verify this is the case.

Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/898

Reviewed By: patricklabatut

Differential Revision: D31926103

Pulled By: bottler

fbshipit-source-id: bff6013c5ca2d452e37e631bd902f0674d5ca091
This commit is contained in:
una-dinosauria 2021-10-26 14:30:45 -07:00 committed by Facebook GitHub Bot
parent 29417d1f9b
commit bee31c48d3
2 changed files with 61 additions and 30 deletions

View File

@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import functools
from typing import Optional from typing import Optional
import torch import torch
@ -39,7 +38,7 @@ e.g.
""" """
def quaternion_to_matrix(quaternions): def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
""" """
Convert rotations given as quaternions to rotation matrices. Convert rotations given as quaternions to rotation matrices.
@ -70,7 +69,7 @@ def quaternion_to_matrix(quaternions):
return o.reshape(quaternions.shape[:-1] + (3, 3)) return o.reshape(quaternions.shape[:-1] + (3, 3))
def _copysign(a, b): def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
""" """
Return a tensor where each element has the absolute value taken from the, Return a tensor where each element has the absolute value taken from the,
corresponding element of a, with sign taken from the corresponding corresponding element of a, with sign taken from the corresponding
@ -114,7 +113,7 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
batch_dim = matrix.shape[:-2] batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
matrix.reshape(*batch_dim, 9), dim=-1 matrix.reshape(batch_dim + (9,)), dim=-1
) )
q_abs = _sqrt_positive_part( q_abs = _sqrt_positive_part(
@ -142,17 +141,18 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
# We floor here at 0.1 but the exact level is not important; if q_abs is small, # We floor here at 0.1 but the exact level is not important; if q_abs is small,
# the candidate won't be picked. # the candidate won't be picked.
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(q_abs.new_tensor(0.1))) flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
# if not for numerical problems, quat_candidates[i] should be same (up to a sign), # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
# forall i; we pick the best-conditioned one (with the largest denominator) # forall i; we pick the best-conditioned one (with the largest denominator)
return quat_candidates[ return quat_candidates[
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : # pyre-ignore[16] F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : # pyre-ignore[16]
].reshape(*batch_dim, 4) ].reshape(batch_dim + (4,))
def _axis_angle_rotation(axis: str, angle): def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
""" """
Return the rotation matrices for one of the rotations about an axis Return the rotation matrices for one of the rotations about an axis
of which Euler angles describe, for each value of the angle given. of which Euler angles describe, for each value of the angle given.
@ -172,15 +172,17 @@ def _axis_angle_rotation(axis: str, angle):
if axis == "X": if axis == "X":
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
if axis == "Y": elif axis == "Y":
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
if axis == "Z": elif axis == "Z":
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
else:
raise ValueError("letter must be either X, Y or Z.")
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
def euler_angles_to_matrix(euler_angles, convention: str): def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor:
""" """
Convert rotations given as Euler angles in radians to rotation matrices. Convert rotations given as Euler angles in radians to rotation matrices.
@ -201,13 +203,17 @@ def euler_angles_to_matrix(euler_angles, convention: str):
for letter in convention: for letter in convention:
if letter not in ("X", "Y", "Z"): if letter not in ("X", "Y", "Z"):
raise ValueError(f"Invalid letter {letter} in convention string.") raise ValueError(f"Invalid letter {letter} in convention string.")
matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)) matrices = [
return functools.reduce(torch.matmul, matrices) _axis_angle_rotation(c, e)
for c, e in zip(convention, torch.unbind(euler_angles, -1))
]
# return functools.reduce(torch.matmul, matrices)
return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
def _angle_from_tan( def _angle_from_tan(
axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
): ) -> torch.Tensor:
""" """
Extract the first or third Euler angle from the two members of Extract the first or third Euler angle from the two members of
the matrix which are positive constant times its sine and cosine. the matrix which are positive constant times its sine and cosine.
@ -238,16 +244,17 @@ def _angle_from_tan(
return torch.atan2(data[..., i2], -data[..., i1]) return torch.atan2(data[..., i2], -data[..., i1])
def _index_from_letter(letter: str): def _index_from_letter(letter: str) -> int:
if letter == "X": if letter == "X":
return 0 return 0
if letter == "Y": if letter == "Y":
return 1 return 1
if letter == "Z": if letter == "Z":
return 2 return 2
raise ValueError("letter must be either X, Y or Z.")
def matrix_to_euler_angles(matrix, convention: str): def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor:
""" """
Convert rotations given as rotation matrices to Euler angles in radians. Convert rotations given as rotation matrices to Euler angles in radians.
@ -291,7 +298,7 @@ def matrix_to_euler_angles(matrix, convention: str):
def random_quaternions( def random_quaternions(
n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
): ) -> torch.Tensor:
""" """
Generate random quaternions representing rotations, Generate random quaternions representing rotations,
i.e. versors with nonnegative real part. i.e. versors with nonnegative real part.
@ -305,6 +312,8 @@ def random_quaternions(
Returns: Returns:
Quaternions as tensor of shape (N, 4). Quaternions as tensor of shape (N, 4).
""" """
if isinstance(device, str):
device = torch.device(device)
o = torch.randn((n, 4), dtype=dtype, device=device) o = torch.randn((n, 4), dtype=dtype, device=device)
s = (o * o).sum(1) s = (o * o).sum(1)
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
@ -313,7 +322,7 @@ def random_quaternions(
def random_rotations( def random_rotations(
n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
): ) -> torch.Tensor:
""" """
Generate random rotations as 3x3 rotation matrices. Generate random rotations as 3x3 rotation matrices.
@ -332,7 +341,7 @@ def random_rotations(
def random_rotation( def random_rotation(
dtype: Optional[torch.dtype] = None, device: Optional[Device] = None dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
): ) -> torch.Tensor:
""" """
Generate a single random 3x3 rotation matrix. Generate a single random 3x3 rotation matrix.
@ -347,7 +356,7 @@ def random_rotation(
return random_rotations(1, dtype, device)[0] return random_rotations(1, dtype, device)[0]
def standardize_quaternion(quaternions): def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
""" """
Convert a unit quaternion to a standard form: one in which the real Convert a unit quaternion to a standard form: one in which the real
part is non negative. part is non negative.
@ -362,7 +371,7 @@ def standardize_quaternion(quaternions):
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
def quaternion_raw_multiply(a, b): def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
""" """
Multiply two quaternions. Multiply two quaternions.
Usual torch rules for broadcasting apply. Usual torch rules for broadcasting apply.
@ -383,7 +392,7 @@ def quaternion_raw_multiply(a, b):
return torch.stack((ow, ox, oy, oz), -1) return torch.stack((ow, ox, oy, oz), -1)
def quaternion_multiply(a, b): def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
""" """
Multiply two quaternions representing rotations, returning the quaternion Multiply two quaternions representing rotations, returning the quaternion
representing their composition, i.e. the versor with nonnegative real part. representing their composition, i.e. the versor with nonnegative real part.
@ -400,7 +409,7 @@ def quaternion_multiply(a, b):
return standardize_quaternion(ab) return standardize_quaternion(ab)
def quaternion_invert(quaternion): def quaternion_invert(quaternion: torch.Tensor) -> torch.Tensor:
""" """
Given a quaternion representing rotation, get the quaternion representing Given a quaternion representing rotation, get the quaternion representing
its inverse. its inverse.
@ -413,10 +422,11 @@ def quaternion_invert(quaternion):
The inverse, a tensor of quaternions of shape (..., 4). The inverse, a tensor of quaternions of shape (..., 4).
""" """
return quaternion * quaternion.new_tensor([1, -1, -1, -1]) scaling = torch.tensor([1, -1, -1, -1], device=quaternion.device)
return quaternion * scaling
def quaternion_apply(quaternion, point): def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Tensor:
""" """
Apply the rotation given by a quaternion to a 3D point. Apply the rotation given by a quaternion to a 3D point.
Usual torch rules for broadcasting apply. Usual torch rules for broadcasting apply.
@ -439,7 +449,7 @@ def quaternion_apply(quaternion, point):
return out[..., 1:] return out[..., 1:]
def axis_angle_to_matrix(axis_angle): def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
""" """
Convert rotations given as axis/angle to rotation matrices. Convert rotations given as axis/angle to rotation matrices.
@ -455,7 +465,7 @@ def axis_angle_to_matrix(axis_angle):
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
def matrix_to_axis_angle(matrix): def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
""" """
Convert rotations given as rotation matrices to axis/angle. Convert rotations given as rotation matrices to axis/angle.
@ -471,7 +481,7 @@ def matrix_to_axis_angle(matrix):
return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
def axis_angle_to_quaternion(axis_angle): def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
""" """
Convert rotations given as axis/angle to quaternions. Convert rotations given as axis/angle to quaternions.
@ -485,7 +495,7 @@ def axis_angle_to_quaternion(axis_angle):
quaternions with real part first, as tensor of shape (..., 4). quaternions with real part first, as tensor of shape (..., 4).
""" """
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
half_angles = 0.5 * angles half_angles = angles * 0.5
eps = 1e-6 eps = 1e-6
small_angles = angles.abs() < eps small_angles = angles.abs() < eps
sin_half_angles_over_angles = torch.empty_like(angles) sin_half_angles_over_angles = torch.empty_like(angles)
@ -503,7 +513,7 @@ def axis_angle_to_quaternion(axis_angle):
return quaternions return quaternions
def quaternion_to_axis_angle(quaternions): def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
""" """
Convert rotations given as quaternions to axis/angle. Convert rotations given as quaternions to axis/angle.
@ -573,4 +583,5 @@ def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
IEEE Conference on Computer Vision and Pattern Recognition, 2019. IEEE Conference on Computer Vision and Pattern Recognition, 2019.
Retrieved from http://arxiv.org/abs/1812.07035 Retrieved from http://arxiv.org/abs/1812.07035
""" """
return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6) batch_dim = matrix.size()[:-2]
return matrix[..., :2, :].clone().reshape(batch_dim + (6,))

View File

@ -8,6 +8,7 @@
import itertools import itertools
import math import math
import unittest import unittest
from distutils.version import LooseVersion
from typing import Optional, Union from typing import Optional, Union
import numpy as np import numpy as np
@ -264,6 +265,25 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
torch.matmul(r, r.permute(0, 2, 1)), torch.eye(3).expand_as(r), atol=1e-6 torch.matmul(r, r.permute(0, 2, 1)), torch.eye(3).expand_as(r), atol=1e-6
) )
@unittest.skipIf(LooseVersion(torch.__version__) < "1.9", "recent torchscript only")
def test_scriptable(self):
torch.jit.script(axis_angle_to_matrix)
torch.jit.script(axis_angle_to_quaternion)
torch.jit.script(euler_angles_to_matrix)
torch.jit.script(matrix_to_axis_angle)
torch.jit.script(matrix_to_euler_angles)
torch.jit.script(matrix_to_quaternion)
torch.jit.script(matrix_to_rotation_6d)
torch.jit.script(quaternion_apply)
torch.jit.script(quaternion_multiply)
torch.jit.script(quaternion_to_matrix)
torch.jit.script(quaternion_to_axis_angle)
torch.jit.script(random_quaternions)
torch.jit.script(random_rotation)
torch.jit.script(random_rotations)
torch.jit.script(random_quaternions)
torch.jit.script(rotation_6d_to_matrix)
def _assert_quaternions_close( def _assert_quaternions_close(
self, self,
input: Union[torch.Tensor, np.ndarray], input: Union[torch.Tensor, np.ndarray],