Update so3 operations for numerical stability

Summary: Replace implementations of `so3_exp_map` and `so3_log_map` in so3.py with existing more-stable implementations.

Reviewed By: bottler

Differential Revision: D52513319

fbshipit-source-id: fbfc039643fef284d8baa11bab61651964077afe
This commit is contained in:
Abdelrahman Selim
2024-01-04 02:26:56 -08:00
committed by Facebook GitHub Bot
parent 3621a36494
commit 292acc71a3
2 changed files with 6 additions and 54 deletions

View File

@@ -97,20 +97,6 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
so3_log_map(rot)
self.assertTrue("Input has to be a batch of 3x3 Tensors." in str(err.exception))
# trace of rot definitely bigger than 3 or smaller than -1
rot = torch.cat(
(
torch.rand(size=[5, 3, 3], device=device) + 4.0,
torch.rand(size=[5, 3, 3], device=device) - 3.0,
)
)
with self.assertRaises(ValueError) as err:
so3_log_map(rot)
self.assertTrue(
"A matrix has trace outside valid range [-1-eps,3+eps]."
in str(err.exception)
)
def test_so3_exp_singularity(self, batch_size: int = 100):
"""
Tests whether the `so3_exp_map` is robust to the input vectors