SO3 improvements for stable gradients.

Summary:
Improves so3 functions to make gradient computation stable:
- Instead of `torch.acos`, uses `acos_linear_extrapolation` which has finite gradients of reasonable magnitude for all inputs.
- Adds tests for the latter.

The tests of the finiteness of the gradient in `test_so3_exp_singularity`, `test_so3_exp_singularity`, `test_so3_cos_bound` would fail if the `so3` functions would call `torch.acos` instead of `acos_linear_extrapolation`.

Reviewed By: bottler

Differential Revision: D23326429

fbshipit-source-id: dc296abf2ae3ddfb3942c8146621491a9cb740ee
This commit is contained in:
David Novotny 2021-06-21 04:47:31 -07:00 committed by Facebook GitHub Bot
parent dd45123f20
commit 9f14e82b5a
3 changed files with 216 additions and 103 deletions

View File

@ -22,6 +22,7 @@ from .rotation_conversions import (
) )
from .so3 import ( from .so3 import (
so3_exponential_map, so3_exponential_map,
so3_exp_map,
so3_log_map, so3_log_map,
so3_relative_angle, so3_relative_angle,
so3_rotation_angle, so3_rotation_angle,

View File

@ -1,13 +1,20 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Tuple
import torch import torch
from ..transforms import acos_linear_extrapolation
HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5 HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5
def so3_relative_angle(R1, R2, cos_angle: bool = False): def so3_relative_angle(
R1: torch.Tensor,
R2: torch.Tensor,
cos_angle: bool = False,
cos_bound: float = 1e-4,
) -> torch.Tensor:
""" """
Calculates the relative angle (in radians) between pairs of Calculates the relative angle (in radians) between pairs of
rotation matrices `R1` and `R2` with `angle = acos(0.5 * (Trace(R1 R2^T)-1))` rotation matrices `R1` and `R2` with `angle = acos(0.5 * (Trace(R1 R2^T)-1))`
@ -20,8 +27,12 @@ def so3_relative_angle(R1, R2, cos_angle: bool = False):
R1: Batch of rotation matrices of shape `(minibatch, 3, 3)`. R1: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
R2: Batch of rotation matrices of shape `(minibatch, 3, 3)`. R2: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
cos_angle: If==True return cosine of the relative angle rather than cos_angle: If==True return cosine of the relative angle rather than
the angle itself. This can avoid the unstable the angle itself. This can avoid the unstable calculation of `acos`.
calculation of `acos`. cos_bound: Clamps the cosine of the relative rotation angle to
[-1 + cos_bound, 1 - cos_bound] to avoid non-finite outputs/gradients
of the `acos` call. Note that the non-finite outputs/gradients
are returned when the angle is requested (i.e. `cos_angle==False`)
and the rotation angle is close to 0 or π.
Returns: Returns:
Corresponding rotation angles of shape `(minibatch,)`. Corresponding rotation angles of shape `(minibatch,)`.
@ -32,10 +43,15 @@ def so3_relative_angle(R1, R2, cos_angle: bool = False):
ValueError if `R1` or `R2` has an unexpected trace. ValueError if `R1` or `R2` has an unexpected trace.
""" """
R12 = torch.bmm(R1, R2.permute(0, 2, 1)) R12 = torch.bmm(R1, R2.permute(0, 2, 1))
return so3_rotation_angle(R12, cos_angle=cos_angle) return so3_rotation_angle(R12, cos_angle=cos_angle, cos_bound=cos_bound)
def so3_rotation_angle(R, eps: float = 1e-4, cos_angle: bool = False): def so3_rotation_angle(
R: torch.Tensor,
eps: float = 1e-4,
cos_angle: bool = False,
cos_bound: float = 1e-4,
) -> torch.Tensor:
""" """
Calculates angles (in radians) of a batch of rotation matrices `R` with Calculates angles (in radians) of a batch of rotation matrices `R` with
`angle = acos(0.5 * (Trace(R)-1))`. The trace of the `angle = acos(0.5 * (Trace(R)-1))`. The trace of the
@ -47,8 +63,13 @@ def so3_rotation_angle(R, eps: float = 1e-4, cos_angle: bool = False):
R: Batch of rotation matrices of shape `(minibatch, 3, 3)`. R: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
eps: Tolerance for the valid trace check. eps: Tolerance for the valid trace check.
cos_angle: If==True return cosine of the rotation angles rather than cos_angle: If==True return cosine of the rotation angles rather than
the angle itself. This can avoid the unstable the angle itself. This can avoid the unstable
calculation of `acos`. calculation of `acos`.
cos_bound: Clamps the cosine of the rotation angle to
[-1 + cos_bound, 1 - cos_bound] to avoid non-finite outputs/gradients
of the `acos` call. Note that the non-finite outputs/gradients
are returned when the angle is requested (i.e. `cos_angle==False`)
and the rotation angle is close to 0 or π.
Returns: Returns:
Corresponding rotation angles of shape `(minibatch,)`. Corresponding rotation angles of shape `(minibatch,)`.
@ -68,20 +89,19 @@ def so3_rotation_angle(R, eps: float = 1e-4, cos_angle: bool = False):
if ((rot_trace < -1.0 - eps) + (rot_trace > 3.0 + eps)).any(): if ((rot_trace < -1.0 - eps) + (rot_trace > 3.0 + eps)).any():
raise ValueError("A matrix has trace outside valid range [-1-eps,3+eps].") raise ValueError("A matrix has trace outside valid range [-1-eps,3+eps].")
# clamp to valid range
rot_trace = torch.clamp(rot_trace, -1.0, 3.0)
# phi ... rotation angle # phi ... rotation angle
phi = 0.5 * (rot_trace - 1.0) phi_cos = (rot_trace - 1.0) * 0.5
if cos_angle: if cos_angle:
return phi return phi_cos
else: else:
# pyre-fixme[16]: `float` has no attribute `acos`. if cos_bound > 0.0:
return phi.acos() return acos_linear_extrapolation(phi_cos, 1.0 - cos_bound)
else:
return torch.acos(phi_cos)
def so3_exponential_map(log_rot, eps: float = 0.0001): def so3_exp_map(log_rot: torch.Tensor, eps: float = 0.0001) -> torch.Tensor:
""" """
Convert a batch of logarithmic representations of rotation matrices `log_rot` Convert a batch of logarithmic representations of rotation matrices `log_rot`
to a batch of 3x3 rotation matrices using Rodrigues formula [1]. to a batch of 3x3 rotation matrices using Rodrigues formula [1].
@ -94,18 +114,31 @@ def so3_exponential_map(log_rot, eps: float = 0.0001):
which is handled by clamping controlled with the `eps` argument. which is handled by clamping controlled with the `eps` argument.
Args: Args:
log_rot: Batch of vectors of shape `(minibatch , 3)`. log_rot: Batch of vectors of shape `(minibatch, 3)`.
eps: A float constant handling the conversion singularity. eps: A float constant handling the conversion singularity.
Returns: Returns:
Batch of rotation matrices of shape `(minibatch , 3 , 3)`. Batch of rotation matrices of shape `(minibatch, 3, 3)`.
Raises: Raises:
ValueError if `log_rot` is of incorrect shape. ValueError if `log_rot` is of incorrect shape.
[1] https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula [1] https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
""" """
return _so3_exp_map(log_rot, eps=eps)[0]
so3_exponential_map = so3_exp_map
def _so3_exp_map(
log_rot: torch.Tensor, eps: float = 0.0001
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
A helper function that computes the so3 exponential map and,
apart from the rotation matrix, also returns intermediate variables
that can be re-used in other functions.
"""
_, dim = log_rot.shape _, dim = log_rot.shape
if dim != 3: if dim != 3:
raise ValueError("Input tensor shape has to be Nx3.") raise ValueError("Input tensor shape has to be Nx3.")
@ -117,27 +150,35 @@ def so3_exponential_map(log_rot, eps: float = 0.0001):
fac1 = rot_angles_inv * rot_angles.sin() fac1 = rot_angles_inv * rot_angles.sin()
fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles.cos()) fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles.cos())
skews = hat(log_rot) skews = hat(log_rot)
skews_square = torch.bmm(skews, skews)
R = ( R = (
# pyre-fixme[16]: `float` has no attribute `__getitem__`. # pyre-fixme[16]: `float` has no attribute `__getitem__`.
fac1[:, None, None] * skews fac1[:, None, None] * skews
+ fac2[:, None, None] * torch.bmm(skews, skews) + fac2[:, None, None] * skews_square
+ torch.eye(3, dtype=log_rot.dtype, device=log_rot.device)[None] + torch.eye(3, dtype=log_rot.dtype, device=log_rot.device)[None]
) )
return R return R, rot_angles, skews, skews_square
def so3_log_map(R, eps: float = 0.0001): def so3_log_map(
R: torch.Tensor, eps: float = 0.0001, cos_bound: float = 1e-4
) -> torch.Tensor:
""" """
Convert a batch of 3x3 rotation matrices `R` Convert a batch of 3x3 rotation matrices `R`
to a batch of 3-dimensional matrix logarithms of rotation matrices to a batch of 3-dimensional matrix logarithms of rotation matrices
The conversion has a singularity around `(R=I)` which is handled The conversion has a singularity around `(R=I)` which is handled
by clamping controlled with the `eps` argument. by clamping controlled with the `eps` and `cos_bound` arguments.
Args: Args:
R: batch of rotation matrices of shape `(minibatch, 3, 3)`. R: batch of rotation matrices of shape `(minibatch, 3, 3)`.
eps: A float constant handling the conversion singularity. eps: A float constant handling the conversion singularity.
cos_bound: Clamps the cosine of the rotation angle to
[-1 + cos_bound, 1 - cos_bound] to avoid non-finite outputs/gradients
of the `acos` call when computing `so3_rotation_angle`.
Note that the non-finite outputs/gradients are returned when
the rotation angle is close to 0 or π.
Returns: Returns:
Batch of logarithms of input rotation matrices Batch of logarithms of input rotation matrices
@ -152,22 +193,26 @@ def so3_log_map(R, eps: float = 0.0001):
if dim1 != 3 or dim2 != 3: if dim1 != 3 or dim2 != 3:
raise ValueError("Input has to be a batch of 3x3 Tensors.") raise ValueError("Input has to be a batch of 3x3 Tensors.")
phi = so3_rotation_angle(R) phi = so3_rotation_angle(R, cos_bound=cos_bound, eps=eps)
phi_sin = phi.sin() phi_sin = torch.sin(phi)
phi_denom = ( # We want to avoid a tiny denominator of phi_factor = phi / (2.0 * phi_sin).
torch.clamp(phi_sin.abs(), eps) * phi_sin.sign() # Hence, for phi_sin.abs() <= 0.5 * eps, we approximate phi_factor with
+ (phi_sin == 0).type_as(phi) * eps # 2nd order Taylor expansion: phi_factor = 0.5 + (1.0 / 12) * phi**2
) phi_factor = torch.empty_like(phi)
ok_denom = phi_sin.abs() > (0.5 * eps)
phi_factor[~ok_denom] = 0.5 + (phi[~ok_denom] ** 2) * (1.0 / 12)
phi_factor[ok_denom] = phi[ok_denom] / (2.0 * phi_sin[ok_denom])
log_rot_hat = phi_factor[:, None, None] * (R - R.permute(0, 2, 1))
log_rot_hat = (phi / (2.0 * phi_denom))[:, None, None] * (R - R.permute(0, 2, 1))
log_rot = hat_inv(log_rot_hat) log_rot = hat_inv(log_rot_hat)
return log_rot return log_rot
def hat_inv(h): def hat_inv(h: torch.Tensor) -> torch.Tensor:
""" """
Compute the inverse Hat operator [1] of a batch of 3x3 matrices. Compute the inverse Hat operator [1] of a batch of 3x3 matrices.
@ -188,9 +233,9 @@ def hat_inv(h):
if dim1 != 3 or dim2 != 3: if dim1 != 3 or dim2 != 3:
raise ValueError("Input has to be a batch of 3x3 Tensors.") raise ValueError("Input has to be a batch of 3x3 Tensors.")
ss_diff = (h + h.permute(0, 2, 1)).abs().max() ss_diff = torch.abs(h + h.permute(0, 2, 1)).max()
if float(ss_diff) > HAT_INV_SKEW_SYMMETRIC_TOL: if float(ss_diff) > HAT_INV_SKEW_SYMMETRIC_TOL:
raise ValueError("One of input matrices not skew-symmetric.") raise ValueError("One of input matrices is not skew-symmetric.")
x = h[:, 2, 1] x = h[:, 2, 1]
y = h[:, 0, 2] y = h[:, 0, 2]
@ -201,7 +246,7 @@ def hat_inv(h):
return v return v
def hat(v): def hat(v: torch.Tensor) -> torch.Tensor:
""" """
Compute the Hat operator [1] of a batch of 3D vectors. Compute the Hat operator [1] of a batch of 3D vectors.
@ -225,7 +270,7 @@ def hat(v):
if dim != 3: if dim != 3:
raise ValueError("Input vectors have to be 3-dimensional.") raise ValueError("Input vectors have to be 3-dimensional.")
h = v.new_zeros(N, 3, 3) h = torch.zeros((N, 3, 3), dtype=v.dtype, device=v.device)
x, y, z = v.unbind(1) x, y, z = v.unbind(1)

View File

@ -9,9 +9,10 @@ import torch
from common_testing import TestCaseMixin from common_testing import TestCaseMixin
from pytorch3d.transforms.so3 import ( from pytorch3d.transforms.so3 import (
hat, hat,
so3_exponential_map, so3_exp_map,
so3_log_map, so3_log_map,
so3_relative_angle, so3_relative_angle,
so3_rotation_angle,
) )
@ -53,10 +54,10 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
def test_determinant(self): def test_determinant(self):
""" """
Tests whether the determinants of 3x3 rotation matrices produced Tests whether the determinants of 3x3 rotation matrices produced
by `so3_exponential_map` are (almost) equal to 1. by `so3_exp_map` are (almost) equal to 1.
""" """
log_rot = TestSO3.init_log_rot(batch_size=30) log_rot = TestSO3.init_log_rot(batch_size=30)
Rs = so3_exponential_map(log_rot) Rs = so3_exp_map(log_rot)
dets = torch.det(Rs) dets = torch.det(Rs)
self.assertClose(dets, torch.ones_like(dets), atol=1e-4) self.assertClose(dets, torch.ones_like(dets), atol=1e-4)
@ -75,14 +76,14 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
def test_bad_so3_input_value_err(self): def test_bad_so3_input_value_err(self):
""" """
Tests whether `so3_exponential_map` and `so3_log_map` correctly return Tests whether `so3_exp_map` and `so3_log_map` correctly return
a ValueError if called with an argument of incorrect shape or, in case a ValueError if called with an argument of incorrect shape or, in case
of `so3_exponential_map`, unexpected trace. of `so3_exp_map`, unexpected trace.
""" """
device = torch.device("cuda:0") device = torch.device("cuda:0")
log_rot = torch.randn(size=[5, 4], device=device) log_rot = torch.randn(size=[5, 4], device=device)
with self.assertRaises(ValueError) as err: with self.assertRaises(ValueError) as err:
so3_exponential_map(log_rot) so3_exp_map(log_rot)
self.assertTrue("Input tensor shape has to be Nx3." in str(err.exception)) self.assertTrue("Input tensor shape has to be Nx3." in str(err.exception))
rot = torch.randn(size=[5, 3, 5], device=device) rot = torch.randn(size=[5, 3, 5], device=device)
@ -106,17 +107,22 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
def test_so3_exp_singularity(self, batch_size: int = 100): def test_so3_exp_singularity(self, batch_size: int = 100):
""" """
Tests whether the `so3_exponential_map` is robust to the input vectors Tests whether the `so3_exp_map` is robust to the input vectors
the norms of which are close to the numerically unstable region the norms of which are close to the numerically unstable region
(vectors with low l2-norms). (vectors with low l2-norms).
""" """
# generate random log-rotations with a tiny angle # generate random log-rotations with a tiny angle
log_rot = TestSO3.init_log_rot(batch_size=batch_size) log_rot = TestSO3.init_log_rot(batch_size=batch_size)
log_rot_small = log_rot * 1e-6 log_rot_small = log_rot * 1e-6
R = so3_exponential_map(log_rot_small) log_rot_small.requires_grad = True
R = so3_exp_map(log_rot_small)
# tests whether all outputs are finite # tests whether all outputs are finite
R_sum = float(R.sum()) self.assertTrue(torch.isfinite(R).all())
self.assertEqual(R_sum, R_sum) # tests whether the gradient is not None and all finite
loss = R.sum()
loss.backward()
self.assertIsNotNone(log_rot_small.grad)
self.assertTrue(torch.isfinite(log_rot_small.grad).all())
def test_so3_log_singularity(self, batch_size: int = 100): def test_so3_log_singularity(self, batch_size: int = 100):
""" """
@ -129,6 +135,107 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
identity = torch.eye(3, device=device) identity = torch.eye(3, device=device)
rot180 = identity * torch.tensor([[1.0, -1.0, -1.0]], device=device) rot180 = identity * torch.tensor([[1.0, -1.0, -1.0]], device=device)
r = [identity, rot180] r = [identity, rot180]
# add random rotations and random almost orthonormal matrices
r.extend(
[
torch.qr(identity + torch.randn_like(identity) * 1e-4)[0]
+ float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-3
# this adds random noise to the second half
# of the random orthogonal matrices to generate
# near-orthogonal matrices
for i in range(batch_size - 2)
]
)
r = torch.stack(r)
r.requires_grad = True
# the log of the rotation matrix r
r_log = so3_log_map(r, cos_bound=1e-4, eps=1e-2)
# tests whether all outputs are finite
self.assertTrue(torch.isfinite(r_log).all())
# tests whether the gradient is not None and all finite
loss = r.sum()
loss.backward()
self.assertIsNotNone(r.grad)
self.assertTrue(torch.isfinite(r.grad).all())
def test_so3_log_to_exp_to_log_to_exp(self, batch_size: int = 100):
"""
Check that
`so3_exp_map(so3_log_map(so3_exp_map(log_rot)))
== so3_exp_map(log_rot)`
for a randomly generated batch of rotation matrix logarithms `log_rot`.
Unlike `test_so3_log_to_exp_to_log`, this test checks the
correctness of converting a `log_rot` which contains values > math.pi.
"""
log_rot = 2.0 * TestSO3.init_log_rot(batch_size=batch_size)
# check also the singular cases where rot. angle = {0, 2pi}
log_rot[:2] = 0
log_rot[1, 0] = 2.0 * math.pi - 1e-6
rot = so3_exp_map(log_rot, eps=1e-4)
rot_ = so3_exp_map(so3_log_map(rot, eps=1e-4, cos_bound=1e-6), eps=1e-6)
self.assertClose(rot, rot_, atol=0.01)
angles = so3_relative_angle(rot, rot_, cos_bound=1e-6)
self.assertClose(angles, torch.zeros_like(angles), atol=0.01)
def test_so3_log_to_exp_to_log(self, batch_size: int = 100):
"""
Check that `so3_log_map(so3_exp_map(log_rot))==log_rot` for
a randomly generated batch of rotation matrix logarithms `log_rot`.
"""
log_rot = TestSO3.init_log_rot(batch_size=batch_size)
# check also the singular cases where rot. angle = 0
log_rot[:1] = 0
log_rot_ = so3_log_map(so3_exp_map(log_rot))
self.assertClose(log_rot, log_rot_, atol=1e-4)
def test_so3_exp_to_log_to_exp(self, batch_size: int = 100):
"""
Check that `so3_exp_map(so3_log_map(R))==R` for
a batch of randomly generated rotation matrices `R`.
"""
rot = TestSO3.init_rot(batch_size=batch_size)
non_singular = (so3_rotation_angle(rot) - math.pi).abs() > 1e-2
rot = rot[non_singular]
rot_ = so3_exp_map(so3_log_map(rot, eps=1e-8, cos_bound=1e-8), eps=1e-8)
self.assertClose(rot_, rot, atol=0.1)
angles = so3_relative_angle(rot, rot_, cos_bound=1e-4)
self.assertClose(angles, torch.zeros_like(angles), atol=0.1)
def test_so3_cos_relative_angle(self, batch_size: int = 100):
"""
Check that `so3_relative_angle(R1, R2, cos_angle=False).cos()`
is the same as `so3_relative_angle(R1, R2, cos_angle=True)` for
batches of randomly generated rotation matrices `R1` and `R2`.
"""
rot1 = TestSO3.init_rot(batch_size=batch_size)
rot2 = TestSO3.init_rot(batch_size=batch_size)
angles = so3_relative_angle(rot1, rot2, cos_angle=False).cos()
angles_ = so3_relative_angle(rot1, rot2, cos_angle=True)
self.assertClose(angles, angles_, atol=1e-4)
def test_so3_cos_angle(self, batch_size: int = 100):
"""
Check that `so3_rotation_angle(R, cos_angle=False).cos()`
is the same as `so3_rotation_angle(R, cos_angle=True)` for
a batch of randomly generated rotation matrices `R`.
"""
rot = TestSO3.init_rot(batch_size=batch_size)
angles = so3_rotation_angle(rot, cos_angle=False).cos()
angles_ = so3_rotation_angle(rot, cos_angle=True)
self.assertClose(angles, angles_, atol=1e-4)
def test_so3_cos_bound(self, batch_size: int = 100):
"""
Checks that for an identity rotation `R=I`, the so3_rotation_angle returns
non-finite gradients when `cos_bound=None` and finite gradients
for `cos_bound > 0.0`.
"""
# generate random rotations with a tiny angle to generate cases
# with the gradient singularity
device = torch.device("cuda:0")
identity = torch.eye(3, device=device)
rot180 = identity * torch.tensor([[1.0, -1.0, -1.0]], device=device)
r = [identity, rot180]
r.extend( r.extend(
[ [
torch.qr(identity + torch.randn_like(identity) * 1e-4)[0] torch.qr(identity + torch.randn_like(identity) * 1e-4)[0]
@ -136,65 +243,25 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
] ]
) )
r = torch.stack(r) r = torch.stack(r)
# the log of the rotation matrix r r.requires_grad = True
r_log = so3_log_map(r) for is_grad_finite in (True, False):
# tests whether all outputs are finite # clear the gradients and decide the cos_bound:
r_sum = float(r_log.sum()) # for is_grad_finite we run so3_rotation_angle with cos_bound
self.assertEqual(r_sum, r_sum) # set to a small float, otherwise we set to 0.0
r.grad = None
def test_so3_log_to_exp_to_log_to_exp(self, batch_size: int = 100): cos_bound = 1e-4 if is_grad_finite else 0.0
""" # compute the angles of r
Check that angles = so3_rotation_angle(r, cos_bound=cos_bound)
`so3_exponential_map(so3_log_map(so3_exponential_map(log_rot))) # tests whether all outputs are finite in both cases
== so3_exponential_map(log_rot)` self.assertTrue(torch.isfinite(angles).all())
for a randomly generated batch of rotation matrix logarithms `log_rot`. # compute the gradients
Unlike `test_so3_log_to_exp_to_log`, this test checks the loss = angles.sum()
correctness of converting a `log_rot` which contains values > math.pi. loss.backward()
""" # tests whether the gradient is not None for both cases
log_rot = 2.0 * TestSO3.init_log_rot(batch_size=batch_size) self.assertIsNotNone(r.grad)
# check also the singular cases where rot. angle = {0, pi, 2pi, 3pi} if is_grad_finite:
log_rot[:3] = 0 # all grad values have to be finite
log_rot[1, 0] = math.pi self.assertTrue(torch.isfinite(r.grad).all())
log_rot[2, 0] = 2.0 * math.pi
log_rot[3, 0] = 3.0 * math.pi
rot = so3_exponential_map(log_rot, eps=1e-8)
rot_ = so3_exponential_map(so3_log_map(rot, eps=1e-8), eps=1e-8)
angles = so3_relative_angle(rot, rot_)
self.assertClose(angles, torch.zeros_like(angles), atol=0.01)
def test_so3_log_to_exp_to_log(self, batch_size: int = 100):
"""
Check that `so3_log_map(so3_exponential_map(log_rot))==log_rot` for
a randomly generated batch of rotation matrix logarithms `log_rot`.
"""
log_rot = TestSO3.init_log_rot(batch_size=batch_size)
# check also the singular cases where rot. angle = 0
log_rot[:1] = 0
log_rot_ = so3_log_map(so3_exponential_map(log_rot))
self.assertClose(log_rot, log_rot_, atol=1e-4)
def test_so3_exp_to_log_to_exp(self, batch_size: int = 100):
"""
Check that `so3_exponential_map(so3_log_map(R))==R` for
a batch of randomly generated rotation matrices `R`.
"""
rot = TestSO3.init_rot(batch_size=batch_size)
rot_ = so3_exponential_map(so3_log_map(rot, eps=1e-8), eps=1e-8)
angles = so3_relative_angle(rot, rot_)
# TODO: a lot of precision lost here ...
self.assertClose(angles, torch.zeros_like(angles), atol=0.1)
def test_so3_cos_angle(self, batch_size: int = 100):
"""
Check that `so3_relative_angle(R1, R2, cos_angle=False).cos()`
is the same as `so3_relative_angle(R1, R2, cos_angle=True)`
batches of randomly generated rotation matrices `R1` and `R2`.
"""
rot1 = TestSO3.init_rot(batch_size=batch_size)
rot2 = TestSO3.init_rot(batch_size=batch_size)
angles = so3_relative_angle(rot1, rot2, cos_angle=False).cos()
angles_ = so3_relative_angle(rot1, rot2, cos_angle=True)
self.assertClose(angles, angles_)
@staticmethod @staticmethod
def so3_expmap(batch_size: int = 10): def so3_expmap(batch_size: int = 10):
@ -202,7 +269,7 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
torch.cuda.synchronize() torch.cuda.synchronize()
def compute_rots(): def compute_rots():
so3_exponential_map(log_rot) so3_exp_map(log_rot)
torch.cuda.synchronize() torch.cuda.synchronize()
return compute_rots return compute_rots