SO3 improvements for stable gradients.

Summary:
Improves so3 functions to make gradient computation stable:
- Instead of `torch.acos`, uses `acos_linear_extrapolation` which has finite gradients of reasonable magnitude for all inputs.
- Adds tests for the latter.

The tests of the finiteness of the gradient in `test_so3_exp_singularity`, `test_so3_exp_singularity`, `test_so3_cos_bound` would fail if the `so3` functions would call `torch.acos` instead of `acos_linear_extrapolation`.

Reviewed By: bottler

Differential Revision: D23326429

fbshipit-source-id: dc296abf2ae3ddfb3942c8146621491a9cb740ee
This commit is contained in:
David Novotny 2021-06-21 04:47:31 -07:00 committed by Facebook GitHub Bot
parent dd45123f20
commit 9f14e82b5a
3 changed files with 216 additions and 103 deletions

View File

@ -22,6 +22,7 @@ from .rotation_conversions import (
)
from .so3 import (
so3_exponential_map,
so3_exp_map,
so3_log_map,
so3_relative_angle,
so3_rotation_angle,

View File

@ -1,13 +1,20 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Tuple
import torch
from ..transforms import acos_linear_extrapolation
HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5
def so3_relative_angle(R1, R2, cos_angle: bool = False):
def so3_relative_angle(
R1: torch.Tensor,
R2: torch.Tensor,
cos_angle: bool = False,
cos_bound: float = 1e-4,
) -> torch.Tensor:
"""
Calculates the relative angle (in radians) between pairs of
rotation matrices `R1` and `R2` with `angle = acos(0.5 * (Trace(R1 R2^T)-1))`
@ -20,8 +27,12 @@ def so3_relative_angle(R1, R2, cos_angle: bool = False):
R1: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
R2: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
cos_angle: If==True return cosine of the relative angle rather than
the angle itself. This can avoid the unstable
calculation of `acos`.
the angle itself. This can avoid the unstable calculation of `acos`.
cos_bound: Clamps the cosine of the relative rotation angle to
[-1 + cos_bound, 1 - cos_bound] to avoid non-finite outputs/gradients
of the `acos` call. Note that the non-finite outputs/gradients
are returned when the angle is requested (i.e. `cos_angle==False`)
and the rotation angle is close to 0 or π.
Returns:
Corresponding rotation angles of shape `(minibatch,)`.
@ -32,10 +43,15 @@ def so3_relative_angle(R1, R2, cos_angle: bool = False):
ValueError if `R1` or `R2` has an unexpected trace.
"""
R12 = torch.bmm(R1, R2.permute(0, 2, 1))
return so3_rotation_angle(R12, cos_angle=cos_angle)
return so3_rotation_angle(R12, cos_angle=cos_angle, cos_bound=cos_bound)
def so3_rotation_angle(R, eps: float = 1e-4, cos_angle: bool = False):
def so3_rotation_angle(
R: torch.Tensor,
eps: float = 1e-4,
cos_angle: bool = False,
cos_bound: float = 1e-4,
) -> torch.Tensor:
"""
Calculates angles (in radians) of a batch of rotation matrices `R` with
`angle = acos(0.5 * (Trace(R)-1))`. The trace of the
@ -47,8 +63,13 @@ def so3_rotation_angle(R, eps: float = 1e-4, cos_angle: bool = False):
R: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
eps: Tolerance for the valid trace check.
cos_angle: If==True return cosine of the rotation angles rather than
the angle itself. This can avoid the unstable
calculation of `acos`.
the angle itself. This can avoid the unstable
calculation of `acos`.
cos_bound: Clamps the cosine of the rotation angle to
[-1 + cos_bound, 1 - cos_bound] to avoid non-finite outputs/gradients
of the `acos` call. Note that the non-finite outputs/gradients
are returned when the angle is requested (i.e. `cos_angle==False`)
and the rotation angle is close to 0 or π.
Returns:
Corresponding rotation angles of shape `(minibatch,)`.
@ -68,20 +89,19 @@ def so3_rotation_angle(R, eps: float = 1e-4, cos_angle: bool = False):
if ((rot_trace < -1.0 - eps) + (rot_trace > 3.0 + eps)).any():
raise ValueError("A matrix has trace outside valid range [-1-eps,3+eps].")
# clamp to valid range
rot_trace = torch.clamp(rot_trace, -1.0, 3.0)
# phi ... rotation angle
phi = 0.5 * (rot_trace - 1.0)
phi_cos = (rot_trace - 1.0) * 0.5
if cos_angle:
return phi
return phi_cos
else:
# pyre-fixme[16]: `float` has no attribute `acos`.
return phi.acos()
if cos_bound > 0.0:
return acos_linear_extrapolation(phi_cos, 1.0 - cos_bound)
else:
return torch.acos(phi_cos)
def so3_exponential_map(log_rot, eps: float = 0.0001):
def so3_exp_map(log_rot: torch.Tensor, eps: float = 0.0001) -> torch.Tensor:
"""
Convert a batch of logarithmic representations of rotation matrices `log_rot`
to a batch of 3x3 rotation matrices using Rodrigues formula [1].
@ -94,18 +114,31 @@ def so3_exponential_map(log_rot, eps: float = 0.0001):
which is handled by clamping controlled with the `eps` argument.
Args:
log_rot: Batch of vectors of shape `(minibatch , 3)`.
log_rot: Batch of vectors of shape `(minibatch, 3)`.
eps: A float constant handling the conversion singularity.
Returns:
Batch of rotation matrices of shape `(minibatch , 3 , 3)`.
Batch of rotation matrices of shape `(minibatch, 3, 3)`.
Raises:
ValueError if `log_rot` is of incorrect shape.
[1] https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
"""
return _so3_exp_map(log_rot, eps=eps)[0]
so3_exponential_map = so3_exp_map
def _so3_exp_map(
log_rot: torch.Tensor, eps: float = 0.0001
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
A helper function that computes the so3 exponential map and,
apart from the rotation matrix, also returns intermediate variables
that can be re-used in other functions.
"""
_, dim = log_rot.shape
if dim != 3:
raise ValueError("Input tensor shape has to be Nx3.")
@ -117,27 +150,35 @@ def so3_exponential_map(log_rot, eps: float = 0.0001):
fac1 = rot_angles_inv * rot_angles.sin()
fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles.cos())
skews = hat(log_rot)
skews_square = torch.bmm(skews, skews)
R = (
# pyre-fixme[16]: `float` has no attribute `__getitem__`.
fac1[:, None, None] * skews
+ fac2[:, None, None] * torch.bmm(skews, skews)
+ fac2[:, None, None] * skews_square
+ torch.eye(3, dtype=log_rot.dtype, device=log_rot.device)[None]
)
return R
return R, rot_angles, skews, skews_square
def so3_log_map(R, eps: float = 0.0001):
def so3_log_map(
R: torch.Tensor, eps: float = 0.0001, cos_bound: float = 1e-4
) -> torch.Tensor:
"""
Convert a batch of 3x3 rotation matrices `R`
to a batch of 3-dimensional matrix logarithms of rotation matrices
The conversion has a singularity around `(R=I)` which is handled
by clamping controlled with the `eps` argument.
by clamping controlled with the `eps` and `cos_bound` arguments.
Args:
R: batch of rotation matrices of shape `(minibatch, 3, 3)`.
eps: A float constant handling the conversion singularity.
cos_bound: Clamps the cosine of the rotation angle to
[-1 + cos_bound, 1 - cos_bound] to avoid non-finite outputs/gradients
of the `acos` call when computing `so3_rotation_angle`.
Note that the non-finite outputs/gradients are returned when
the rotation angle is close to 0 or π.
Returns:
Batch of logarithms of input rotation matrices
@ -152,22 +193,26 @@ def so3_log_map(R, eps: float = 0.0001):
if dim1 != 3 or dim2 != 3:
raise ValueError("Input has to be a batch of 3x3 Tensors.")
phi = so3_rotation_angle(R)
phi = so3_rotation_angle(R, cos_bound=cos_bound, eps=eps)
phi_sin = phi.sin()
phi_sin = torch.sin(phi)
phi_denom = (
torch.clamp(phi_sin.abs(), eps) * phi_sin.sign()
+ (phi_sin == 0).type_as(phi) * eps
)
# We want to avoid a tiny denominator of phi_factor = phi / (2.0 * phi_sin).
# Hence, for phi_sin.abs() <= 0.5 * eps, we approximate phi_factor with
# 2nd order Taylor expansion: phi_factor = 0.5 + (1.0 / 12) * phi**2
phi_factor = torch.empty_like(phi)
ok_denom = phi_sin.abs() > (0.5 * eps)
phi_factor[~ok_denom] = 0.5 + (phi[~ok_denom] ** 2) * (1.0 / 12)
phi_factor[ok_denom] = phi[ok_denom] / (2.0 * phi_sin[ok_denom])
log_rot_hat = phi_factor[:, None, None] * (R - R.permute(0, 2, 1))
log_rot_hat = (phi / (2.0 * phi_denom))[:, None, None] * (R - R.permute(0, 2, 1))
log_rot = hat_inv(log_rot_hat)
return log_rot
def hat_inv(h):
def hat_inv(h: torch.Tensor) -> torch.Tensor:
"""
Compute the inverse Hat operator [1] of a batch of 3x3 matrices.
@ -188,9 +233,9 @@ def hat_inv(h):
if dim1 != 3 or dim2 != 3:
raise ValueError("Input has to be a batch of 3x3 Tensors.")
ss_diff = (h + h.permute(0, 2, 1)).abs().max()
ss_diff = torch.abs(h + h.permute(0, 2, 1)).max()
if float(ss_diff) > HAT_INV_SKEW_SYMMETRIC_TOL:
raise ValueError("One of input matrices not skew-symmetric.")
raise ValueError("One of input matrices is not skew-symmetric.")
x = h[:, 2, 1]
y = h[:, 0, 2]
@ -201,7 +246,7 @@ def hat_inv(h):
return v
def hat(v):
def hat(v: torch.Tensor) -> torch.Tensor:
"""
Compute the Hat operator [1] of a batch of 3D vectors.
@ -225,7 +270,7 @@ def hat(v):
if dim != 3:
raise ValueError("Input vectors have to be 3-dimensional.")
h = v.new_zeros(N, 3, 3)
h = torch.zeros((N, 3, 3), dtype=v.dtype, device=v.device)
x, y, z = v.unbind(1)

View File

@ -9,9 +9,10 @@ import torch
from common_testing import TestCaseMixin
from pytorch3d.transforms.so3 import (
hat,
so3_exponential_map,
so3_exp_map,
so3_log_map,
so3_relative_angle,
so3_rotation_angle,
)
@ -53,10 +54,10 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
def test_determinant(self):
"""
Tests whether the determinants of 3x3 rotation matrices produced
by `so3_exponential_map` are (almost) equal to 1.
by `so3_exp_map` are (almost) equal to 1.
"""
log_rot = TestSO3.init_log_rot(batch_size=30)
Rs = so3_exponential_map(log_rot)
Rs = so3_exp_map(log_rot)
dets = torch.det(Rs)
self.assertClose(dets, torch.ones_like(dets), atol=1e-4)
@ -75,14 +76,14 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
def test_bad_so3_input_value_err(self):
"""
Tests whether `so3_exponential_map` and `so3_log_map` correctly return
Tests whether `so3_exp_map` and `so3_log_map` correctly return
a ValueError if called with an argument of incorrect shape or, in case
of `so3_exponential_map`, unexpected trace.
of `so3_exp_map`, unexpected trace.
"""
device = torch.device("cuda:0")
log_rot = torch.randn(size=[5, 4], device=device)
with self.assertRaises(ValueError) as err:
so3_exponential_map(log_rot)
so3_exp_map(log_rot)
self.assertTrue("Input tensor shape has to be Nx3." in str(err.exception))
rot = torch.randn(size=[5, 3, 5], device=device)
@ -106,17 +107,22 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
def test_so3_exp_singularity(self, batch_size: int = 100):
"""
Tests whether the `so3_exponential_map` is robust to the input vectors
Tests whether the `so3_exp_map` is robust to the input vectors
the norms of which are close to the numerically unstable region
(vectors with low l2-norms).
"""
# generate random log-rotations with a tiny angle
log_rot = TestSO3.init_log_rot(batch_size=batch_size)
log_rot_small = log_rot * 1e-6
R = so3_exponential_map(log_rot_small)
log_rot_small.requires_grad = True
R = so3_exp_map(log_rot_small)
# tests whether all outputs are finite
R_sum = float(R.sum())
self.assertEqual(R_sum, R_sum)
self.assertTrue(torch.isfinite(R).all())
# tests whether the gradient is not None and all finite
loss = R.sum()
loss.backward()
self.assertIsNotNone(log_rot_small.grad)
self.assertTrue(torch.isfinite(log_rot_small.grad).all())
def test_so3_log_singularity(self, batch_size: int = 100):
"""
@ -129,6 +135,107 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
identity = torch.eye(3, device=device)
rot180 = identity * torch.tensor([[1.0, -1.0, -1.0]], device=device)
r = [identity, rot180]
# add random rotations and random almost orthonormal matrices
r.extend(
[
torch.qr(identity + torch.randn_like(identity) * 1e-4)[0]
+ float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-3
# this adds random noise to the second half
# of the random orthogonal matrices to generate
# near-orthogonal matrices
for i in range(batch_size - 2)
]
)
r = torch.stack(r)
r.requires_grad = True
# the log of the rotation matrix r
r_log = so3_log_map(r, cos_bound=1e-4, eps=1e-2)
# tests whether all outputs are finite
self.assertTrue(torch.isfinite(r_log).all())
# tests whether the gradient is not None and all finite
loss = r.sum()
loss.backward()
self.assertIsNotNone(r.grad)
self.assertTrue(torch.isfinite(r.grad).all())
def test_so3_log_to_exp_to_log_to_exp(self, batch_size: int = 100):
"""
Check that
`so3_exp_map(so3_log_map(so3_exp_map(log_rot)))
== so3_exp_map(log_rot)`
for a randomly generated batch of rotation matrix logarithms `log_rot`.
Unlike `test_so3_log_to_exp_to_log`, this test checks the
correctness of converting a `log_rot` which contains values > math.pi.
"""
log_rot = 2.0 * TestSO3.init_log_rot(batch_size=batch_size)
# check also the singular cases where rot. angle = {0, 2pi}
log_rot[:2] = 0
log_rot[1, 0] = 2.0 * math.pi - 1e-6
rot = so3_exp_map(log_rot, eps=1e-4)
rot_ = so3_exp_map(so3_log_map(rot, eps=1e-4, cos_bound=1e-6), eps=1e-6)
self.assertClose(rot, rot_, atol=0.01)
angles = so3_relative_angle(rot, rot_, cos_bound=1e-6)
self.assertClose(angles, torch.zeros_like(angles), atol=0.01)
def test_so3_log_to_exp_to_log(self, batch_size: int = 100):
"""
Check that `so3_log_map(so3_exp_map(log_rot))==log_rot` for
a randomly generated batch of rotation matrix logarithms `log_rot`.
"""
log_rot = TestSO3.init_log_rot(batch_size=batch_size)
# check also the singular cases where rot. angle = 0
log_rot[:1] = 0
log_rot_ = so3_log_map(so3_exp_map(log_rot))
self.assertClose(log_rot, log_rot_, atol=1e-4)
def test_so3_exp_to_log_to_exp(self, batch_size: int = 100):
"""
Check that `so3_exp_map(so3_log_map(R))==R` for
a batch of randomly generated rotation matrices `R`.
"""
rot = TestSO3.init_rot(batch_size=batch_size)
non_singular = (so3_rotation_angle(rot) - math.pi).abs() > 1e-2
rot = rot[non_singular]
rot_ = so3_exp_map(so3_log_map(rot, eps=1e-8, cos_bound=1e-8), eps=1e-8)
self.assertClose(rot_, rot, atol=0.1)
angles = so3_relative_angle(rot, rot_, cos_bound=1e-4)
self.assertClose(angles, torch.zeros_like(angles), atol=0.1)
def test_so3_cos_relative_angle(self, batch_size: int = 100):
"""
Check that `so3_relative_angle(R1, R2, cos_angle=False).cos()`
is the same as `so3_relative_angle(R1, R2, cos_angle=True)` for
batches of randomly generated rotation matrices `R1` and `R2`.
"""
rot1 = TestSO3.init_rot(batch_size=batch_size)
rot2 = TestSO3.init_rot(batch_size=batch_size)
angles = so3_relative_angle(rot1, rot2, cos_angle=False).cos()
angles_ = so3_relative_angle(rot1, rot2, cos_angle=True)
self.assertClose(angles, angles_, atol=1e-4)
def test_so3_cos_angle(self, batch_size: int = 100):
"""
Check that `so3_rotation_angle(R, cos_angle=False).cos()`
is the same as `so3_rotation_angle(R, cos_angle=True)` for
a batch of randomly generated rotation matrices `R`.
"""
rot = TestSO3.init_rot(batch_size=batch_size)
angles = so3_rotation_angle(rot, cos_angle=False).cos()
angles_ = so3_rotation_angle(rot, cos_angle=True)
self.assertClose(angles, angles_, atol=1e-4)
def test_so3_cos_bound(self, batch_size: int = 100):
"""
Checks that for an identity rotation `R=I`, the so3_rotation_angle returns
non-finite gradients when `cos_bound=None` and finite gradients
for `cos_bound > 0.0`.
"""
# generate random rotations with a tiny angle to generate cases
# with the gradient singularity
device = torch.device("cuda:0")
identity = torch.eye(3, device=device)
rot180 = identity * torch.tensor([[1.0, -1.0, -1.0]], device=device)
r = [identity, rot180]
r.extend(
[
torch.qr(identity + torch.randn_like(identity) * 1e-4)[0]
@ -136,65 +243,25 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
]
)
r = torch.stack(r)
# the log of the rotation matrix r
r_log = so3_log_map(r)
# tests whether all outputs are finite
r_sum = float(r_log.sum())
self.assertEqual(r_sum, r_sum)
def test_so3_log_to_exp_to_log_to_exp(self, batch_size: int = 100):
"""
Check that
`so3_exponential_map(so3_log_map(so3_exponential_map(log_rot)))
== so3_exponential_map(log_rot)`
for a randomly generated batch of rotation matrix logarithms `log_rot`.
Unlike `test_so3_log_to_exp_to_log`, this test checks the
correctness of converting a `log_rot` which contains values > math.pi.
"""
log_rot = 2.0 * TestSO3.init_log_rot(batch_size=batch_size)
# check also the singular cases where rot. angle = {0, pi, 2pi, 3pi}
log_rot[:3] = 0
log_rot[1, 0] = math.pi
log_rot[2, 0] = 2.0 * math.pi
log_rot[3, 0] = 3.0 * math.pi
rot = so3_exponential_map(log_rot, eps=1e-8)
rot_ = so3_exponential_map(so3_log_map(rot, eps=1e-8), eps=1e-8)
angles = so3_relative_angle(rot, rot_)
self.assertClose(angles, torch.zeros_like(angles), atol=0.01)
def test_so3_log_to_exp_to_log(self, batch_size: int = 100):
"""
Check that `so3_log_map(so3_exponential_map(log_rot))==log_rot` for
a randomly generated batch of rotation matrix logarithms `log_rot`.
"""
log_rot = TestSO3.init_log_rot(batch_size=batch_size)
# check also the singular cases where rot. angle = 0
log_rot[:1] = 0
log_rot_ = so3_log_map(so3_exponential_map(log_rot))
self.assertClose(log_rot, log_rot_, atol=1e-4)
def test_so3_exp_to_log_to_exp(self, batch_size: int = 100):
"""
Check that `so3_exponential_map(so3_log_map(R))==R` for
a batch of randomly generated rotation matrices `R`.
"""
rot = TestSO3.init_rot(batch_size=batch_size)
rot_ = so3_exponential_map(so3_log_map(rot, eps=1e-8), eps=1e-8)
angles = so3_relative_angle(rot, rot_)
# TODO: a lot of precision lost here ...
self.assertClose(angles, torch.zeros_like(angles), atol=0.1)
def test_so3_cos_angle(self, batch_size: int = 100):
"""
Check that `so3_relative_angle(R1, R2, cos_angle=False).cos()`
is the same as `so3_relative_angle(R1, R2, cos_angle=True)`
batches of randomly generated rotation matrices `R1` and `R2`.
"""
rot1 = TestSO3.init_rot(batch_size=batch_size)
rot2 = TestSO3.init_rot(batch_size=batch_size)
angles = so3_relative_angle(rot1, rot2, cos_angle=False).cos()
angles_ = so3_relative_angle(rot1, rot2, cos_angle=True)
self.assertClose(angles, angles_)
r.requires_grad = True
for is_grad_finite in (True, False):
# clear the gradients and decide the cos_bound:
# for is_grad_finite we run so3_rotation_angle with cos_bound
# set to a small float, otherwise we set to 0.0
r.grad = None
cos_bound = 1e-4 if is_grad_finite else 0.0
# compute the angles of r
angles = so3_rotation_angle(r, cos_bound=cos_bound)
# tests whether all outputs are finite in both cases
self.assertTrue(torch.isfinite(angles).all())
# compute the gradients
loss = angles.sum()
loss.backward()
# tests whether the gradient is not None for both cases
self.assertIsNotNone(r.grad)
if is_grad_finite:
# all grad values have to be finite
self.assertTrue(torch.isfinite(r.grad).all())
@staticmethod
def so3_expmap(batch_size: int = 10):
@ -202,7 +269,7 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
torch.cuda.synchronize()
def compute_rots():
so3_exponential_map(log_rot)
so3_exp_map(log_rot)
torch.cuda.synchronize()
return compute_rots