make so3_log_map torch script compatible

Summary:
* HAT_INV_SKEW_SYMMETRIC_TOL was a global variable and torch script gives an error when compiling that function. Move it to the function scope.
* torch script gives error when compiling acos_linear_extrapolation because bound is a union of tuple and float. The tuple version is kept in this diff.

Reviewed By: patricklabatut

Differential Revision: D30614916

fbshipit-source-id: 34258d200dc6a09fbf8917cac84ba8a269c00aef
This commit is contained in:
Shangchen Han
2021-09-10 11:12:13 -07:00
committed by Facebook GitHub Bot
parent c3d7808868
commit 46f727cb68
4 changed files with 22 additions and 26 deletions

View File

@@ -101,11 +101,6 @@ class TestAcosLinearExtrapolation(TestCaseMixin, unittest.TestCase):
self._test_acos_outside_bounds(
x[x_lower], y[x_lower], dacos_dx[x_lower], lower_bound
)
if abs(upper_bound + lower_bound) <= 1e-5: # lower_bound==-upper_bound
# check that passing bounds=upper_bound gives the same
# resut as bounds=[lower_bound, upper_bound]
y_one_bound = acos_linear_extrapolation(x, upper_bound)
self.assertClose(y_one_bound, y)
def test_acos(self, batch_size: int = 10000):
"""

View File

@@ -7,6 +7,7 @@
import math
import unittest
from distutils.version import LooseVersion
import numpy as np
import torch
@@ -268,6 +269,11 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
# all grad values have to be finite
self.assertTrue(torch.isfinite(r.grad).all())
@unittest.skipIf(LooseVersion(torch.__version__) < "1.9", "recent torchscript only")
def test_scriptable(self):
torch.jit.script(so3_exp_map)
torch.jit.script(so3_log_map)
@staticmethod
def so3_expmap(batch_size: int = 10):
log_rot = TestSO3.init_log_rot(batch_size=batch_size)