pytorch3d/tests/test_so3.py
Jeremy Reizenstein 34f648ede0 move targets
Summary: Move testing targets from pytorch3d/tests/TARGETS to pytorch3d/TARGETS.

Reviewed By: shapovalov

Differential Revision: D36186940

fbshipit-source-id: a4c52c4d99351f885e2b0bf870532d530324039b
2022-05-25 06:16:03 -07:00

299 lines
12 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import math
import unittest
from distutils.version import LooseVersion
import numpy as np
import torch
from pytorch3d.common.compat import qr
from pytorch3d.transforms.so3 import (
hat,
so3_exp_map,
so3_log_map,
so3_relative_angle,
so3_rotation_angle,
)
from .common_testing import TestCaseMixin
class TestSO3(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(42)
np.random.seed(42)
@staticmethod
def init_log_rot(batch_size: int = 10):
"""
Initialize a list of `batch_size` 3-dimensional vectors representing
randomly generated logarithms of rotation matrices.
"""
device = torch.device("cuda:0")
log_rot = torch.randn((batch_size, 3), dtype=torch.float32, device=device)
return log_rot
@staticmethod
def init_rot(batch_size: int = 10):
"""
Randomly generate a batch of `batch_size` 3x3 rotation matrices.
"""
device = torch.device("cuda:0")
# TODO(dnovotny): replace with random_rotation from random_rotation.py
rot = []
for _ in range(batch_size):
r = qr(torch.randn((3, 3), device=device))[0]
f = torch.randint(2, (3,), device=device, dtype=torch.float32)
if f.sum() % 2 == 0:
f = 1 - f
rot.append(r * (2 * f - 1).float())
rot = torch.stack(rot)
return rot
def test_determinant(self):
"""
Tests whether the determinants of 3x3 rotation matrices produced
by `so3_exp_map` are (almost) equal to 1.
"""
log_rot = TestSO3.init_log_rot(batch_size=30)
Rs = so3_exp_map(log_rot)
dets = torch.det(Rs)
self.assertClose(dets, torch.ones_like(dets), atol=1e-4)
def test_cross(self):
"""
For a pair of randomly generated 3-dimensional vectors `a` and `b`,
tests whether a matrix product of `hat(a)` and `b` equals the result
of a cross product between `a` and `b`.
"""
device = torch.device("cuda:0")
a, b = torch.randn((2, 100, 3), dtype=torch.float32, device=device)
hat_a = hat(a)
cross = torch.bmm(hat_a, b[:, :, None])[:, :, 0]
torch_cross = torch.cross(a, b, dim=1)
self.assertClose(torch_cross, cross, atol=1e-4)
def test_bad_so3_input_value_err(self):
"""
Tests whether `so3_exp_map` and `so3_log_map` correctly return
a ValueError if called with an argument of incorrect shape or, in case
of `so3_exp_map`, unexpected trace.
"""
device = torch.device("cuda:0")
log_rot = torch.randn(size=[5, 4], device=device)
with self.assertRaises(ValueError) as err:
so3_exp_map(log_rot)
self.assertTrue("Input tensor shape has to be Nx3." in str(err.exception))
rot = torch.randn(size=[5, 3, 5], device=device)
with self.assertRaises(ValueError) as err:
so3_log_map(rot)
self.assertTrue("Input has to be a batch of 3x3 Tensors." in str(err.exception))
# trace of rot definitely bigger than 3 or smaller than -1
rot = torch.cat(
(
torch.rand(size=[5, 3, 3], device=device) + 4.0,
torch.rand(size=[5, 3, 3], device=device) - 3.0,
)
)
with self.assertRaises(ValueError) as err:
so3_log_map(rot)
self.assertTrue(
"A matrix has trace outside valid range [-1-eps,3+eps]."
in str(err.exception)
)
def test_so3_exp_singularity(self, batch_size: int = 100):
"""
Tests whether the `so3_exp_map` is robust to the input vectors
the norms of which are close to the numerically unstable region
(vectors with low l2-norms).
"""
# generate random log-rotations with a tiny angle
log_rot = TestSO3.init_log_rot(batch_size=batch_size)
log_rot_small = log_rot * 1e-6
log_rot_small.requires_grad = True
R = so3_exp_map(log_rot_small)
# tests whether all outputs are finite
self.assertTrue(torch.isfinite(R).all())
# tests whether the gradient is not None and all finite
loss = R.sum()
loss.backward()
self.assertIsNotNone(log_rot_small.grad)
self.assertTrue(torch.isfinite(log_rot_small.grad).all())
def test_so3_log_singularity(self, batch_size: int = 100):
"""
Tests whether the `so3_log_map` is robust to the input matrices
who's rotation angles are close to the numerically unstable region
(i.e. matrices with low rotation angles).
"""
# generate random rotations with a tiny angle
device = torch.device("cuda:0")
identity = torch.eye(3, device=device)
rot180 = identity * torch.tensor([[1.0, -1.0, -1.0]], device=device)
r = [identity, rot180]
# add random rotations and random almost orthonormal matrices
r.extend(
[
qr(identity + torch.randn_like(identity) * 1e-4)[0]
+ float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-3
# this adds random noise to the second half
# of the random orthogonal matrices to generate
# near-orthogonal matrices
for i in range(batch_size - 2)
]
)
r = torch.stack(r)
r.requires_grad = True
# the log of the rotation matrix r
r_log = so3_log_map(r, cos_bound=1e-4, eps=1e-2)
# tests whether all outputs are finite
self.assertTrue(torch.isfinite(r_log).all())
# tests whether the gradient is not None and all finite
loss = r.sum()
loss.backward()
self.assertIsNotNone(r.grad)
self.assertTrue(torch.isfinite(r.grad).all())
def test_so3_log_to_exp_to_log_to_exp(self, batch_size: int = 100):
"""
Check that
`so3_exp_map(so3_log_map(so3_exp_map(log_rot)))
== so3_exp_map(log_rot)`
for a randomly generated batch of rotation matrix logarithms `log_rot`.
Unlike `test_so3_log_to_exp_to_log`, this test checks the
correctness of converting a `log_rot` which contains values > math.pi.
"""
log_rot = 2.0 * TestSO3.init_log_rot(batch_size=batch_size)
# check also the singular cases where rot. angle = {0, 2pi}
log_rot[:2] = 0
log_rot[1, 0] = 2.0 * math.pi - 1e-6
rot = so3_exp_map(log_rot, eps=1e-4)
rot_ = so3_exp_map(so3_log_map(rot, eps=1e-4, cos_bound=1e-6), eps=1e-6)
self.assertClose(rot, rot_, atol=0.01)
angles = so3_relative_angle(rot, rot_, cos_bound=1e-6)
self.assertClose(angles, torch.zeros_like(angles), atol=0.01)
def test_so3_log_to_exp_to_log(self, batch_size: int = 100):
"""
Check that `so3_log_map(so3_exp_map(log_rot))==log_rot` for
a randomly generated batch of rotation matrix logarithms `log_rot`.
"""
log_rot = TestSO3.init_log_rot(batch_size=batch_size)
# check also the singular cases where rot. angle = 0
log_rot[:1] = 0
log_rot_ = so3_log_map(so3_exp_map(log_rot))
self.assertClose(log_rot, log_rot_, atol=1e-4)
def test_so3_exp_to_log_to_exp(self, batch_size: int = 100):
"""
Check that `so3_exp_map(so3_log_map(R))==R` for
a batch of randomly generated rotation matrices `R`.
"""
rot = TestSO3.init_rot(batch_size=batch_size)
non_singular = (so3_rotation_angle(rot) - math.pi).abs() > 1e-2
rot = rot[non_singular]
rot_ = so3_exp_map(so3_log_map(rot, eps=1e-8, cos_bound=1e-8), eps=1e-8)
self.assertClose(rot_, rot, atol=0.1)
angles = so3_relative_angle(rot, rot_, cos_bound=1e-4)
self.assertClose(angles, torch.zeros_like(angles), atol=0.1)
def test_so3_cos_relative_angle(self, batch_size: int = 100):
"""
Check that `so3_relative_angle(R1, R2, cos_angle=False).cos()`
is the same as `so3_relative_angle(R1, R2, cos_angle=True)` for
batches of randomly generated rotation matrices `R1` and `R2`.
"""
rot1 = TestSO3.init_rot(batch_size=batch_size)
rot2 = TestSO3.init_rot(batch_size=batch_size)
angles = so3_relative_angle(rot1, rot2, cos_angle=False).cos()
angles_ = so3_relative_angle(rot1, rot2, cos_angle=True)
self.assertClose(angles, angles_, atol=1e-4)
def test_so3_cos_angle(self, batch_size: int = 100):
"""
Check that `so3_rotation_angle(R, cos_angle=False).cos()`
is the same as `so3_rotation_angle(R, cos_angle=True)` for
a batch of randomly generated rotation matrices `R`.
"""
rot = TestSO3.init_rot(batch_size=batch_size)
angles = so3_rotation_angle(rot, cos_angle=False).cos()
angles_ = so3_rotation_angle(rot, cos_angle=True)
self.assertClose(angles, angles_, atol=1e-4)
def test_so3_cos_bound(self, batch_size: int = 100):
"""
Checks that for an identity rotation `R=I`, the so3_rotation_angle returns
non-finite gradients when `cos_bound=None` and finite gradients
for `cos_bound > 0.0`.
"""
# generate random rotations with a tiny angle to generate cases
# with the gradient singularity
device = torch.device("cuda:0")
identity = torch.eye(3, device=device)
rot180 = identity * torch.tensor([[1.0, -1.0, -1.0]], device=device)
r = [identity, rot180]
r.extend(
[
qr(identity + torch.randn_like(identity) * 1e-4)[0]
for _ in range(batch_size - 2)
]
)
r = torch.stack(r)
r.requires_grad = True
for is_grad_finite in (True, False):
# clear the gradients and decide the cos_bound:
# for is_grad_finite we run so3_rotation_angle with cos_bound
# set to a small float, otherwise we set to 0.0
r.grad = None
cos_bound = 1e-4 if is_grad_finite else 0.0
# compute the angles of r
angles = so3_rotation_angle(r, cos_bound=cos_bound)
# tests whether all outputs are finite in both cases
self.assertTrue(torch.isfinite(angles).all())
# compute the gradients
loss = angles.sum()
loss.backward()
# tests whether the gradient is not None for both cases
self.assertIsNotNone(r.grad)
if is_grad_finite:
# all grad values have to be finite
self.assertTrue(torch.isfinite(r.grad).all())
@unittest.skipIf(LooseVersion(torch.__version__) < "1.9", "recent torchscript only")
def test_scriptable(self):
torch.jit.script(so3_exp_map)
torch.jit.script(so3_log_map)
@staticmethod
def so3_expmap(batch_size: int = 10):
log_rot = TestSO3.init_log_rot(batch_size=batch_size)
torch.cuda.synchronize()
def compute_rots():
so3_exp_map(log_rot)
torch.cuda.synchronize()
return compute_rots
@staticmethod
def so3_logmap(batch_size: int = 10):
log_rot = TestSO3.init_rot(batch_size=batch_size)
torch.cuda.synchronize()
def compute_logs():
so3_log_map(log_rot)
torch.cuda.synchronize()
return compute_logs