SO3 improvements for stable gradients.

Summary:
Improves so3 functions to make gradient computation stable:
- Instead of `torch.acos`, uses `acos_linear_extrapolation` which has finite gradients of reasonable magnitude for all inputs.
- Adds tests for the latter.

The tests of the finiteness of the gradient in `test_so3_exp_singularity`, `test_so3_exp_singularity`, `test_so3_cos_bound` would fail if the `so3` functions would call `torch.acos` instead of `acos_linear_extrapolation`.

Reviewed By: bottler

Differential Revision: D23326429

fbshipit-source-id: dc296abf2ae3ddfb3942c8146621491a9cb740ee
This commit is contained in:
David Novotny
2021-06-21 04:47:31 -07:00
committed by Facebook GitHub Bot
parent dd45123f20
commit 9f14e82b5a
3 changed files with 216 additions and 103 deletions

View File

@@ -9,9 +9,10 @@ import torch
from common_testing import TestCaseMixin
from pytorch3d.transforms.so3 import (
hat,
so3_exponential_map,
so3_exp_map,
so3_log_map,
so3_relative_angle,
so3_rotation_angle,
)
@@ -53,10 +54,10 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
def test_determinant(self):
"""
Tests whether the determinants of 3x3 rotation matrices produced
by `so3_exponential_map` are (almost) equal to 1.
by `so3_exp_map` are (almost) equal to 1.
"""
log_rot = TestSO3.init_log_rot(batch_size=30)
Rs = so3_exponential_map(log_rot)
Rs = so3_exp_map(log_rot)
dets = torch.det(Rs)
self.assertClose(dets, torch.ones_like(dets), atol=1e-4)
@@ -75,14 +76,14 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
def test_bad_so3_input_value_err(self):
"""
Tests whether `so3_exponential_map` and `so3_log_map` correctly return
Tests whether `so3_exp_map` and `so3_log_map` correctly return
a ValueError if called with an argument of incorrect shape or, in case
of `so3_exponential_map`, unexpected trace.
of `so3_exp_map`, unexpected trace.
"""
device = torch.device("cuda:0")
log_rot = torch.randn(size=[5, 4], device=device)
with self.assertRaises(ValueError) as err:
so3_exponential_map(log_rot)
so3_exp_map(log_rot)
self.assertTrue("Input tensor shape has to be Nx3." in str(err.exception))
rot = torch.randn(size=[5, 3, 5], device=device)
@@ -106,17 +107,22 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
def test_so3_exp_singularity(self, batch_size: int = 100):
"""
Tests whether the `so3_exponential_map` is robust to the input vectors
Tests whether the `so3_exp_map` is robust to the input vectors
the norms of which are close to the numerically unstable region
(vectors with low l2-norms).
"""
# generate random log-rotations with a tiny angle
log_rot = TestSO3.init_log_rot(batch_size=batch_size)
log_rot_small = log_rot * 1e-6
R = so3_exponential_map(log_rot_small)
log_rot_small.requires_grad = True
R = so3_exp_map(log_rot_small)
# tests whether all outputs are finite
R_sum = float(R.sum())
self.assertEqual(R_sum, R_sum)
self.assertTrue(torch.isfinite(R).all())
# tests whether the gradient is not None and all finite
loss = R.sum()
loss.backward()
self.assertIsNotNone(log_rot_small.grad)
self.assertTrue(torch.isfinite(log_rot_small.grad).all())
def test_so3_log_singularity(self, batch_size: int = 100):
"""
@@ -129,6 +135,107 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
identity = torch.eye(3, device=device)
rot180 = identity * torch.tensor([[1.0, -1.0, -1.0]], device=device)
r = [identity, rot180]
# add random rotations and random almost orthonormal matrices
r.extend(
[
torch.qr(identity + torch.randn_like(identity) * 1e-4)[0]
+ float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-3
# this adds random noise to the second half
# of the random orthogonal matrices to generate
# near-orthogonal matrices
for i in range(batch_size - 2)
]
)
r = torch.stack(r)
r.requires_grad = True
# the log of the rotation matrix r
r_log = so3_log_map(r, cos_bound=1e-4, eps=1e-2)
# tests whether all outputs are finite
self.assertTrue(torch.isfinite(r_log).all())
# tests whether the gradient is not None and all finite
loss = r.sum()
loss.backward()
self.assertIsNotNone(r.grad)
self.assertTrue(torch.isfinite(r.grad).all())
def test_so3_log_to_exp_to_log_to_exp(self, batch_size: int = 100):
"""
Check that
`so3_exp_map(so3_log_map(so3_exp_map(log_rot)))
== so3_exp_map(log_rot)`
for a randomly generated batch of rotation matrix logarithms `log_rot`.
Unlike `test_so3_log_to_exp_to_log`, this test checks the
correctness of converting a `log_rot` which contains values > math.pi.
"""
log_rot = 2.0 * TestSO3.init_log_rot(batch_size=batch_size)
# check also the singular cases where rot. angle = {0, 2pi}
log_rot[:2] = 0
log_rot[1, 0] = 2.0 * math.pi - 1e-6
rot = so3_exp_map(log_rot, eps=1e-4)
rot_ = so3_exp_map(so3_log_map(rot, eps=1e-4, cos_bound=1e-6), eps=1e-6)
self.assertClose(rot, rot_, atol=0.01)
angles = so3_relative_angle(rot, rot_, cos_bound=1e-6)
self.assertClose(angles, torch.zeros_like(angles), atol=0.01)
def test_so3_log_to_exp_to_log(self, batch_size: int = 100):
"""
Check that `so3_log_map(so3_exp_map(log_rot))==log_rot` for
a randomly generated batch of rotation matrix logarithms `log_rot`.
"""
log_rot = TestSO3.init_log_rot(batch_size=batch_size)
# check also the singular cases where rot. angle = 0
log_rot[:1] = 0
log_rot_ = so3_log_map(so3_exp_map(log_rot))
self.assertClose(log_rot, log_rot_, atol=1e-4)
def test_so3_exp_to_log_to_exp(self, batch_size: int = 100):
"""
Check that `so3_exp_map(so3_log_map(R))==R` for
a batch of randomly generated rotation matrices `R`.
"""
rot = TestSO3.init_rot(batch_size=batch_size)
non_singular = (so3_rotation_angle(rot) - math.pi).abs() > 1e-2
rot = rot[non_singular]
rot_ = so3_exp_map(so3_log_map(rot, eps=1e-8, cos_bound=1e-8), eps=1e-8)
self.assertClose(rot_, rot, atol=0.1)
angles = so3_relative_angle(rot, rot_, cos_bound=1e-4)
self.assertClose(angles, torch.zeros_like(angles), atol=0.1)
def test_so3_cos_relative_angle(self, batch_size: int = 100):
"""
Check that `so3_relative_angle(R1, R2, cos_angle=False).cos()`
is the same as `so3_relative_angle(R1, R2, cos_angle=True)` for
batches of randomly generated rotation matrices `R1` and `R2`.
"""
rot1 = TestSO3.init_rot(batch_size=batch_size)
rot2 = TestSO3.init_rot(batch_size=batch_size)
angles = so3_relative_angle(rot1, rot2, cos_angle=False).cos()
angles_ = so3_relative_angle(rot1, rot2, cos_angle=True)
self.assertClose(angles, angles_, atol=1e-4)
def test_so3_cos_angle(self, batch_size: int = 100):
"""
Check that `so3_rotation_angle(R, cos_angle=False).cos()`
is the same as `so3_rotation_angle(R, cos_angle=True)` for
a batch of randomly generated rotation matrices `R`.
"""
rot = TestSO3.init_rot(batch_size=batch_size)
angles = so3_rotation_angle(rot, cos_angle=False).cos()
angles_ = so3_rotation_angle(rot, cos_angle=True)
self.assertClose(angles, angles_, atol=1e-4)
def test_so3_cos_bound(self, batch_size: int = 100):
"""
Checks that for an identity rotation `R=I`, the so3_rotation_angle returns
non-finite gradients when `cos_bound=None` and finite gradients
for `cos_bound > 0.0`.
"""
# generate random rotations with a tiny angle to generate cases
# with the gradient singularity
device = torch.device("cuda:0")
identity = torch.eye(3, device=device)
rot180 = identity * torch.tensor([[1.0, -1.0, -1.0]], device=device)
r = [identity, rot180]
r.extend(
[
torch.qr(identity + torch.randn_like(identity) * 1e-4)[0]
@@ -136,65 +243,25 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
]
)
r = torch.stack(r)
# the log of the rotation matrix r
r_log = so3_log_map(r)
# tests whether all outputs are finite
r_sum = float(r_log.sum())
self.assertEqual(r_sum, r_sum)
def test_so3_log_to_exp_to_log_to_exp(self, batch_size: int = 100):
"""
Check that
`so3_exponential_map(so3_log_map(so3_exponential_map(log_rot)))
== so3_exponential_map(log_rot)`
for a randomly generated batch of rotation matrix logarithms `log_rot`.
Unlike `test_so3_log_to_exp_to_log`, this test checks the
correctness of converting a `log_rot` which contains values > math.pi.
"""
log_rot = 2.0 * TestSO3.init_log_rot(batch_size=batch_size)
# check also the singular cases where rot. angle = {0, pi, 2pi, 3pi}
log_rot[:3] = 0
log_rot[1, 0] = math.pi
log_rot[2, 0] = 2.0 * math.pi
log_rot[3, 0] = 3.0 * math.pi
rot = so3_exponential_map(log_rot, eps=1e-8)
rot_ = so3_exponential_map(so3_log_map(rot, eps=1e-8), eps=1e-8)
angles = so3_relative_angle(rot, rot_)
self.assertClose(angles, torch.zeros_like(angles), atol=0.01)
def test_so3_log_to_exp_to_log(self, batch_size: int = 100):
"""
Check that `so3_log_map(so3_exponential_map(log_rot))==log_rot` for
a randomly generated batch of rotation matrix logarithms `log_rot`.
"""
log_rot = TestSO3.init_log_rot(batch_size=batch_size)
# check also the singular cases where rot. angle = 0
log_rot[:1] = 0
log_rot_ = so3_log_map(so3_exponential_map(log_rot))
self.assertClose(log_rot, log_rot_, atol=1e-4)
def test_so3_exp_to_log_to_exp(self, batch_size: int = 100):
"""
Check that `so3_exponential_map(so3_log_map(R))==R` for
a batch of randomly generated rotation matrices `R`.
"""
rot = TestSO3.init_rot(batch_size=batch_size)
rot_ = so3_exponential_map(so3_log_map(rot, eps=1e-8), eps=1e-8)
angles = so3_relative_angle(rot, rot_)
# TODO: a lot of precision lost here ...
self.assertClose(angles, torch.zeros_like(angles), atol=0.1)
def test_so3_cos_angle(self, batch_size: int = 100):
"""
Check that `so3_relative_angle(R1, R2, cos_angle=False).cos()`
is the same as `so3_relative_angle(R1, R2, cos_angle=True)`
batches of randomly generated rotation matrices `R1` and `R2`.
"""
rot1 = TestSO3.init_rot(batch_size=batch_size)
rot2 = TestSO3.init_rot(batch_size=batch_size)
angles = so3_relative_angle(rot1, rot2, cos_angle=False).cos()
angles_ = so3_relative_angle(rot1, rot2, cos_angle=True)
self.assertClose(angles, angles_)
r.requires_grad = True
for is_grad_finite in (True, False):
# clear the gradients and decide the cos_bound:
# for is_grad_finite we run so3_rotation_angle with cos_bound
# set to a small float, otherwise we set to 0.0
r.grad = None
cos_bound = 1e-4 if is_grad_finite else 0.0
# compute the angles of r
angles = so3_rotation_angle(r, cos_bound=cos_bound)
# tests whether all outputs are finite in both cases
self.assertTrue(torch.isfinite(angles).all())
# compute the gradients
loss = angles.sum()
loss.backward()
# tests whether the gradient is not None for both cases
self.assertIsNotNone(r.grad)
if is_grad_finite:
# all grad values have to be finite
self.assertTrue(torch.isfinite(r.grad).all())
@staticmethod
def so3_expmap(batch_size: int = 10):
@@ -202,7 +269,7 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
torch.cuda.synchronize()
def compute_rots():
so3_exponential_map(log_rot)
so3_exp_map(log_rot)
torch.cuda.synchronize()
return compute_rots