diff --git a/pytorch3d/transforms/__init__.py b/pytorch3d/transforms/__init__.py index 5709bd6a..efa0d631 100644 --- a/pytorch3d/transforms/__init__.py +++ b/pytorch3d/transforms/__init__.py @@ -22,6 +22,7 @@ from .rotation_conversions import ( ) from .so3 import ( so3_exponential_map, + so3_exp_map, so3_log_map, so3_relative_angle, so3_rotation_angle, diff --git a/pytorch3d/transforms/so3.py b/pytorch3d/transforms/so3.py index 8371d8a6..750a47b8 100644 --- a/pytorch3d/transforms/so3.py +++ b/pytorch3d/transforms/so3.py @@ -1,13 +1,20 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +from typing import Tuple import torch +from ..transforms import acos_linear_extrapolation HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5 -def so3_relative_angle(R1, R2, cos_angle: bool = False): +def so3_relative_angle( + R1: torch.Tensor, + R2: torch.Tensor, + cos_angle: bool = False, + cos_bound: float = 1e-4, +) -> torch.Tensor: """ Calculates the relative angle (in radians) between pairs of rotation matrices `R1` and `R2` with `angle = acos(0.5 * (Trace(R1 R2^T)-1))` @@ -20,8 +27,12 @@ def so3_relative_angle(R1, R2, cos_angle: bool = False): R1: Batch of rotation matrices of shape `(minibatch, 3, 3)`. R2: Batch of rotation matrices of shape `(minibatch, 3, 3)`. cos_angle: If==True return cosine of the relative angle rather than - the angle itself. This can avoid the unstable - calculation of `acos`. + the angle itself. This can avoid the unstable calculation of `acos`. + cos_bound: Clamps the cosine of the relative rotation angle to + [-1 + cos_bound, 1 - cos_bound] to avoid non-finite outputs/gradients + of the `acos` call. Note that the non-finite outputs/gradients + are returned when the angle is requested (i.e. `cos_angle==False`) + and the rotation angle is close to 0 or π. Returns: Corresponding rotation angles of shape `(minibatch,)`. @@ -32,10 +43,15 @@ def so3_relative_angle(R1, R2, cos_angle: bool = False): ValueError if `R1` or `R2` has an unexpected trace. """ R12 = torch.bmm(R1, R2.permute(0, 2, 1)) - return so3_rotation_angle(R12, cos_angle=cos_angle) + return so3_rotation_angle(R12, cos_angle=cos_angle, cos_bound=cos_bound) -def so3_rotation_angle(R, eps: float = 1e-4, cos_angle: bool = False): +def so3_rotation_angle( + R: torch.Tensor, + eps: float = 1e-4, + cos_angle: bool = False, + cos_bound: float = 1e-4, +) -> torch.Tensor: """ Calculates angles (in radians) of a batch of rotation matrices `R` with `angle = acos(0.5 * (Trace(R)-1))`. The trace of the @@ -47,8 +63,13 @@ def so3_rotation_angle(R, eps: float = 1e-4, cos_angle: bool = False): R: Batch of rotation matrices of shape `(minibatch, 3, 3)`. eps: Tolerance for the valid trace check. cos_angle: If==True return cosine of the rotation angles rather than - the angle itself. This can avoid the unstable - calculation of `acos`. + the angle itself. This can avoid the unstable + calculation of `acos`. + cos_bound: Clamps the cosine of the rotation angle to + [-1 + cos_bound, 1 - cos_bound] to avoid non-finite outputs/gradients + of the `acos` call. Note that the non-finite outputs/gradients + are returned when the angle is requested (i.e. `cos_angle==False`) + and the rotation angle is close to 0 or π. Returns: Corresponding rotation angles of shape `(minibatch,)`. @@ -68,20 +89,19 @@ def so3_rotation_angle(R, eps: float = 1e-4, cos_angle: bool = False): if ((rot_trace < -1.0 - eps) + (rot_trace > 3.0 + eps)).any(): raise ValueError("A matrix has trace outside valid range [-1-eps,3+eps].") - # clamp to valid range - rot_trace = torch.clamp(rot_trace, -1.0, 3.0) - # phi ... rotation angle - phi = 0.5 * (rot_trace - 1.0) + phi_cos = (rot_trace - 1.0) * 0.5 if cos_angle: - return phi + return phi_cos else: - # pyre-fixme[16]: `float` has no attribute `acos`. - return phi.acos() + if cos_bound > 0.0: + return acos_linear_extrapolation(phi_cos, 1.0 - cos_bound) + else: + return torch.acos(phi_cos) -def so3_exponential_map(log_rot, eps: float = 0.0001): +def so3_exp_map(log_rot: torch.Tensor, eps: float = 0.0001) -> torch.Tensor: """ Convert a batch of logarithmic representations of rotation matrices `log_rot` to a batch of 3x3 rotation matrices using Rodrigues formula [1]. @@ -94,18 +114,31 @@ def so3_exponential_map(log_rot, eps: float = 0.0001): which is handled by clamping controlled with the `eps` argument. Args: - log_rot: Batch of vectors of shape `(minibatch , 3)`. + log_rot: Batch of vectors of shape `(minibatch, 3)`. eps: A float constant handling the conversion singularity. Returns: - Batch of rotation matrices of shape `(minibatch , 3 , 3)`. + Batch of rotation matrices of shape `(minibatch, 3, 3)`. Raises: ValueError if `log_rot` is of incorrect shape. [1] https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula """ + return _so3_exp_map(log_rot, eps=eps)[0] + +so3_exponential_map = so3_exp_map + + +def _so3_exp_map( + log_rot: torch.Tensor, eps: float = 0.0001 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + A helper function that computes the so3 exponential map and, + apart from the rotation matrix, also returns intermediate variables + that can be re-used in other functions. + """ _, dim = log_rot.shape if dim != 3: raise ValueError("Input tensor shape has to be Nx3.") @@ -117,27 +150,35 @@ def so3_exponential_map(log_rot, eps: float = 0.0001): fac1 = rot_angles_inv * rot_angles.sin() fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles.cos()) skews = hat(log_rot) + skews_square = torch.bmm(skews, skews) R = ( # pyre-fixme[16]: `float` has no attribute `__getitem__`. fac1[:, None, None] * skews - + fac2[:, None, None] * torch.bmm(skews, skews) + + fac2[:, None, None] * skews_square + torch.eye(3, dtype=log_rot.dtype, device=log_rot.device)[None] ) - return R + return R, rot_angles, skews, skews_square -def so3_log_map(R, eps: float = 0.0001): +def so3_log_map( + R: torch.Tensor, eps: float = 0.0001, cos_bound: float = 1e-4 +) -> torch.Tensor: """ Convert a batch of 3x3 rotation matrices `R` to a batch of 3-dimensional matrix logarithms of rotation matrices The conversion has a singularity around `(R=I)` which is handled - by clamping controlled with the `eps` argument. + by clamping controlled with the `eps` and `cos_bound` arguments. Args: R: batch of rotation matrices of shape `(minibatch, 3, 3)`. eps: A float constant handling the conversion singularity. + cos_bound: Clamps the cosine of the rotation angle to + [-1 + cos_bound, 1 - cos_bound] to avoid non-finite outputs/gradients + of the `acos` call when computing `so3_rotation_angle`. + Note that the non-finite outputs/gradients are returned when + the rotation angle is close to 0 or π. Returns: Batch of logarithms of input rotation matrices @@ -152,22 +193,26 @@ def so3_log_map(R, eps: float = 0.0001): if dim1 != 3 or dim2 != 3: raise ValueError("Input has to be a batch of 3x3 Tensors.") - phi = so3_rotation_angle(R) + phi = so3_rotation_angle(R, cos_bound=cos_bound, eps=eps) - phi_sin = phi.sin() + phi_sin = torch.sin(phi) - phi_denom = ( - torch.clamp(phi_sin.abs(), eps) * phi_sin.sign() - + (phi_sin == 0).type_as(phi) * eps - ) + # We want to avoid a tiny denominator of phi_factor = phi / (2.0 * phi_sin). + # Hence, for phi_sin.abs() <= 0.5 * eps, we approximate phi_factor with + # 2nd order Taylor expansion: phi_factor = 0.5 + (1.0 / 12) * phi**2 + phi_factor = torch.empty_like(phi) + ok_denom = phi_sin.abs() > (0.5 * eps) + phi_factor[~ok_denom] = 0.5 + (phi[~ok_denom] ** 2) * (1.0 / 12) + phi_factor[ok_denom] = phi[ok_denom] / (2.0 * phi_sin[ok_denom]) + + log_rot_hat = phi_factor[:, None, None] * (R - R.permute(0, 2, 1)) - log_rot_hat = (phi / (2.0 * phi_denom))[:, None, None] * (R - R.permute(0, 2, 1)) log_rot = hat_inv(log_rot_hat) return log_rot -def hat_inv(h): +def hat_inv(h: torch.Tensor) -> torch.Tensor: """ Compute the inverse Hat operator [1] of a batch of 3x3 matrices. @@ -188,9 +233,9 @@ def hat_inv(h): if dim1 != 3 or dim2 != 3: raise ValueError("Input has to be a batch of 3x3 Tensors.") - ss_diff = (h + h.permute(0, 2, 1)).abs().max() + ss_diff = torch.abs(h + h.permute(0, 2, 1)).max() if float(ss_diff) > HAT_INV_SKEW_SYMMETRIC_TOL: - raise ValueError("One of input matrices not skew-symmetric.") + raise ValueError("One of input matrices is not skew-symmetric.") x = h[:, 2, 1] y = h[:, 0, 2] @@ -201,7 +246,7 @@ def hat_inv(h): return v -def hat(v): +def hat(v: torch.Tensor) -> torch.Tensor: """ Compute the Hat operator [1] of a batch of 3D vectors. @@ -225,7 +270,7 @@ def hat(v): if dim != 3: raise ValueError("Input vectors have to be 3-dimensional.") - h = v.new_zeros(N, 3, 3) + h = torch.zeros((N, 3, 3), dtype=v.dtype, device=v.device) x, y, z = v.unbind(1) diff --git a/tests/test_so3.py b/tests/test_so3.py index b7958e4f..85ec908c 100644 --- a/tests/test_so3.py +++ b/tests/test_so3.py @@ -9,9 +9,10 @@ import torch from common_testing import TestCaseMixin from pytorch3d.transforms.so3 import ( hat, - so3_exponential_map, + so3_exp_map, so3_log_map, so3_relative_angle, + so3_rotation_angle, ) @@ -53,10 +54,10 @@ class TestSO3(TestCaseMixin, unittest.TestCase): def test_determinant(self): """ Tests whether the determinants of 3x3 rotation matrices produced - by `so3_exponential_map` are (almost) equal to 1. + by `so3_exp_map` are (almost) equal to 1. """ log_rot = TestSO3.init_log_rot(batch_size=30) - Rs = so3_exponential_map(log_rot) + Rs = so3_exp_map(log_rot) dets = torch.det(Rs) self.assertClose(dets, torch.ones_like(dets), atol=1e-4) @@ -75,14 +76,14 @@ class TestSO3(TestCaseMixin, unittest.TestCase): def test_bad_so3_input_value_err(self): """ - Tests whether `so3_exponential_map` and `so3_log_map` correctly return + Tests whether `so3_exp_map` and `so3_log_map` correctly return a ValueError if called with an argument of incorrect shape or, in case - of `so3_exponential_map`, unexpected trace. + of `so3_exp_map`, unexpected trace. """ device = torch.device("cuda:0") log_rot = torch.randn(size=[5, 4], device=device) with self.assertRaises(ValueError) as err: - so3_exponential_map(log_rot) + so3_exp_map(log_rot) self.assertTrue("Input tensor shape has to be Nx3." in str(err.exception)) rot = torch.randn(size=[5, 3, 5], device=device) @@ -106,17 +107,22 @@ class TestSO3(TestCaseMixin, unittest.TestCase): def test_so3_exp_singularity(self, batch_size: int = 100): """ - Tests whether the `so3_exponential_map` is robust to the input vectors + Tests whether the `so3_exp_map` is robust to the input vectors the norms of which are close to the numerically unstable region (vectors with low l2-norms). """ # generate random log-rotations with a tiny angle log_rot = TestSO3.init_log_rot(batch_size=batch_size) log_rot_small = log_rot * 1e-6 - R = so3_exponential_map(log_rot_small) + log_rot_small.requires_grad = True + R = so3_exp_map(log_rot_small) # tests whether all outputs are finite - R_sum = float(R.sum()) - self.assertEqual(R_sum, R_sum) + self.assertTrue(torch.isfinite(R).all()) + # tests whether the gradient is not None and all finite + loss = R.sum() + loss.backward() + self.assertIsNotNone(log_rot_small.grad) + self.assertTrue(torch.isfinite(log_rot_small.grad).all()) def test_so3_log_singularity(self, batch_size: int = 100): """ @@ -129,6 +135,107 @@ class TestSO3(TestCaseMixin, unittest.TestCase): identity = torch.eye(3, device=device) rot180 = identity * torch.tensor([[1.0, -1.0, -1.0]], device=device) r = [identity, rot180] + # add random rotations and random almost orthonormal matrices + r.extend( + [ + torch.qr(identity + torch.randn_like(identity) * 1e-4)[0] + + float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-3 + # this adds random noise to the second half + # of the random orthogonal matrices to generate + # near-orthogonal matrices + for i in range(batch_size - 2) + ] + ) + r = torch.stack(r) + r.requires_grad = True + # the log of the rotation matrix r + r_log = so3_log_map(r, cos_bound=1e-4, eps=1e-2) + # tests whether all outputs are finite + self.assertTrue(torch.isfinite(r_log).all()) + # tests whether the gradient is not None and all finite + loss = r.sum() + loss.backward() + self.assertIsNotNone(r.grad) + self.assertTrue(torch.isfinite(r.grad).all()) + + def test_so3_log_to_exp_to_log_to_exp(self, batch_size: int = 100): + """ + Check that + `so3_exp_map(so3_log_map(so3_exp_map(log_rot))) + == so3_exp_map(log_rot)` + for a randomly generated batch of rotation matrix logarithms `log_rot`. + Unlike `test_so3_log_to_exp_to_log`, this test checks the + correctness of converting a `log_rot` which contains values > math.pi. + """ + log_rot = 2.0 * TestSO3.init_log_rot(batch_size=batch_size) + # check also the singular cases where rot. angle = {0, 2pi} + log_rot[:2] = 0 + log_rot[1, 0] = 2.0 * math.pi - 1e-6 + rot = so3_exp_map(log_rot, eps=1e-4) + rot_ = so3_exp_map(so3_log_map(rot, eps=1e-4, cos_bound=1e-6), eps=1e-6) + self.assertClose(rot, rot_, atol=0.01) + angles = so3_relative_angle(rot, rot_, cos_bound=1e-6) + self.assertClose(angles, torch.zeros_like(angles), atol=0.01) + + def test_so3_log_to_exp_to_log(self, batch_size: int = 100): + """ + Check that `so3_log_map(so3_exp_map(log_rot))==log_rot` for + a randomly generated batch of rotation matrix logarithms `log_rot`. + """ + log_rot = TestSO3.init_log_rot(batch_size=batch_size) + # check also the singular cases where rot. angle = 0 + log_rot[:1] = 0 + log_rot_ = so3_log_map(so3_exp_map(log_rot)) + self.assertClose(log_rot, log_rot_, atol=1e-4) + + def test_so3_exp_to_log_to_exp(self, batch_size: int = 100): + """ + Check that `so3_exp_map(so3_log_map(R))==R` for + a batch of randomly generated rotation matrices `R`. + """ + rot = TestSO3.init_rot(batch_size=batch_size) + non_singular = (so3_rotation_angle(rot) - math.pi).abs() > 1e-2 + rot = rot[non_singular] + rot_ = so3_exp_map(so3_log_map(rot, eps=1e-8, cos_bound=1e-8), eps=1e-8) + self.assertClose(rot_, rot, atol=0.1) + angles = so3_relative_angle(rot, rot_, cos_bound=1e-4) + self.assertClose(angles, torch.zeros_like(angles), atol=0.1) + + def test_so3_cos_relative_angle(self, batch_size: int = 100): + """ + Check that `so3_relative_angle(R1, R2, cos_angle=False).cos()` + is the same as `so3_relative_angle(R1, R2, cos_angle=True)` for + batches of randomly generated rotation matrices `R1` and `R2`. + """ + rot1 = TestSO3.init_rot(batch_size=batch_size) + rot2 = TestSO3.init_rot(batch_size=batch_size) + angles = so3_relative_angle(rot1, rot2, cos_angle=False).cos() + angles_ = so3_relative_angle(rot1, rot2, cos_angle=True) + self.assertClose(angles, angles_, atol=1e-4) + + def test_so3_cos_angle(self, batch_size: int = 100): + """ + Check that `so3_rotation_angle(R, cos_angle=False).cos()` + is the same as `so3_rotation_angle(R, cos_angle=True)` for + a batch of randomly generated rotation matrices `R`. + """ + rot = TestSO3.init_rot(batch_size=batch_size) + angles = so3_rotation_angle(rot, cos_angle=False).cos() + angles_ = so3_rotation_angle(rot, cos_angle=True) + self.assertClose(angles, angles_, atol=1e-4) + + def test_so3_cos_bound(self, batch_size: int = 100): + """ + Checks that for an identity rotation `R=I`, the so3_rotation_angle returns + non-finite gradients when `cos_bound=None` and finite gradients + for `cos_bound > 0.0`. + """ + # generate random rotations with a tiny angle to generate cases + # with the gradient singularity + device = torch.device("cuda:0") + identity = torch.eye(3, device=device) + rot180 = identity * torch.tensor([[1.0, -1.0, -1.0]], device=device) + r = [identity, rot180] r.extend( [ torch.qr(identity + torch.randn_like(identity) * 1e-4)[0] @@ -136,65 +243,25 @@ class TestSO3(TestCaseMixin, unittest.TestCase): ] ) r = torch.stack(r) - # the log of the rotation matrix r - r_log = so3_log_map(r) - # tests whether all outputs are finite - r_sum = float(r_log.sum()) - self.assertEqual(r_sum, r_sum) - - def test_so3_log_to_exp_to_log_to_exp(self, batch_size: int = 100): - """ - Check that - `so3_exponential_map(so3_log_map(so3_exponential_map(log_rot))) - == so3_exponential_map(log_rot)` - for a randomly generated batch of rotation matrix logarithms `log_rot`. - Unlike `test_so3_log_to_exp_to_log`, this test checks the - correctness of converting a `log_rot` which contains values > math.pi. - """ - log_rot = 2.0 * TestSO3.init_log_rot(batch_size=batch_size) - # check also the singular cases where rot. angle = {0, pi, 2pi, 3pi} - log_rot[:3] = 0 - log_rot[1, 0] = math.pi - log_rot[2, 0] = 2.0 * math.pi - log_rot[3, 0] = 3.0 * math.pi - rot = so3_exponential_map(log_rot, eps=1e-8) - rot_ = so3_exponential_map(so3_log_map(rot, eps=1e-8), eps=1e-8) - angles = so3_relative_angle(rot, rot_) - self.assertClose(angles, torch.zeros_like(angles), atol=0.01) - - def test_so3_log_to_exp_to_log(self, batch_size: int = 100): - """ - Check that `so3_log_map(so3_exponential_map(log_rot))==log_rot` for - a randomly generated batch of rotation matrix logarithms `log_rot`. - """ - log_rot = TestSO3.init_log_rot(batch_size=batch_size) - # check also the singular cases where rot. angle = 0 - log_rot[:1] = 0 - log_rot_ = so3_log_map(so3_exponential_map(log_rot)) - self.assertClose(log_rot, log_rot_, atol=1e-4) - - def test_so3_exp_to_log_to_exp(self, batch_size: int = 100): - """ - Check that `so3_exponential_map(so3_log_map(R))==R` for - a batch of randomly generated rotation matrices `R`. - """ - rot = TestSO3.init_rot(batch_size=batch_size) - rot_ = so3_exponential_map(so3_log_map(rot, eps=1e-8), eps=1e-8) - angles = so3_relative_angle(rot, rot_) - # TODO: a lot of precision lost here ... - self.assertClose(angles, torch.zeros_like(angles), atol=0.1) - - def test_so3_cos_angle(self, batch_size: int = 100): - """ - Check that `so3_relative_angle(R1, R2, cos_angle=False).cos()` - is the same as `so3_relative_angle(R1, R2, cos_angle=True)` - batches of randomly generated rotation matrices `R1` and `R2`. - """ - rot1 = TestSO3.init_rot(batch_size=batch_size) - rot2 = TestSO3.init_rot(batch_size=batch_size) - angles = so3_relative_angle(rot1, rot2, cos_angle=False).cos() - angles_ = so3_relative_angle(rot1, rot2, cos_angle=True) - self.assertClose(angles, angles_) + r.requires_grad = True + for is_grad_finite in (True, False): + # clear the gradients and decide the cos_bound: + # for is_grad_finite we run so3_rotation_angle with cos_bound + # set to a small float, otherwise we set to 0.0 + r.grad = None + cos_bound = 1e-4 if is_grad_finite else 0.0 + # compute the angles of r + angles = so3_rotation_angle(r, cos_bound=cos_bound) + # tests whether all outputs are finite in both cases + self.assertTrue(torch.isfinite(angles).all()) + # compute the gradients + loss = angles.sum() + loss.backward() + # tests whether the gradient is not None for both cases + self.assertIsNotNone(r.grad) + if is_grad_finite: + # all grad values have to be finite + self.assertTrue(torch.isfinite(r.grad).all()) @staticmethod def so3_expmap(batch_size: int = 10): @@ -202,7 +269,7 @@ class TestSO3(TestCaseMixin, unittest.TestCase): torch.cuda.synchronize() def compute_rots(): - so3_exponential_map(log_rot) + so3_exp_map(log_rot) torch.cuda.synchronize() return compute_rots