From dd45123f202441e7539c4af9b35d07317b786528 Mon Sep 17 00:00:00 2001 From: David Novotny Date: Mon, 21 Jun 2021 04:47:31 -0700 Subject: [PATCH] Linearly extrapolated acos. Summary: Implements a backprop-safe version of `torch.acos` that linearly extrapolates the function outside bounds. Below is a plot of the extrapolated acos for different bounds: {F611339485} Reviewed By: bottler, nikhilaravi Differential Revision: D27945714 fbshipit-source-id: fa2e2385b56d6fe534338d5192447c4a3aec540c --- pytorch3d/transforms/__init__.py | 1 + pytorch3d/transforms/math.py | 83 ++++++++++++++ tests/bm_acos_linear_extrapolation.py | 23 ++++ tests/test_acos_linear_extrapolation.py | 139 ++++++++++++++++++++++++ 4 files changed, 246 insertions(+) create mode 100644 pytorch3d/transforms/math.py create mode 100644 tests/bm_acos_linear_extrapolation.py create mode 100644 tests/test_acos_linear_extrapolation.py diff --git a/pytorch3d/transforms/__init__.py b/pytorch3d/transforms/__init__.py index d32ba792..5709bd6a 100644 --- a/pytorch3d/transforms/__init__.py +++ b/pytorch3d/transforms/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +from .math import acos_linear_extrapolation from .rotation_conversions import ( axis_angle_to_matrix, axis_angle_to_quaternion, diff --git a/pytorch3d/transforms/math.py b/pytorch3d/transforms/math.py new file mode 100644 index 00000000..ff513d69 --- /dev/null +++ b/pytorch3d/transforms/math.py @@ -0,0 +1,83 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import math +from typing import Tuple, Union + +import torch + + +def acos_linear_extrapolation( + x: torch.Tensor, + bound: Union[float, Tuple[float, float]] = 1.0 - 1e-4, +) -> torch.Tensor: + """ + Implements `arccos(x)` which is linearly extrapolated outside `x`'s original + domain of `(-1, 1)`. This allows for stable backpropagation in case `x` + is not guaranteed to be strictly within `(-1, 1)`. + + More specifically: + ``` + if -bound <= x <= bound: + acos_linear_extrapolation(x) = acos(x) + elif x <= -bound: # 1st order Taylor approximation + acos_linear_extrapolation(x) = acos(-bound) + dacos/dx(-bound) * (x - (-bound)) + else: # x >= bound + acos_linear_extrapolation(x) = acos(bound) + dacos/dx(bound) * (x - bound) + ``` + Note that `bound` can be made more specific with setting + `bound=[lower_bound, upper_bound]` as detailed below. + + Args: + x: Input `Tensor`. + bound: A float constant or a float 2-tuple defining the region for the + linear extrapolation of `acos`. + If `bound` is a float scalar, linearly interpolates acos for + `x <= -bound` or `bound <= x`. + If `bound` is a 2-tuple, the first/second element of `bound` + describes the lower/upper bound that defines the lower/upper + extrapolation region, i.e. the region where + `x <= bound[0]`/`bound[1] <= x`. + Note that all elements of `bound` have to be within (-1, 1). + Returns: + acos_linear_extrapolation: `Tensor` containing the extrapolated `arccos(x)`. + """ + + if isinstance(bound, float): + upper_bound = bound + lower_bound = -bound + else: + lower_bound, upper_bound = bound + + if lower_bound > upper_bound: + raise ValueError("lower bound has to be smaller or equal to upper bound.") + + if lower_bound <= -1.0 or upper_bound >= 1.0: + raise ValueError("Both lower bound and upper bound have to be within (-1, 1).") + + # init an empty tensor and define the domain sets + acos_extrap = torch.empty_like(x) + x_upper = x >= upper_bound + x_lower = x <= lower_bound + x_mid = (~x_upper) & (~x_lower) + + # acos calculation for upper_bound < x < lower_bound + acos_extrap[x_mid] = torch.acos(x[x_mid]) + # the linear extrapolation for x >= upper_bound + acos_extrap[x_upper] = _acos_linear_approximation(x[x_upper], upper_bound) + # the linear extrapolation for x <= lower_bound + acos_extrap[x_lower] = _acos_linear_approximation(x[x_lower], lower_bound) + + return acos_extrap + + +def _acos_linear_approximation(x: torch.Tensor, x0: float) -> torch.Tensor: + """ + Calculates the 1st order Taylor expansion of `arccos(x)` around `x0`. + """ + return (x - x0) * _dacos_dx(x0) + math.acos(x0) + + +def _dacos_dx(x: float) -> float: + """ + Calculates the derivative of `arccos(x)` w.r.t. `x`. + """ + return (-1.0) / math.sqrt(1.0 - x * x) diff --git a/tests/bm_acos_linear_extrapolation.py b/tests/bm_acos_linear_extrapolation.py new file mode 100644 index 00000000..7902b292 --- /dev/null +++ b/tests/bm_acos_linear_extrapolation.py @@ -0,0 +1,23 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +from fvcore.common.benchmark import benchmark +from test_acos_linear_extrapolation import TestAcosLinearExtrapolation + + +def bm_acos_linear_extrapolation() -> None: + kwargs_list = [ + {"batch_size": 1}, + {"batch_size": 100}, + {"batch_size": 10000}, + {"batch_size": 1000000}, + ] + benchmark( + TestAcosLinearExtrapolation.acos_linear_extrapolation, + "ACOS_LINEAR_EXTRAPOLATION", + kwargs_list, + warmup_iters=1, + ) + + +if __name__ == "__main__": + bm_acos_linear_extrapolation() diff --git a/tests/test_acos_linear_extrapolation.py b/tests/test_acos_linear_extrapolation.py new file mode 100644 index 00000000..b4c73a7c --- /dev/null +++ b/tests/test_acos_linear_extrapolation.py @@ -0,0 +1,139 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + + +import unittest + +import numpy as np +import torch +from common_testing import TestCaseMixin +from pytorch3d.transforms import acos_linear_extrapolation + + +class TestAcosLinearExtrapolation(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + super().setUp() + torch.manual_seed(42) + np.random.seed(42) + + @staticmethod + def init_acos_boundary_values(batch_size: int = 10000): + """ + Initialize a tensor containing values close to the bounds of the + domain of `acos`, i.e. close to -1 or 1; and random values between (-1, 1). + """ + device = torch.device("cuda:0") + # one quarter are random values between -1 and 1 + x_rand = 2 * torch.rand(batch_size // 4, dtype=torch.float32, device=device) - 1 + x = [x_rand] + for bound in [-1, 1]: + for above_bound in [True, False]: + for noise_std in [1e-4, 1e-2]: + n_generate = (batch_size - batch_size // 4) // 8 + x_add = ( + bound + + (2 * float(above_bound) - 1) + * torch.randn( + n_generate, device=device, dtype=torch.float32 + ).abs() + * noise_std + ) + x.append(x_add) + x = torch.cat(x) + return x + + @staticmethod + def acos_linear_extrapolation(batch_size: int): + x = TestAcosLinearExtrapolation.init_acos_boundary_values(batch_size) + torch.cuda.synchronize() + + def compute_acos(): + acos_linear_extrapolation(x) + torch.cuda.synchronize() + + return compute_acos + + def _test_acos_outside_bounds(self, x, y, dydx, bound): + """ + Check that `acos_linear_extrapolation` yields points on a line with correct + slope, and that the function is continuous around `bound`. + """ + bound_t = torch.tensor(bound, device=x.device, dtype=x.dtype) + # fit a line: slope * x + bias = y + x_1 = torch.stack([x, torch.ones_like(x)], dim=-1) + solution = torch.linalg.lstsq(x_1, y[:, None]).solution + slope, bias = solution.view(-1)[:2] + desired_slope = (-1.0) / torch.sqrt(1.0 - bound_t ** 2) + # test that the desired slope is the same as the fitted one + self.assertClose(desired_slope.view(1), slope.view(1), atol=1e-2) + # test that the autograd's slope is the same as the desired one + self.assertClose(desired_slope.expand_as(dydx), dydx, atol=1e-2) + # test that the value of the fitted line at x=bound equals + # arccos(x), i.e. the function is continuous around the bound + y_bound_lin = (slope * bound_t + bias).view(1) + y_bound_acos = bound_t.acos().view(1) + self.assertClose(y_bound_lin, y_bound_acos, atol=1e-2) + + def _one_acos_test(self, x: torch.Tensor, lower_bound: float, upper_bound: float): + """ + Test that `acos_linear_extrapolation` returns correct values for + `x` between/above/below `lower_bound`/`upper_bound`. + """ + x.requires_grad = True + x.grad = None + y = acos_linear_extrapolation(x, [lower_bound, upper_bound]) + # compute the gradient of the acos w.r.t. x + y.backward(torch.ones_like(y)) + dacos_dx = x.grad + x_lower = x <= lower_bound + x_upper = x >= upper_bound + x_mid = (~x_lower) & (~x_upper) + # test that between bounds, the function returns plain acos + self.assertClose(x[x_mid].acos(), y[x_mid]) + # test that outside the bounds, the function is linear with the right + # slope and continuous around the bound + self._test_acos_outside_bounds( + x[x_upper], y[x_upper], dacos_dx[x_upper], upper_bound + ) + self._test_acos_outside_bounds( + x[x_lower], y[x_lower], dacos_dx[x_lower], lower_bound + ) + if abs(upper_bound + lower_bound) <= 1e-5: # lower_bound==-upper_bound + # check that passing bounds=upper_bound gives the same + # resut as bounds=[lower_bound, upper_bound] + y_one_bound = acos_linear_extrapolation(x, upper_bound) + self.assertClose(y_one_bound, y) + + def test_acos(self, batch_size: int = 10000): + """ + Tests whether the function returns correct outputs + inside/outside the bounds. + """ + x = TestAcosLinearExtrapolation.init_acos_boundary_values(batch_size) + bounds = 1 - 10.0 ** torch.linspace(-1, -5, 5) + for lower_bound in -bounds: + for upper_bound in bounds: + if upper_bound < lower_bound: + continue + self._one_acos_test(x, float(lower_bound), float(upper_bound)) + + def test_finite_gradient(self, batch_size: int = 10000): + """ + Tests whether gradients stay finite close to the bounds. + """ + x = TestAcosLinearExtrapolation.init_acos_boundary_values(batch_size) + x.requires_grad = True + bounds = 1 - 10.0 ** torch.linspace(-1, -5, 5) + for lower_bound in -bounds: + for upper_bound in bounds: + if upper_bound < lower_bound: + continue + x.grad = None + y = acos_linear_extrapolation( + x, + [float(lower_bound), float(upper_bound)], + ) + self.assertTrue(torch.isfinite(y).all()) + loss = y.mean() + loss.backward() + self.assertIsNotNone(x.grad) + self.assertTrue(torch.isfinite(x.grad).all())