mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Linearly extrapolated acos.
Summary: Implements a backprop-safe version of `torch.acos` that linearly extrapolates the function outside bounds. Below is a plot of the extrapolated acos for different bounds: {F611339485} Reviewed By: bottler, nikhilaravi Differential Revision: D27945714 fbshipit-source-id: fa2e2385b56d6fe534338d5192447c4a3aec540c
This commit is contained in:
parent
88f5d79088
commit
dd45123f20
@ -1,5 +1,6 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from .math import acos_linear_extrapolation
|
||||
from .rotation_conversions import (
|
||||
axis_angle_to_matrix,
|
||||
axis_angle_to_quaternion,
|
||||
|
83
pytorch3d/transforms/math.py
Normal file
83
pytorch3d/transforms/math.py
Normal file
@ -0,0 +1,83 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def acos_linear_extrapolation(
|
||||
x: torch.Tensor,
|
||||
bound: Union[float, Tuple[float, float]] = 1.0 - 1e-4,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Implements `arccos(x)` which is linearly extrapolated outside `x`'s original
|
||||
domain of `(-1, 1)`. This allows for stable backpropagation in case `x`
|
||||
is not guaranteed to be strictly within `(-1, 1)`.
|
||||
|
||||
More specifically:
|
||||
```
|
||||
if -bound <= x <= bound:
|
||||
acos_linear_extrapolation(x) = acos(x)
|
||||
elif x <= -bound: # 1st order Taylor approximation
|
||||
acos_linear_extrapolation(x) = acos(-bound) + dacos/dx(-bound) * (x - (-bound))
|
||||
else: # x >= bound
|
||||
acos_linear_extrapolation(x) = acos(bound) + dacos/dx(bound) * (x - bound)
|
||||
```
|
||||
Note that `bound` can be made more specific with setting
|
||||
`bound=[lower_bound, upper_bound]` as detailed below.
|
||||
|
||||
Args:
|
||||
x: Input `Tensor`.
|
||||
bound: A float constant or a float 2-tuple defining the region for the
|
||||
linear extrapolation of `acos`.
|
||||
If `bound` is a float scalar, linearly interpolates acos for
|
||||
`x <= -bound` or `bound <= x`.
|
||||
If `bound` is a 2-tuple, the first/second element of `bound`
|
||||
describes the lower/upper bound that defines the lower/upper
|
||||
extrapolation region, i.e. the region where
|
||||
`x <= bound[0]`/`bound[1] <= x`.
|
||||
Note that all elements of `bound` have to be within (-1, 1).
|
||||
Returns:
|
||||
acos_linear_extrapolation: `Tensor` containing the extrapolated `arccos(x)`.
|
||||
"""
|
||||
|
||||
if isinstance(bound, float):
|
||||
upper_bound = bound
|
||||
lower_bound = -bound
|
||||
else:
|
||||
lower_bound, upper_bound = bound
|
||||
|
||||
if lower_bound > upper_bound:
|
||||
raise ValueError("lower bound has to be smaller or equal to upper bound.")
|
||||
|
||||
if lower_bound <= -1.0 or upper_bound >= 1.0:
|
||||
raise ValueError("Both lower bound and upper bound have to be within (-1, 1).")
|
||||
|
||||
# init an empty tensor and define the domain sets
|
||||
acos_extrap = torch.empty_like(x)
|
||||
x_upper = x >= upper_bound
|
||||
x_lower = x <= lower_bound
|
||||
x_mid = (~x_upper) & (~x_lower)
|
||||
|
||||
# acos calculation for upper_bound < x < lower_bound
|
||||
acos_extrap[x_mid] = torch.acos(x[x_mid])
|
||||
# the linear extrapolation for x >= upper_bound
|
||||
acos_extrap[x_upper] = _acos_linear_approximation(x[x_upper], upper_bound)
|
||||
# the linear extrapolation for x <= lower_bound
|
||||
acos_extrap[x_lower] = _acos_linear_approximation(x[x_lower], lower_bound)
|
||||
|
||||
return acos_extrap
|
||||
|
||||
|
||||
def _acos_linear_approximation(x: torch.Tensor, x0: float) -> torch.Tensor:
|
||||
"""
|
||||
Calculates the 1st order Taylor expansion of `arccos(x)` around `x0`.
|
||||
"""
|
||||
return (x - x0) * _dacos_dx(x0) + math.acos(x0)
|
||||
|
||||
|
||||
def _dacos_dx(x: float) -> float:
|
||||
"""
|
||||
Calculates the derivative of `arccos(x)` w.r.t. `x`.
|
||||
"""
|
||||
return (-1.0) / math.sqrt(1.0 - x * x)
|
23
tests/bm_acos_linear_extrapolation.py
Normal file
23
tests/bm_acos_linear_extrapolation.py
Normal file
@ -0,0 +1,23 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from fvcore.common.benchmark import benchmark
|
||||
from test_acos_linear_extrapolation import TestAcosLinearExtrapolation
|
||||
|
||||
|
||||
def bm_acos_linear_extrapolation() -> None:
|
||||
kwargs_list = [
|
||||
{"batch_size": 1},
|
||||
{"batch_size": 100},
|
||||
{"batch_size": 10000},
|
||||
{"batch_size": 1000000},
|
||||
]
|
||||
benchmark(
|
||||
TestAcosLinearExtrapolation.acos_linear_extrapolation,
|
||||
"ACOS_LINEAR_EXTRAPOLATION",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bm_acos_linear_extrapolation()
|
139
tests/test_acos_linear_extrapolation.py
Normal file
139
tests/test_acos_linear_extrapolation.py
Normal file
@ -0,0 +1,139 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.transforms import acos_linear_extrapolation
|
||||
|
||||
|
||||
class TestAcosLinearExtrapolation(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(42)
|
||||
np.random.seed(42)
|
||||
|
||||
@staticmethod
|
||||
def init_acos_boundary_values(batch_size: int = 10000):
|
||||
"""
|
||||
Initialize a tensor containing values close to the bounds of the
|
||||
domain of `acos`, i.e. close to -1 or 1; and random values between (-1, 1).
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
# one quarter are random values between -1 and 1
|
||||
x_rand = 2 * torch.rand(batch_size // 4, dtype=torch.float32, device=device) - 1
|
||||
x = [x_rand]
|
||||
for bound in [-1, 1]:
|
||||
for above_bound in [True, False]:
|
||||
for noise_std in [1e-4, 1e-2]:
|
||||
n_generate = (batch_size - batch_size // 4) // 8
|
||||
x_add = (
|
||||
bound
|
||||
+ (2 * float(above_bound) - 1)
|
||||
* torch.randn(
|
||||
n_generate, device=device, dtype=torch.float32
|
||||
).abs()
|
||||
* noise_std
|
||||
)
|
||||
x.append(x_add)
|
||||
x = torch.cat(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def acos_linear_extrapolation(batch_size: int):
|
||||
x = TestAcosLinearExtrapolation.init_acos_boundary_values(batch_size)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def compute_acos():
|
||||
acos_linear_extrapolation(x)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return compute_acos
|
||||
|
||||
def _test_acos_outside_bounds(self, x, y, dydx, bound):
|
||||
"""
|
||||
Check that `acos_linear_extrapolation` yields points on a line with correct
|
||||
slope, and that the function is continuous around `bound`.
|
||||
"""
|
||||
bound_t = torch.tensor(bound, device=x.device, dtype=x.dtype)
|
||||
# fit a line: slope * x + bias = y
|
||||
x_1 = torch.stack([x, torch.ones_like(x)], dim=-1)
|
||||
solution = torch.linalg.lstsq(x_1, y[:, None]).solution
|
||||
slope, bias = solution.view(-1)[:2]
|
||||
desired_slope = (-1.0) / torch.sqrt(1.0 - bound_t ** 2)
|
||||
# test that the desired slope is the same as the fitted one
|
||||
self.assertClose(desired_slope.view(1), slope.view(1), atol=1e-2)
|
||||
# test that the autograd's slope is the same as the desired one
|
||||
self.assertClose(desired_slope.expand_as(dydx), dydx, atol=1e-2)
|
||||
# test that the value of the fitted line at x=bound equals
|
||||
# arccos(x), i.e. the function is continuous around the bound
|
||||
y_bound_lin = (slope * bound_t + bias).view(1)
|
||||
y_bound_acos = bound_t.acos().view(1)
|
||||
self.assertClose(y_bound_lin, y_bound_acos, atol=1e-2)
|
||||
|
||||
def _one_acos_test(self, x: torch.Tensor, lower_bound: float, upper_bound: float):
|
||||
"""
|
||||
Test that `acos_linear_extrapolation` returns correct values for
|
||||
`x` between/above/below `lower_bound`/`upper_bound`.
|
||||
"""
|
||||
x.requires_grad = True
|
||||
x.grad = None
|
||||
y = acos_linear_extrapolation(x, [lower_bound, upper_bound])
|
||||
# compute the gradient of the acos w.r.t. x
|
||||
y.backward(torch.ones_like(y))
|
||||
dacos_dx = x.grad
|
||||
x_lower = x <= lower_bound
|
||||
x_upper = x >= upper_bound
|
||||
x_mid = (~x_lower) & (~x_upper)
|
||||
# test that between bounds, the function returns plain acos
|
||||
self.assertClose(x[x_mid].acos(), y[x_mid])
|
||||
# test that outside the bounds, the function is linear with the right
|
||||
# slope and continuous around the bound
|
||||
self._test_acos_outside_bounds(
|
||||
x[x_upper], y[x_upper], dacos_dx[x_upper], upper_bound
|
||||
)
|
||||
self._test_acos_outside_bounds(
|
||||
x[x_lower], y[x_lower], dacos_dx[x_lower], lower_bound
|
||||
)
|
||||
if abs(upper_bound + lower_bound) <= 1e-5: # lower_bound==-upper_bound
|
||||
# check that passing bounds=upper_bound gives the same
|
||||
# resut as bounds=[lower_bound, upper_bound]
|
||||
y_one_bound = acos_linear_extrapolation(x, upper_bound)
|
||||
self.assertClose(y_one_bound, y)
|
||||
|
||||
def test_acos(self, batch_size: int = 10000):
|
||||
"""
|
||||
Tests whether the function returns correct outputs
|
||||
inside/outside the bounds.
|
||||
"""
|
||||
x = TestAcosLinearExtrapolation.init_acos_boundary_values(batch_size)
|
||||
bounds = 1 - 10.0 ** torch.linspace(-1, -5, 5)
|
||||
for lower_bound in -bounds:
|
||||
for upper_bound in bounds:
|
||||
if upper_bound < lower_bound:
|
||||
continue
|
||||
self._one_acos_test(x, float(lower_bound), float(upper_bound))
|
||||
|
||||
def test_finite_gradient(self, batch_size: int = 10000):
|
||||
"""
|
||||
Tests whether gradients stay finite close to the bounds.
|
||||
"""
|
||||
x = TestAcosLinearExtrapolation.init_acos_boundary_values(batch_size)
|
||||
x.requires_grad = True
|
||||
bounds = 1 - 10.0 ** torch.linspace(-1, -5, 5)
|
||||
for lower_bound in -bounds:
|
||||
for upper_bound in bounds:
|
||||
if upper_bound < lower_bound:
|
||||
continue
|
||||
x.grad = None
|
||||
y = acos_linear_extrapolation(
|
||||
x,
|
||||
[float(lower_bound), float(upper_bound)],
|
||||
)
|
||||
self.assertTrue(torch.isfinite(y).all())
|
||||
loss = y.mean()
|
||||
loss.backward()
|
||||
self.assertIsNotNone(x.grad)
|
||||
self.assertTrue(torch.isfinite(x.grad).all())
|
Loading…
x
Reference in New Issue
Block a user