diff --git a/pytorch3d/transforms/rotation_conversions.py b/pytorch3d/transforms/rotation_conversions.py index 98b22ca2..7d7aabfa 100644 --- a/pytorch3d/transforms/rotation_conversions.py +++ b/pytorch3d/transforms/rotation_conversions.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import functools from typing import Optional import torch @@ -39,7 +38,7 @@ e.g. """ -def quaternion_to_matrix(quaternions): +def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: """ Convert rotations given as quaternions to rotation matrices. @@ -70,7 +69,7 @@ def quaternion_to_matrix(quaternions): return o.reshape(quaternions.shape[:-1] + (3, 3)) -def _copysign(a, b): +def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """ Return a tensor where each element has the absolute value taken from the, corresponding element of a, with sign taken from the corresponding @@ -114,7 +113,7 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: batch_dim = matrix.shape[:-2] m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( - matrix.reshape(*batch_dim, 9), dim=-1 + matrix.reshape(batch_dim + (9,)), dim=-1 ) q_abs = _sqrt_positive_part( @@ -142,17 +141,18 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: # We floor here at 0.1 but the exact level is not important; if q_abs is small, # the candidate won't be picked. - quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(q_abs.new_tensor(0.1))) + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) # if not for numerical problems, quat_candidates[i] should be same (up to a sign), # forall i; we pick the best-conditioned one (with the largest denominator) return quat_candidates[ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : # pyre-ignore[16] - ].reshape(*batch_dim, 4) + ].reshape(batch_dim + (4,)) -def _axis_angle_rotation(axis: str, angle): +def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: """ Return the rotation matrices for one of the rotations about an axis of which Euler angles describe, for each value of the angle given. @@ -172,15 +172,17 @@ def _axis_angle_rotation(axis: str, angle): if axis == "X": R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) - if axis == "Y": + elif axis == "Y": R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) - if axis == "Z": + elif axis == "Z": R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + else: + raise ValueError("letter must be either X, Y or Z.") return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) -def euler_angles_to_matrix(euler_angles, convention: str): +def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor: """ Convert rotations given as Euler angles in radians to rotation matrices. @@ -201,13 +203,17 @@ def euler_angles_to_matrix(euler_angles, convention: str): for letter in convention: if letter not in ("X", "Y", "Z"): raise ValueError(f"Invalid letter {letter} in convention string.") - matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)) - return functools.reduce(torch.matmul, matrices) + matrices = [ + _axis_angle_rotation(c, e) + for c, e in zip(convention, torch.unbind(euler_angles, -1)) + ] + # return functools.reduce(torch.matmul, matrices) + return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) def _angle_from_tan( axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool -): +) -> torch.Tensor: """ Extract the first or third Euler angle from the two members of the matrix which are positive constant times its sine and cosine. @@ -238,16 +244,17 @@ def _angle_from_tan( return torch.atan2(data[..., i2], -data[..., i1]) -def _index_from_letter(letter: str): +def _index_from_letter(letter: str) -> int: if letter == "X": return 0 if letter == "Y": return 1 if letter == "Z": return 2 + raise ValueError("letter must be either X, Y or Z.") -def matrix_to_euler_angles(matrix, convention: str): +def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor: """ Convert rotations given as rotation matrices to Euler angles in radians. @@ -291,7 +298,7 @@ def matrix_to_euler_angles(matrix, convention: str): def random_quaternions( n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None -): +) -> torch.Tensor: """ Generate random quaternions representing rotations, i.e. versors with nonnegative real part. @@ -305,6 +312,8 @@ def random_quaternions( Returns: Quaternions as tensor of shape (N, 4). """ + if isinstance(device, str): + device = torch.device(device) o = torch.randn((n, 4), dtype=dtype, device=device) s = (o * o).sum(1) o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] @@ -313,7 +322,7 @@ def random_quaternions( def random_rotations( n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None -): +) -> torch.Tensor: """ Generate random rotations as 3x3 rotation matrices. @@ -332,7 +341,7 @@ def random_rotations( def random_rotation( dtype: Optional[torch.dtype] = None, device: Optional[Device] = None -): +) -> torch.Tensor: """ Generate a single random 3x3 rotation matrix. @@ -347,7 +356,7 @@ def random_rotation( return random_rotations(1, dtype, device)[0] -def standardize_quaternion(quaternions): +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: """ Convert a unit quaternion to a standard form: one in which the real part is non negative. @@ -362,7 +371,7 @@ def standardize_quaternion(quaternions): return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) -def quaternion_raw_multiply(a, b): +def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """ Multiply two quaternions. Usual torch rules for broadcasting apply. @@ -383,7 +392,7 @@ def quaternion_raw_multiply(a, b): return torch.stack((ow, ox, oy, oz), -1) -def quaternion_multiply(a, b): +def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """ Multiply two quaternions representing rotations, returning the quaternion representing their composition, i.e. the versor with nonnegative real part. @@ -400,7 +409,7 @@ def quaternion_multiply(a, b): return standardize_quaternion(ab) -def quaternion_invert(quaternion): +def quaternion_invert(quaternion: torch.Tensor) -> torch.Tensor: """ Given a quaternion representing rotation, get the quaternion representing its inverse. @@ -413,10 +422,11 @@ def quaternion_invert(quaternion): The inverse, a tensor of quaternions of shape (..., 4). """ - return quaternion * quaternion.new_tensor([1, -1, -1, -1]) + scaling = torch.tensor([1, -1, -1, -1], device=quaternion.device) + return quaternion * scaling -def quaternion_apply(quaternion, point): +def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Tensor: """ Apply the rotation given by a quaternion to a 3D point. Usual torch rules for broadcasting apply. @@ -439,7 +449,7 @@ def quaternion_apply(quaternion, point): return out[..., 1:] -def axis_angle_to_matrix(axis_angle): +def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor: """ Convert rotations given as axis/angle to rotation matrices. @@ -455,7 +465,7 @@ def axis_angle_to_matrix(axis_angle): return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) -def matrix_to_axis_angle(matrix): +def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor: """ Convert rotations given as rotation matrices to axis/angle. @@ -471,7 +481,7 @@ def matrix_to_axis_angle(matrix): return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) -def axis_angle_to_quaternion(axis_angle): +def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor: """ Convert rotations given as axis/angle to quaternions. @@ -485,7 +495,7 @@ def axis_angle_to_quaternion(axis_angle): quaternions with real part first, as tensor of shape (..., 4). """ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) - half_angles = 0.5 * angles + half_angles = angles * 0.5 eps = 1e-6 small_angles = angles.abs() < eps sin_half_angles_over_angles = torch.empty_like(angles) @@ -503,7 +513,7 @@ def axis_angle_to_quaternion(axis_angle): return quaternions -def quaternion_to_axis_angle(quaternions): +def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor: """ Convert rotations given as quaternions to axis/angle. @@ -573,4 +583,5 @@ def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: IEEE Conference on Computer Vision and Pattern Recognition, 2019. Retrieved from http://arxiv.org/abs/1812.07035 """ - return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6) + batch_dim = matrix.size()[:-2] + return matrix[..., :2, :].clone().reshape(batch_dim + (6,)) diff --git a/tests/test_rotation_conversions.py b/tests/test_rotation_conversions.py index 85888a45..263da99c 100644 --- a/tests/test_rotation_conversions.py +++ b/tests/test_rotation_conversions.py @@ -8,6 +8,7 @@ import itertools import math import unittest +from distutils.version import LooseVersion from typing import Optional, Union import numpy as np @@ -264,6 +265,25 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase): torch.matmul(r, r.permute(0, 2, 1)), torch.eye(3).expand_as(r), atol=1e-6 ) + @unittest.skipIf(LooseVersion(torch.__version__) < "1.9", "recent torchscript only") + def test_scriptable(self): + torch.jit.script(axis_angle_to_matrix) + torch.jit.script(axis_angle_to_quaternion) + torch.jit.script(euler_angles_to_matrix) + torch.jit.script(matrix_to_axis_angle) + torch.jit.script(matrix_to_euler_angles) + torch.jit.script(matrix_to_quaternion) + torch.jit.script(matrix_to_rotation_6d) + torch.jit.script(quaternion_apply) + torch.jit.script(quaternion_multiply) + torch.jit.script(quaternion_to_matrix) + torch.jit.script(quaternion_to_axis_angle) + torch.jit.script(random_quaternions) + torch.jit.script(random_rotation) + torch.jit.script(random_rotations) + torch.jit.script(random_quaternions) + torch.jit.script(rotation_6d_to_matrix) + def _assert_quaternions_close( self, input: Union[torch.Tensor, np.ndarray],