mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Make some matrix conversion jittable (#898)
Summary: Make sure the functions from `rotation_conversion` are jittable, and add some type hints. Add tests to verify this is the case. Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/898 Reviewed By: patricklabatut Differential Revision: D31926103 Pulled By: bottler fbshipit-source-id: bff6013c5ca2d452e37e631bd902f0674d5ca091
This commit is contained in:
parent
29417d1f9b
commit
bee31c48d3
@ -4,7 +4,6 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import functools
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -39,7 +38,7 @@ e.g.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def quaternion_to_matrix(quaternions):
|
def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Convert rotations given as quaternions to rotation matrices.
|
Convert rotations given as quaternions to rotation matrices.
|
||||||
|
|
||||||
@ -70,7 +69,7 @@ def quaternion_to_matrix(quaternions):
|
|||||||
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
||||||
|
|
||||||
|
|
||||||
def _copysign(a, b):
|
def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Return a tensor where each element has the absolute value taken from the,
|
Return a tensor where each element has the absolute value taken from the,
|
||||||
corresponding element of a, with sign taken from the corresponding
|
corresponding element of a, with sign taken from the corresponding
|
||||||
@ -114,7 +113,7 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
|||||||
|
|
||||||
batch_dim = matrix.shape[:-2]
|
batch_dim = matrix.shape[:-2]
|
||||||
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
|
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
|
||||||
matrix.reshape(*batch_dim, 9), dim=-1
|
matrix.reshape(batch_dim + (9,)), dim=-1
|
||||||
)
|
)
|
||||||
|
|
||||||
q_abs = _sqrt_positive_part(
|
q_abs = _sqrt_positive_part(
|
||||||
@ -142,17 +141,18 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
|||||||
|
|
||||||
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
|
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
|
||||||
# the candidate won't be picked.
|
# the candidate won't be picked.
|
||||||
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(q_abs.new_tensor(0.1)))
|
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
||||||
|
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
||||||
|
|
||||||
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
||||||
# forall i; we pick the best-conditioned one (with the largest denominator)
|
# forall i; we pick the best-conditioned one (with the largest denominator)
|
||||||
|
|
||||||
return quat_candidates[
|
return quat_candidates[
|
||||||
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : # pyre-ignore[16]
|
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : # pyre-ignore[16]
|
||||||
].reshape(*batch_dim, 4)
|
].reshape(batch_dim + (4,))
|
||||||
|
|
||||||
|
|
||||||
def _axis_angle_rotation(axis: str, angle):
|
def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Return the rotation matrices for one of the rotations about an axis
|
Return the rotation matrices for one of the rotations about an axis
|
||||||
of which Euler angles describe, for each value of the angle given.
|
of which Euler angles describe, for each value of the angle given.
|
||||||
@ -172,15 +172,17 @@ def _axis_angle_rotation(axis: str, angle):
|
|||||||
|
|
||||||
if axis == "X":
|
if axis == "X":
|
||||||
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
|
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
|
||||||
if axis == "Y":
|
elif axis == "Y":
|
||||||
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
|
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
|
||||||
if axis == "Z":
|
elif axis == "Z":
|
||||||
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
|
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
|
||||||
|
else:
|
||||||
|
raise ValueError("letter must be either X, Y or Z.")
|
||||||
|
|
||||||
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
|
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
|
||||||
|
|
||||||
|
|
||||||
def euler_angles_to_matrix(euler_angles, convention: str):
|
def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Convert rotations given as Euler angles in radians to rotation matrices.
|
Convert rotations given as Euler angles in radians to rotation matrices.
|
||||||
|
|
||||||
@ -201,13 +203,17 @@ def euler_angles_to_matrix(euler_angles, convention: str):
|
|||||||
for letter in convention:
|
for letter in convention:
|
||||||
if letter not in ("X", "Y", "Z"):
|
if letter not in ("X", "Y", "Z"):
|
||||||
raise ValueError(f"Invalid letter {letter} in convention string.")
|
raise ValueError(f"Invalid letter {letter} in convention string.")
|
||||||
matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1))
|
matrices = [
|
||||||
return functools.reduce(torch.matmul, matrices)
|
_axis_angle_rotation(c, e)
|
||||||
|
for c, e in zip(convention, torch.unbind(euler_angles, -1))
|
||||||
|
]
|
||||||
|
# return functools.reduce(torch.matmul, matrices)
|
||||||
|
return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
|
||||||
|
|
||||||
|
|
||||||
def _angle_from_tan(
|
def _angle_from_tan(
|
||||||
axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
|
axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
|
||||||
):
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Extract the first or third Euler angle from the two members of
|
Extract the first or third Euler angle from the two members of
|
||||||
the matrix which are positive constant times its sine and cosine.
|
the matrix which are positive constant times its sine and cosine.
|
||||||
@ -238,16 +244,17 @@ def _angle_from_tan(
|
|||||||
return torch.atan2(data[..., i2], -data[..., i1])
|
return torch.atan2(data[..., i2], -data[..., i1])
|
||||||
|
|
||||||
|
|
||||||
def _index_from_letter(letter: str):
|
def _index_from_letter(letter: str) -> int:
|
||||||
if letter == "X":
|
if letter == "X":
|
||||||
return 0
|
return 0
|
||||||
if letter == "Y":
|
if letter == "Y":
|
||||||
return 1
|
return 1
|
||||||
if letter == "Z":
|
if letter == "Z":
|
||||||
return 2
|
return 2
|
||||||
|
raise ValueError("letter must be either X, Y or Z.")
|
||||||
|
|
||||||
|
|
||||||
def matrix_to_euler_angles(matrix, convention: str):
|
def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Convert rotations given as rotation matrices to Euler angles in radians.
|
Convert rotations given as rotation matrices to Euler angles in radians.
|
||||||
|
|
||||||
@ -291,7 +298,7 @@ def matrix_to_euler_angles(matrix, convention: str):
|
|||||||
|
|
||||||
def random_quaternions(
|
def random_quaternions(
|
||||||
n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
|
n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
|
||||||
):
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Generate random quaternions representing rotations,
|
Generate random quaternions representing rotations,
|
||||||
i.e. versors with nonnegative real part.
|
i.e. versors with nonnegative real part.
|
||||||
@ -305,6 +312,8 @@ def random_quaternions(
|
|||||||
Returns:
|
Returns:
|
||||||
Quaternions as tensor of shape (N, 4).
|
Quaternions as tensor of shape (N, 4).
|
||||||
"""
|
"""
|
||||||
|
if isinstance(device, str):
|
||||||
|
device = torch.device(device)
|
||||||
o = torch.randn((n, 4), dtype=dtype, device=device)
|
o = torch.randn((n, 4), dtype=dtype, device=device)
|
||||||
s = (o * o).sum(1)
|
s = (o * o).sum(1)
|
||||||
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
|
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
|
||||||
@ -313,7 +322,7 @@ def random_quaternions(
|
|||||||
|
|
||||||
def random_rotations(
|
def random_rotations(
|
||||||
n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
|
n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
|
||||||
):
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Generate random rotations as 3x3 rotation matrices.
|
Generate random rotations as 3x3 rotation matrices.
|
||||||
|
|
||||||
@ -332,7 +341,7 @@ def random_rotations(
|
|||||||
|
|
||||||
def random_rotation(
|
def random_rotation(
|
||||||
dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
|
dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
|
||||||
):
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Generate a single random 3x3 rotation matrix.
|
Generate a single random 3x3 rotation matrix.
|
||||||
|
|
||||||
@ -347,7 +356,7 @@ def random_rotation(
|
|||||||
return random_rotations(1, dtype, device)[0]
|
return random_rotations(1, dtype, device)[0]
|
||||||
|
|
||||||
|
|
||||||
def standardize_quaternion(quaternions):
|
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Convert a unit quaternion to a standard form: one in which the real
|
Convert a unit quaternion to a standard form: one in which the real
|
||||||
part is non negative.
|
part is non negative.
|
||||||
@ -362,7 +371,7 @@ def standardize_quaternion(quaternions):
|
|||||||
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
|
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
|
||||||
|
|
||||||
|
|
||||||
def quaternion_raw_multiply(a, b):
|
def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Multiply two quaternions.
|
Multiply two quaternions.
|
||||||
Usual torch rules for broadcasting apply.
|
Usual torch rules for broadcasting apply.
|
||||||
@ -383,7 +392,7 @@ def quaternion_raw_multiply(a, b):
|
|||||||
return torch.stack((ow, ox, oy, oz), -1)
|
return torch.stack((ow, ox, oy, oz), -1)
|
||||||
|
|
||||||
|
|
||||||
def quaternion_multiply(a, b):
|
def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Multiply two quaternions representing rotations, returning the quaternion
|
Multiply two quaternions representing rotations, returning the quaternion
|
||||||
representing their composition, i.e. the versor with nonnegative real part.
|
representing their composition, i.e. the versor with nonnegative real part.
|
||||||
@ -400,7 +409,7 @@ def quaternion_multiply(a, b):
|
|||||||
return standardize_quaternion(ab)
|
return standardize_quaternion(ab)
|
||||||
|
|
||||||
|
|
||||||
def quaternion_invert(quaternion):
|
def quaternion_invert(quaternion: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Given a quaternion representing rotation, get the quaternion representing
|
Given a quaternion representing rotation, get the quaternion representing
|
||||||
its inverse.
|
its inverse.
|
||||||
@ -413,10 +422,11 @@ def quaternion_invert(quaternion):
|
|||||||
The inverse, a tensor of quaternions of shape (..., 4).
|
The inverse, a tensor of quaternions of shape (..., 4).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return quaternion * quaternion.new_tensor([1, -1, -1, -1])
|
scaling = torch.tensor([1, -1, -1, -1], device=quaternion.device)
|
||||||
|
return quaternion * scaling
|
||||||
|
|
||||||
|
|
||||||
def quaternion_apply(quaternion, point):
|
def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Apply the rotation given by a quaternion to a 3D point.
|
Apply the rotation given by a quaternion to a 3D point.
|
||||||
Usual torch rules for broadcasting apply.
|
Usual torch rules for broadcasting apply.
|
||||||
@ -439,7 +449,7 @@ def quaternion_apply(quaternion, point):
|
|||||||
return out[..., 1:]
|
return out[..., 1:]
|
||||||
|
|
||||||
|
|
||||||
def axis_angle_to_matrix(axis_angle):
|
def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Convert rotations given as axis/angle to rotation matrices.
|
Convert rotations given as axis/angle to rotation matrices.
|
||||||
|
|
||||||
@ -455,7 +465,7 @@ def axis_angle_to_matrix(axis_angle):
|
|||||||
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
|
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
|
||||||
|
|
||||||
|
|
||||||
def matrix_to_axis_angle(matrix):
|
def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Convert rotations given as rotation matrices to axis/angle.
|
Convert rotations given as rotation matrices to axis/angle.
|
||||||
|
|
||||||
@ -471,7 +481,7 @@ def matrix_to_axis_angle(matrix):
|
|||||||
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
|
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
|
||||||
|
|
||||||
|
|
||||||
def axis_angle_to_quaternion(axis_angle):
|
def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Convert rotations given as axis/angle to quaternions.
|
Convert rotations given as axis/angle to quaternions.
|
||||||
|
|
||||||
@ -485,7 +495,7 @@ def axis_angle_to_quaternion(axis_angle):
|
|||||||
quaternions with real part first, as tensor of shape (..., 4).
|
quaternions with real part first, as tensor of shape (..., 4).
|
||||||
"""
|
"""
|
||||||
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
|
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
|
||||||
half_angles = 0.5 * angles
|
half_angles = angles * 0.5
|
||||||
eps = 1e-6
|
eps = 1e-6
|
||||||
small_angles = angles.abs() < eps
|
small_angles = angles.abs() < eps
|
||||||
sin_half_angles_over_angles = torch.empty_like(angles)
|
sin_half_angles_over_angles = torch.empty_like(angles)
|
||||||
@ -503,7 +513,7 @@ def axis_angle_to_quaternion(axis_angle):
|
|||||||
return quaternions
|
return quaternions
|
||||||
|
|
||||||
|
|
||||||
def quaternion_to_axis_angle(quaternions):
|
def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Convert rotations given as quaternions to axis/angle.
|
Convert rotations given as quaternions to axis/angle.
|
||||||
|
|
||||||
@ -573,4 +583,5 @@ def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
|
|||||||
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
||||||
Retrieved from http://arxiv.org/abs/1812.07035
|
Retrieved from http://arxiv.org/abs/1812.07035
|
||||||
"""
|
"""
|
||||||
return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)
|
batch_dim = matrix.size()[:-2]
|
||||||
|
return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
import itertools
|
import itertools
|
||||||
import math
|
import math
|
||||||
import unittest
|
import unittest
|
||||||
|
from distutils.version import LooseVersion
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -264,6 +265,25 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
|||||||
torch.matmul(r, r.permute(0, 2, 1)), torch.eye(3).expand_as(r), atol=1e-6
|
torch.matmul(r, r.permute(0, 2, 1)), torch.eye(3).expand_as(r), atol=1e-6
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@unittest.skipIf(LooseVersion(torch.__version__) < "1.9", "recent torchscript only")
|
||||||
|
def test_scriptable(self):
|
||||||
|
torch.jit.script(axis_angle_to_matrix)
|
||||||
|
torch.jit.script(axis_angle_to_quaternion)
|
||||||
|
torch.jit.script(euler_angles_to_matrix)
|
||||||
|
torch.jit.script(matrix_to_axis_angle)
|
||||||
|
torch.jit.script(matrix_to_euler_angles)
|
||||||
|
torch.jit.script(matrix_to_quaternion)
|
||||||
|
torch.jit.script(matrix_to_rotation_6d)
|
||||||
|
torch.jit.script(quaternion_apply)
|
||||||
|
torch.jit.script(quaternion_multiply)
|
||||||
|
torch.jit.script(quaternion_to_matrix)
|
||||||
|
torch.jit.script(quaternion_to_axis_angle)
|
||||||
|
torch.jit.script(random_quaternions)
|
||||||
|
torch.jit.script(random_rotation)
|
||||||
|
torch.jit.script(random_rotations)
|
||||||
|
torch.jit.script(random_quaternions)
|
||||||
|
torch.jit.script(rotation_6d_to_matrix)
|
||||||
|
|
||||||
def _assert_quaternions_close(
|
def _assert_quaternions_close(
|
||||||
self,
|
self,
|
||||||
input: Union[torch.Tensor, np.ndarray],
|
input: Union[torch.Tensor, np.ndarray],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user