From 34a0df0630c964d4e4be225b1dc0ccf166743e75 Mon Sep 17 00:00:00 2001 From: David Novotny Date: Sun, 10 May 2020 13:14:10 -0700 Subject: [PATCH] SO3 log map fix for singularity at PI Summary: Fixes the case where the rotation angle is exactly 0/PI. Added a test for `so3_log_map(identity_matrix)`. Reviewed By: nikhilaravi Differential Revision: D21477078 fbshipit-source-id: adff804da97f6f0d4f50aa1f6904a34832cb8bfe --- pytorch3d/transforms/so3.py | 9 ++++-- tests/test_so3.py | 60 ++++++++++++++++++++++++++----------- 2 files changed, 49 insertions(+), 20 deletions(-) diff --git a/pytorch3d/transforms/so3.py b/pytorch3d/transforms/so3.py index 18c63b78..57ef51b0 100644 --- a/pytorch3d/transforms/so3.py +++ b/pytorch3d/transforms/so3.py @@ -152,11 +152,14 @@ def so3_log_map(R, eps: float = 0.0001): phi = so3_rotation_angle(R) - phi_valid = torch.clamp(phi.abs(), eps) * phi.sign() + phi_sin = phi.sin() - log_rot_hat = (phi_valid / (2.0 * phi_valid.sin()))[:, None, None] * ( - R - R.permute(0, 2, 1) + phi_denom = ( + torch.clamp(phi_sin.abs(), eps) * phi_sin.sign() + + (phi_sin == 0).type_as(phi) * eps ) + + log_rot_hat = (phi / (2.0 * phi_denom))[:, None, None] * (R - R.permute(0, 2, 1)) log_rot = hat_inv(log_rot_hat) return log_rot diff --git a/tests/test_so3.py b/tests/test_so3.py index 8e261529..315e11da 100644 --- a/tests/test_so3.py +++ b/tests/test_so3.py @@ -1,10 +1,12 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import math import unittest import numpy as np import torch +from common_testing import TestCaseMixin from pytorch3d.transforms.so3 import ( hat, so3_exponential_map, @@ -13,7 +15,7 @@ from pytorch3d.transforms.so3 import ( ) -class TestSO3(unittest.TestCase): +class TestSO3(TestCaseMixin, unittest.TestCase): def setUp(self) -> None: super().setUp() torch.manual_seed(42) @@ -55,9 +57,8 @@ class TestSO3(unittest.TestCase): """ log_rot = TestSO3.init_log_rot(batch_size=30) Rs = so3_exponential_map(log_rot) - for R in Rs: - det = np.linalg.det(R.cpu().numpy()) - self.assertAlmostEqual(float(det), 1.0, 5) + dets = torch.det(Rs) + self.assertClose(dets, torch.ones_like(dets), atol=1e-4) def test_cross(self): """ @@ -70,8 +71,7 @@ class TestSO3(unittest.TestCase): hat_a = hat(a) cross = torch.bmm(hat_a, b[:, :, None])[:, :, 0] torch_cross = torch.cross(a, b, dim=1) - max_df = (cross - torch_cross).abs().max() - self.assertAlmostEqual(float(max_df), 0.0, 5) + self.assertClose(torch_cross, cross, atol=1e-4) def test_bad_so3_input_value_err(self): """ @@ -126,24 +126,52 @@ class TestSO3(unittest.TestCase): """ # generate random rotations with a tiny angle device = torch.device("cuda:0") - r = torch.eye(3, device=device)[None].repeat((batch_size, 1, 1)) - r += torch.randn((batch_size, 3, 3), device=device) * 1e-3 - r = torch.stack([torch.qr(r_)[0] for r_ in r]) + identity = torch.eye(3, device=device) + rot180 = identity * torch.tensor([[1.0, -1.0, -1.0]], device=device) + r = [identity, rot180] + r.extend( + [ + torch.qr(identity + torch.randn_like(identity) * 1e-4)[0] + for _ in range(batch_size - 2) + ] + ) + r = torch.stack(r) # the log of the rotation matrix r r_log = so3_log_map(r) # tests whether all outputs are finite r_sum = float(r_log.sum()) self.assertEqual(r_sum, r_sum) + def test_so3_log_to_exp_to_log_to_exp(self, batch_size: int = 100): + """ + Check that + `so3_exponential_map(so3_log_map(so3_exponential_map(log_rot))) + == so3_exponential_map(log_rot)` + for a randomly generated batch of rotation matrix logarithms `log_rot`. + Unlike `test_so3_log_to_exp_to_log`, this test allows to check the + correctness of converting `log_rot` which contains values > math.pi. + """ + log_rot = 2.0 * TestSO3.init_log_rot(batch_size=batch_size) + # check also the singular cases where rot. angle = {0, pi, 2pi, 3pi} + log_rot[:3] = 0 + log_rot[1, 0] = math.pi + log_rot[2, 0] = 2.0 * math.pi + log_rot[3, 0] = 3.0 * math.pi + rot = so3_exponential_map(log_rot, eps=1e-8) + rot_ = so3_exponential_map(so3_log_map(rot, eps=1e-8), eps=1e-8) + angles = so3_relative_angle(rot, rot_) + self.assertClose(angles, torch.zeros_like(angles), atol=0.01) + def test_so3_log_to_exp_to_log(self, batch_size: int = 100): """ Check that `so3_log_map(so3_exponential_map(log_rot))==log_rot` for a randomly generated batch of rotation matrix logarithms `log_rot`. """ log_rot = TestSO3.init_log_rot(batch_size=batch_size) + # check also the singular cases where rot. angle = 0 + log_rot[:1] = 0 log_rot_ = so3_log_map(so3_exponential_map(log_rot)) - max_df = (log_rot - log_rot_).abs().max() - self.assertAlmostEqual(float(max_df), 0.0, 4) + self.assertClose(log_rot, log_rot_, atol=1e-4) def test_so3_exp_to_log_to_exp(self, batch_size: int = 100): """ @@ -151,12 +179,10 @@ class TestSO3(unittest.TestCase): a batch of randomly generated rotation matrices `R`. """ rot = TestSO3.init_rot(batch_size=batch_size) - rot_ = so3_exponential_map(so3_log_map(rot)) + rot_ = so3_exponential_map(so3_log_map(rot, eps=1e-8), eps=1e-8) angles = so3_relative_angle(rot, rot_) - max_angle = angles.max() - # a lot of precision lost here :( - # TODO: fix this test?? - self.assertTrue(np.allclose(float(max_angle), 0.0, atol=0.1)) + # TODO: a lot of precision lost here ... + self.assertClose(angles, torch.zeros_like(angles), atol=0.1) def test_so3_cos_angle(self, batch_size: int = 100): """ @@ -168,7 +194,7 @@ class TestSO3(unittest.TestCase): rot2 = TestSO3.init_rot(batch_size=batch_size) angles = so3_relative_angle(rot1, rot2, cos_angle=False).cos() angles_ = so3_relative_angle(rot1, rot2, cos_angle=True) - self.assertTrue(torch.allclose(angles, angles_)) + self.assertClose(angles, angles_) @staticmethod def so3_expmap(batch_size: int = 10):