diff --git a/pytorch3d/transforms/math.py b/pytorch3d/transforms/math.py index 41adea17..88e3084a 100644 --- a/pytorch3d/transforms/math.py +++ b/pytorch3d/transforms/math.py @@ -9,10 +9,12 @@ from typing import Tuple, Union import torch +DEFAULT_ACOS_BOUND = 1.0 - 1e-4 + def acos_linear_extrapolation( x: torch.Tensor, - bound: Union[float, Tuple[float, float]] = 1.0 - 1e-4, + bounds: Tuple[float, float] = (-DEFAULT_ACOS_BOUND, DEFAULT_ACOS_BOUND), ) -> torch.Tensor: """ Implements `arccos(x)` which is linearly extrapolated outside `x`'s original @@ -21,23 +23,20 @@ def acos_linear_extrapolation( More specifically: ``` - if -bound <= x <= bound: + bounds=(lower_bound, upper_bound) + if lower_bound <= x <= upper_bound: acos_linear_extrapolation(x) = acos(x) - elif x <= -bound: # 1st order Taylor approximation - acos_linear_extrapolation(x) = acos(-bound) + dacos/dx(-bound) * (x - (-bound)) - else: # x >= bound - acos_linear_extrapolation(x) = acos(bound) + dacos/dx(bound) * (x - bound) + elif x <= lower_bound: # 1st order Taylor approximation + acos_linear_extrapolation(x) = acos(lower_bound) + dacos/dx(lower_bound) * (x - lower_bound) + else: # x >= upper_bound + acos_linear_extrapolation(x) = acos(upper_bound) + dacos/dx(upper_bound) * (x - upper_bound) ``` - Note that `bound` can be made more specific with setting - `bound=[lower_bound, upper_bound]` as detailed below. Args: x: Input `Tensor`. - bound: A float constant or a float 2-tuple defining the region for the + bounds: A float 2-tuple defining the region for the linear extrapolation of `acos`. - If `bound` is a float scalar, linearly interpolates acos for - `x <= -bound` or `bound <= x`. - If `bound` is a 2-tuple, the first/second element of `bound` + The first/second element of `bound` describes the lower/upper bound that defines the lower/upper extrapolation region, i.e. the region where `x <= bound[0]`/`bound[1] <= x`. @@ -46,11 +45,7 @@ def acos_linear_extrapolation( acos_linear_extrapolation: `Tensor` containing the extrapolated `arccos(x)`. """ - if isinstance(bound, float): - upper_bound = bound - lower_bound = -bound - else: - lower_bound, upper_bound = bound + lower_bound, upper_bound = bounds if lower_bound > upper_bound: raise ValueError("lower bound has to be smaller or equal to upper bound.") diff --git a/pytorch3d/transforms/so3.py b/pytorch3d/transforms/so3.py index 0c6ced7d..50004ed7 100644 --- a/pytorch3d/transforms/so3.py +++ b/pytorch3d/transforms/so3.py @@ -12,9 +12,6 @@ import torch from ..transforms import acos_linear_extrapolation -HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5 - - def so3_relative_angle( R1: torch.Tensor, R2: torch.Tensor, @@ -104,7 +101,8 @@ def so3_rotation_angle( return phi_cos else: if cos_bound > 0.0: - return acos_linear_extrapolation(phi_cos, 1.0 - cos_bound) + bound = 1.0 - cos_bound + return acos_linear_extrapolation(phi_cos, (-bound, bound)) else: return torch.acos(phi_cos) @@ -250,6 +248,8 @@ def hat_inv(h: torch.Tensor) -> torch.Tensor: raise ValueError("Input has to be a batch of 3x3 Tensors.") ss_diff = torch.abs(h + h.permute(0, 2, 1)).max() + + HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5 if float(ss_diff) > HAT_INV_SKEW_SYMMETRIC_TOL: raise ValueError("One of input matrices is not skew-symmetric.") diff --git a/tests/test_acos_linear_extrapolation.py b/tests/test_acos_linear_extrapolation.py index 35076509..339b4de7 100644 --- a/tests/test_acos_linear_extrapolation.py +++ b/tests/test_acos_linear_extrapolation.py @@ -101,11 +101,6 @@ class TestAcosLinearExtrapolation(TestCaseMixin, unittest.TestCase): self._test_acos_outside_bounds( x[x_lower], y[x_lower], dacos_dx[x_lower], lower_bound ) - if abs(upper_bound + lower_bound) <= 1e-5: # lower_bound==-upper_bound - # check that passing bounds=upper_bound gives the same - # resut as bounds=[lower_bound, upper_bound] - y_one_bound = acos_linear_extrapolation(x, upper_bound) - self.assertClose(y_one_bound, y) def test_acos(self, batch_size: int = 10000): """ diff --git a/tests/test_so3.py b/tests/test_so3.py index 414733d8..ad6b230d 100644 --- a/tests/test_so3.py +++ b/tests/test_so3.py @@ -7,6 +7,7 @@ import math import unittest +from distutils.version import LooseVersion import numpy as np import torch @@ -268,6 +269,11 @@ class TestSO3(TestCaseMixin, unittest.TestCase): # all grad values have to be finite self.assertTrue(torch.isfinite(r.grad).all()) + @unittest.skipIf(LooseVersion(torch.__version__) < "1.9", "recent torchscript only") + def test_scriptable(self): + torch.jit.script(so3_exp_map) + torch.jit.script(so3_log_map) + @staticmethod def so3_expmap(batch_size: int = 10): log_rot = TestSO3.init_log_rot(batch_size=batch_size)