Make some matrix conversion jittable (#898)

Summary:
Make sure the functions from `rotation_conversion` are jittable, and add some type hints.

Add tests to verify this is the case.

Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/898

Reviewed By: patricklabatut

Differential Revision: D31926103

Pulled By: bottler

fbshipit-source-id: bff6013c5ca2d452e37e631bd902f0674d5ca091
This commit is contained in:
una-dinosauria 2021-10-26 14:30:45 -07:00 committed by Facebook GitHub Bot
parent 29417d1f9b
commit bee31c48d3
2 changed files with 61 additions and 30 deletions

View File

@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import functools
from typing import Optional
import torch
@ -39,7 +38,7 @@ e.g.
"""
def quaternion_to_matrix(quaternions):
def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as quaternions to rotation matrices.
@ -70,7 +69,7 @@ def quaternion_to_matrix(quaternions):
return o.reshape(quaternions.shape[:-1] + (3, 3))
def _copysign(a, b):
def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""
Return a tensor where each element has the absolute value taken from the,
corresponding element of a, with sign taken from the corresponding
@ -114,7 +113,7 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
matrix.reshape(*batch_dim, 9), dim=-1
matrix.reshape(batch_dim + (9,)), dim=-1
)
q_abs = _sqrt_positive_part(
@ -142,17 +141,18 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
# the candidate won't be picked.
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(q_abs.new_tensor(0.1)))
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
# forall i; we pick the best-conditioned one (with the largest denominator)
return quat_candidates[
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : # pyre-ignore[16]
].reshape(*batch_dim, 4)
].reshape(batch_dim + (4,))
def _axis_angle_rotation(axis: str, angle):
def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
"""
Return the rotation matrices for one of the rotations about an axis
of which Euler angles describe, for each value of the angle given.
@ -172,15 +172,17 @@ def _axis_angle_rotation(axis: str, angle):
if axis == "X":
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
if axis == "Y":
elif axis == "Y":
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
if axis == "Z":
elif axis == "Z":
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
else:
raise ValueError("letter must be either X, Y or Z.")
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
def euler_angles_to_matrix(euler_angles, convention: str):
def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor:
"""
Convert rotations given as Euler angles in radians to rotation matrices.
@ -201,13 +203,17 @@ def euler_angles_to_matrix(euler_angles, convention: str):
for letter in convention:
if letter not in ("X", "Y", "Z"):
raise ValueError(f"Invalid letter {letter} in convention string.")
matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1))
return functools.reduce(torch.matmul, matrices)
matrices = [
_axis_angle_rotation(c, e)
for c, e in zip(convention, torch.unbind(euler_angles, -1))
]
# return functools.reduce(torch.matmul, matrices)
return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
def _angle_from_tan(
axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
):
) -> torch.Tensor:
"""
Extract the first or third Euler angle from the two members of
the matrix which are positive constant times its sine and cosine.
@ -238,16 +244,17 @@ def _angle_from_tan(
return torch.atan2(data[..., i2], -data[..., i1])
def _index_from_letter(letter: str):
def _index_from_letter(letter: str) -> int:
if letter == "X":
return 0
if letter == "Y":
return 1
if letter == "Z":
return 2
raise ValueError("letter must be either X, Y or Z.")
def matrix_to_euler_angles(matrix, convention: str):
def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to Euler angles in radians.
@ -291,7 +298,7 @@ def matrix_to_euler_angles(matrix, convention: str):
def random_quaternions(
n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
):
) -> torch.Tensor:
"""
Generate random quaternions representing rotations,
i.e. versors with nonnegative real part.
@ -305,6 +312,8 @@ def random_quaternions(
Returns:
Quaternions as tensor of shape (N, 4).
"""
if isinstance(device, str):
device = torch.device(device)
o = torch.randn((n, 4), dtype=dtype, device=device)
s = (o * o).sum(1)
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
@ -313,7 +322,7 @@ def random_quaternions(
def random_rotations(
n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
):
) -> torch.Tensor:
"""
Generate random rotations as 3x3 rotation matrices.
@ -332,7 +341,7 @@ def random_rotations(
def random_rotation(
dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
):
) -> torch.Tensor:
"""
Generate a single random 3x3 rotation matrix.
@ -347,7 +356,7 @@ def random_rotation(
return random_rotations(1, dtype, device)[0]
def standardize_quaternion(quaternions):
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert a unit quaternion to a standard form: one in which the real
part is non negative.
@ -362,7 +371,7 @@ def standardize_quaternion(quaternions):
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
def quaternion_raw_multiply(a, b):
def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""
Multiply two quaternions.
Usual torch rules for broadcasting apply.
@ -383,7 +392,7 @@ def quaternion_raw_multiply(a, b):
return torch.stack((ow, ox, oy, oz), -1)
def quaternion_multiply(a, b):
def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""
Multiply two quaternions representing rotations, returning the quaternion
representing their composition, i.e. the versor with nonnegative real part.
@ -400,7 +409,7 @@ def quaternion_multiply(a, b):
return standardize_quaternion(ab)
def quaternion_invert(quaternion):
def quaternion_invert(quaternion: torch.Tensor) -> torch.Tensor:
"""
Given a quaternion representing rotation, get the quaternion representing
its inverse.
@ -413,10 +422,11 @@ def quaternion_invert(quaternion):
The inverse, a tensor of quaternions of shape (..., 4).
"""
return quaternion * quaternion.new_tensor([1, -1, -1, -1])
scaling = torch.tensor([1, -1, -1, -1], device=quaternion.device)
return quaternion * scaling
def quaternion_apply(quaternion, point):
def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Tensor:
"""
Apply the rotation given by a quaternion to a 3D point.
Usual torch rules for broadcasting apply.
@ -439,7 +449,7 @@ def quaternion_apply(quaternion, point):
return out[..., 1:]
def axis_angle_to_matrix(axis_angle):
def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as axis/angle to rotation matrices.
@ -455,7 +465,7 @@ def axis_angle_to_matrix(axis_angle):
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
def matrix_to_axis_angle(matrix):
def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to axis/angle.
@ -471,7 +481,7 @@ def matrix_to_axis_angle(matrix):
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
def axis_angle_to_quaternion(axis_angle):
def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as axis/angle to quaternions.
@ -485,7 +495,7 @@ def axis_angle_to_quaternion(axis_angle):
quaternions with real part first, as tensor of shape (..., 4).
"""
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
half_angles = 0.5 * angles
half_angles = angles * 0.5
eps = 1e-6
small_angles = angles.abs() < eps
sin_half_angles_over_angles = torch.empty_like(angles)
@ -503,7 +513,7 @@ def axis_angle_to_quaternion(axis_angle):
return quaternions
def quaternion_to_axis_angle(quaternions):
def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as quaternions to axis/angle.
@ -573,4 +583,5 @@ def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
Retrieved from http://arxiv.org/abs/1812.07035
"""
return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)
batch_dim = matrix.size()[:-2]
return matrix[..., :2, :].clone().reshape(batch_dim + (6,))

View File

@ -8,6 +8,7 @@
import itertools
import math
import unittest
from distutils.version import LooseVersion
from typing import Optional, Union
import numpy as np
@ -264,6 +265,25 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
torch.matmul(r, r.permute(0, 2, 1)), torch.eye(3).expand_as(r), atol=1e-6
)
@unittest.skipIf(LooseVersion(torch.__version__) < "1.9", "recent torchscript only")
def test_scriptable(self):
torch.jit.script(axis_angle_to_matrix)
torch.jit.script(axis_angle_to_quaternion)
torch.jit.script(euler_angles_to_matrix)
torch.jit.script(matrix_to_axis_angle)
torch.jit.script(matrix_to_euler_angles)
torch.jit.script(matrix_to_quaternion)
torch.jit.script(matrix_to_rotation_6d)
torch.jit.script(quaternion_apply)
torch.jit.script(quaternion_multiply)
torch.jit.script(quaternion_to_matrix)
torch.jit.script(quaternion_to_axis_angle)
torch.jit.script(random_quaternions)
torch.jit.script(random_rotation)
torch.jit.script(random_rotations)
torch.jit.script(random_quaternions)
torch.jit.script(rotation_6d_to_matrix)
def _assert_quaternions_close(
self,
input: Union[torch.Tensor, np.ndarray],