mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
SO3 log map fix for singularity at PI
Summary: Fixes the case where the rotation angle is exactly 0/PI. Added a test for `so3_log_map(identity_matrix)`. Reviewed By: nikhilaravi Differential Revision: D21477078 fbshipit-source-id: adff804da97f6f0d4f50aa1f6904a34832cb8bfe
This commit is contained in:
parent
17ca6ecd81
commit
34a0df0630
@ -152,11 +152,14 @@ def so3_log_map(R, eps: float = 0.0001):
|
|||||||
|
|
||||||
phi = so3_rotation_angle(R)
|
phi = so3_rotation_angle(R)
|
||||||
|
|
||||||
phi_valid = torch.clamp(phi.abs(), eps) * phi.sign()
|
phi_sin = phi.sin()
|
||||||
|
|
||||||
log_rot_hat = (phi_valid / (2.0 * phi_valid.sin()))[:, None, None] * (
|
phi_denom = (
|
||||||
R - R.permute(0, 2, 1)
|
torch.clamp(phi_sin.abs(), eps) * phi_sin.sign()
|
||||||
|
+ (phi_sin == 0).type_as(phi) * eps
|
||||||
)
|
)
|
||||||
|
|
||||||
|
log_rot_hat = (phi / (2.0 * phi_denom))[:, None, None] * (R - R.permute(0, 2, 1))
|
||||||
log_rot = hat_inv(log_rot_hat)
|
log_rot = hat_inv(log_rot_hat)
|
||||||
|
|
||||||
return log_rot
|
return log_rot
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
|
||||||
|
import math
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from common_testing import TestCaseMixin
|
||||||
from pytorch3d.transforms.so3 import (
|
from pytorch3d.transforms.so3 import (
|
||||||
hat,
|
hat,
|
||||||
so3_exponential_map,
|
so3_exponential_map,
|
||||||
@ -13,7 +15,7 @@ from pytorch3d.transforms.so3 import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestSO3(unittest.TestCase):
|
class TestSO3(TestCaseMixin, unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
super().setUp()
|
super().setUp()
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
@ -55,9 +57,8 @@ class TestSO3(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
log_rot = TestSO3.init_log_rot(batch_size=30)
|
log_rot = TestSO3.init_log_rot(batch_size=30)
|
||||||
Rs = so3_exponential_map(log_rot)
|
Rs = so3_exponential_map(log_rot)
|
||||||
for R in Rs:
|
dets = torch.det(Rs)
|
||||||
det = np.linalg.det(R.cpu().numpy())
|
self.assertClose(dets, torch.ones_like(dets), atol=1e-4)
|
||||||
self.assertAlmostEqual(float(det), 1.0, 5)
|
|
||||||
|
|
||||||
def test_cross(self):
|
def test_cross(self):
|
||||||
"""
|
"""
|
||||||
@ -70,8 +71,7 @@ class TestSO3(unittest.TestCase):
|
|||||||
hat_a = hat(a)
|
hat_a = hat(a)
|
||||||
cross = torch.bmm(hat_a, b[:, :, None])[:, :, 0]
|
cross = torch.bmm(hat_a, b[:, :, None])[:, :, 0]
|
||||||
torch_cross = torch.cross(a, b, dim=1)
|
torch_cross = torch.cross(a, b, dim=1)
|
||||||
max_df = (cross - torch_cross).abs().max()
|
self.assertClose(torch_cross, cross, atol=1e-4)
|
||||||
self.assertAlmostEqual(float(max_df), 0.0, 5)
|
|
||||||
|
|
||||||
def test_bad_so3_input_value_err(self):
|
def test_bad_so3_input_value_err(self):
|
||||||
"""
|
"""
|
||||||
@ -126,24 +126,52 @@ class TestSO3(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
# generate random rotations with a tiny angle
|
# generate random rotations with a tiny angle
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
r = torch.eye(3, device=device)[None].repeat((batch_size, 1, 1))
|
identity = torch.eye(3, device=device)
|
||||||
r += torch.randn((batch_size, 3, 3), device=device) * 1e-3
|
rot180 = identity * torch.tensor([[1.0, -1.0, -1.0]], device=device)
|
||||||
r = torch.stack([torch.qr(r_)[0] for r_ in r])
|
r = [identity, rot180]
|
||||||
|
r.extend(
|
||||||
|
[
|
||||||
|
torch.qr(identity + torch.randn_like(identity) * 1e-4)[0]
|
||||||
|
for _ in range(batch_size - 2)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
r = torch.stack(r)
|
||||||
# the log of the rotation matrix r
|
# the log of the rotation matrix r
|
||||||
r_log = so3_log_map(r)
|
r_log = so3_log_map(r)
|
||||||
# tests whether all outputs are finite
|
# tests whether all outputs are finite
|
||||||
r_sum = float(r_log.sum())
|
r_sum = float(r_log.sum())
|
||||||
self.assertEqual(r_sum, r_sum)
|
self.assertEqual(r_sum, r_sum)
|
||||||
|
|
||||||
|
def test_so3_log_to_exp_to_log_to_exp(self, batch_size: int = 100):
|
||||||
|
"""
|
||||||
|
Check that
|
||||||
|
`so3_exponential_map(so3_log_map(so3_exponential_map(log_rot)))
|
||||||
|
== so3_exponential_map(log_rot)`
|
||||||
|
for a randomly generated batch of rotation matrix logarithms `log_rot`.
|
||||||
|
Unlike `test_so3_log_to_exp_to_log`, this test allows to check the
|
||||||
|
correctness of converting `log_rot` which contains values > math.pi.
|
||||||
|
"""
|
||||||
|
log_rot = 2.0 * TestSO3.init_log_rot(batch_size=batch_size)
|
||||||
|
# check also the singular cases where rot. angle = {0, pi, 2pi, 3pi}
|
||||||
|
log_rot[:3] = 0
|
||||||
|
log_rot[1, 0] = math.pi
|
||||||
|
log_rot[2, 0] = 2.0 * math.pi
|
||||||
|
log_rot[3, 0] = 3.0 * math.pi
|
||||||
|
rot = so3_exponential_map(log_rot, eps=1e-8)
|
||||||
|
rot_ = so3_exponential_map(so3_log_map(rot, eps=1e-8), eps=1e-8)
|
||||||
|
angles = so3_relative_angle(rot, rot_)
|
||||||
|
self.assertClose(angles, torch.zeros_like(angles), atol=0.01)
|
||||||
|
|
||||||
def test_so3_log_to_exp_to_log(self, batch_size: int = 100):
|
def test_so3_log_to_exp_to_log(self, batch_size: int = 100):
|
||||||
"""
|
"""
|
||||||
Check that `so3_log_map(so3_exponential_map(log_rot))==log_rot` for
|
Check that `so3_log_map(so3_exponential_map(log_rot))==log_rot` for
|
||||||
a randomly generated batch of rotation matrix logarithms `log_rot`.
|
a randomly generated batch of rotation matrix logarithms `log_rot`.
|
||||||
"""
|
"""
|
||||||
log_rot = TestSO3.init_log_rot(batch_size=batch_size)
|
log_rot = TestSO3.init_log_rot(batch_size=batch_size)
|
||||||
|
# check also the singular cases where rot. angle = 0
|
||||||
|
log_rot[:1] = 0
|
||||||
log_rot_ = so3_log_map(so3_exponential_map(log_rot))
|
log_rot_ = so3_log_map(so3_exponential_map(log_rot))
|
||||||
max_df = (log_rot - log_rot_).abs().max()
|
self.assertClose(log_rot, log_rot_, atol=1e-4)
|
||||||
self.assertAlmostEqual(float(max_df), 0.0, 4)
|
|
||||||
|
|
||||||
def test_so3_exp_to_log_to_exp(self, batch_size: int = 100):
|
def test_so3_exp_to_log_to_exp(self, batch_size: int = 100):
|
||||||
"""
|
"""
|
||||||
@ -151,12 +179,10 @@ class TestSO3(unittest.TestCase):
|
|||||||
a batch of randomly generated rotation matrices `R`.
|
a batch of randomly generated rotation matrices `R`.
|
||||||
"""
|
"""
|
||||||
rot = TestSO3.init_rot(batch_size=batch_size)
|
rot = TestSO3.init_rot(batch_size=batch_size)
|
||||||
rot_ = so3_exponential_map(so3_log_map(rot))
|
rot_ = so3_exponential_map(so3_log_map(rot, eps=1e-8), eps=1e-8)
|
||||||
angles = so3_relative_angle(rot, rot_)
|
angles = so3_relative_angle(rot, rot_)
|
||||||
max_angle = angles.max()
|
# TODO: a lot of precision lost here ...
|
||||||
# a lot of precision lost here :(
|
self.assertClose(angles, torch.zeros_like(angles), atol=0.1)
|
||||||
# TODO: fix this test??
|
|
||||||
self.assertTrue(np.allclose(float(max_angle), 0.0, atol=0.1))
|
|
||||||
|
|
||||||
def test_so3_cos_angle(self, batch_size: int = 100):
|
def test_so3_cos_angle(self, batch_size: int = 100):
|
||||||
"""
|
"""
|
||||||
@ -168,7 +194,7 @@ class TestSO3(unittest.TestCase):
|
|||||||
rot2 = TestSO3.init_rot(batch_size=batch_size)
|
rot2 = TestSO3.init_rot(batch_size=batch_size)
|
||||||
angles = so3_relative_angle(rot1, rot2, cos_angle=False).cos()
|
angles = so3_relative_angle(rot1, rot2, cos_angle=False).cos()
|
||||||
angles_ = so3_relative_angle(rot1, rot2, cos_angle=True)
|
angles_ = so3_relative_angle(rot1, rot2, cos_angle=True)
|
||||||
self.assertTrue(torch.allclose(angles, angles_))
|
self.assertClose(angles, angles_)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def so3_expmap(batch_size: int = 10):
|
def so3_expmap(batch_size: int = 10):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user