make so3_log_map torch script compatible

Summary:
* HAT_INV_SKEW_SYMMETRIC_TOL was a global variable and torch script gives an error when compiling that function. Move it to the function scope.
* torch script gives error when compiling acos_linear_extrapolation because bound is a union of tuple and float. The tuple version is kept in this diff.

Reviewed By: patricklabatut

Differential Revision: D30614916

fbshipit-source-id: 34258d200dc6a09fbf8917cac84ba8a269c00aef
This commit is contained in:
Shangchen Han 2021-09-10 11:12:13 -07:00 committed by Facebook GitHub Bot
parent c3d7808868
commit 46f727cb68
4 changed files with 22 additions and 26 deletions

View File

@ -9,10 +9,12 @@ from typing import Tuple, Union
import torch
DEFAULT_ACOS_BOUND = 1.0 - 1e-4
def acos_linear_extrapolation(
x: torch.Tensor,
bound: Union[float, Tuple[float, float]] = 1.0 - 1e-4,
bounds: Tuple[float, float] = (-DEFAULT_ACOS_BOUND, DEFAULT_ACOS_BOUND),
) -> torch.Tensor:
"""
Implements `arccos(x)` which is linearly extrapolated outside `x`'s original
@ -21,23 +23,20 @@ def acos_linear_extrapolation(
More specifically:
```
if -bound <= x <= bound:
bounds=(lower_bound, upper_bound)
if lower_bound <= x <= upper_bound:
acos_linear_extrapolation(x) = acos(x)
elif x <= -bound: # 1st order Taylor approximation
acos_linear_extrapolation(x) = acos(-bound) + dacos/dx(-bound) * (x - (-bound))
else: # x >= bound
acos_linear_extrapolation(x) = acos(bound) + dacos/dx(bound) * (x - bound)
elif x <= lower_bound: # 1st order Taylor approximation
acos_linear_extrapolation(x) = acos(lower_bound) + dacos/dx(lower_bound) * (x - lower_bound)
else: # x >= upper_bound
acos_linear_extrapolation(x) = acos(upper_bound) + dacos/dx(upper_bound) * (x - upper_bound)
```
Note that `bound` can be made more specific with setting
`bound=[lower_bound, upper_bound]` as detailed below.
Args:
x: Input `Tensor`.
bound: A float constant or a float 2-tuple defining the region for the
bounds: A float 2-tuple defining the region for the
linear extrapolation of `acos`.
If `bound` is a float scalar, linearly interpolates acos for
`x <= -bound` or `bound <= x`.
If `bound` is a 2-tuple, the first/second element of `bound`
The first/second element of `bound`
describes the lower/upper bound that defines the lower/upper
extrapolation region, i.e. the region where
`x <= bound[0]`/`bound[1] <= x`.
@ -46,11 +45,7 @@ def acos_linear_extrapolation(
acos_linear_extrapolation: `Tensor` containing the extrapolated `arccos(x)`.
"""
if isinstance(bound, float):
upper_bound = bound
lower_bound = -bound
else:
lower_bound, upper_bound = bound
lower_bound, upper_bound = bounds
if lower_bound > upper_bound:
raise ValueError("lower bound has to be smaller or equal to upper bound.")

View File

@ -12,9 +12,6 @@ import torch
from ..transforms import acos_linear_extrapolation
HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5
def so3_relative_angle(
R1: torch.Tensor,
R2: torch.Tensor,
@ -104,7 +101,8 @@ def so3_rotation_angle(
return phi_cos
else:
if cos_bound > 0.0:
return acos_linear_extrapolation(phi_cos, 1.0 - cos_bound)
bound = 1.0 - cos_bound
return acos_linear_extrapolation(phi_cos, (-bound, bound))
else:
return torch.acos(phi_cos)
@ -250,6 +248,8 @@ def hat_inv(h: torch.Tensor) -> torch.Tensor:
raise ValueError("Input has to be a batch of 3x3 Tensors.")
ss_diff = torch.abs(h + h.permute(0, 2, 1)).max()
HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5
if float(ss_diff) > HAT_INV_SKEW_SYMMETRIC_TOL:
raise ValueError("One of input matrices is not skew-symmetric.")

View File

@ -101,11 +101,6 @@ class TestAcosLinearExtrapolation(TestCaseMixin, unittest.TestCase):
self._test_acos_outside_bounds(
x[x_lower], y[x_lower], dacos_dx[x_lower], lower_bound
)
if abs(upper_bound + lower_bound) <= 1e-5: # lower_bound==-upper_bound
# check that passing bounds=upper_bound gives the same
# resut as bounds=[lower_bound, upper_bound]
y_one_bound = acos_linear_extrapolation(x, upper_bound)
self.assertClose(y_one_bound, y)
def test_acos(self, batch_size: int = 10000):
"""

View File

@ -7,6 +7,7 @@
import math
import unittest
from distutils.version import LooseVersion
import numpy as np
import torch
@ -268,6 +269,11 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
# all grad values have to be finite
self.assertTrue(torch.isfinite(r.grad).all())
@unittest.skipIf(LooseVersion(torch.__version__) < "1.9", "recent torchscript only")
def test_scriptable(self):
torch.jit.script(so3_exp_map)
torch.jit.script(so3_log_map)
@staticmethod
def so3_expmap(batch_size: int = 10):
log_rot = TestSO3.init_log_rot(batch_size=batch_size)