Linearly extrapolated acos.

Summary:
Implements a backprop-safe version of `torch.acos` that linearly extrapolates the function outside bounds.

Below is a plot of the extrapolated acos for different bounds:
{F611339485}

Reviewed By: bottler, nikhilaravi

Differential Revision: D27945714

fbshipit-source-id: fa2e2385b56d6fe534338d5192447c4a3aec540c
This commit is contained in:
David Novotny
2021-06-21 04:47:31 -07:00
committed by Facebook GitHub Bot
parent 88f5d79088
commit dd45123f20
4 changed files with 246 additions and 0 deletions

View File

@@ -1,5 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .math import acos_linear_extrapolation
from .rotation_conversions import (
axis_angle_to_matrix,
axis_angle_to_quaternion,

View File

@@ -0,0 +1,83 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import math
from typing import Tuple, Union
import torch
def acos_linear_extrapolation(
x: torch.Tensor,
bound: Union[float, Tuple[float, float]] = 1.0 - 1e-4,
) -> torch.Tensor:
"""
Implements `arccos(x)` which is linearly extrapolated outside `x`'s original
domain of `(-1, 1)`. This allows for stable backpropagation in case `x`
is not guaranteed to be strictly within `(-1, 1)`.
More specifically:
```
if -bound <= x <= bound:
acos_linear_extrapolation(x) = acos(x)
elif x <= -bound: # 1st order Taylor approximation
acos_linear_extrapolation(x) = acos(-bound) + dacos/dx(-bound) * (x - (-bound))
else: # x >= bound
acos_linear_extrapolation(x) = acos(bound) + dacos/dx(bound) * (x - bound)
```
Note that `bound` can be made more specific with setting
`bound=[lower_bound, upper_bound]` as detailed below.
Args:
x: Input `Tensor`.
bound: A float constant or a float 2-tuple defining the region for the
linear extrapolation of `acos`.
If `bound` is a float scalar, linearly interpolates acos for
`x <= -bound` or `bound <= x`.
If `bound` is a 2-tuple, the first/second element of `bound`
describes the lower/upper bound that defines the lower/upper
extrapolation region, i.e. the region where
`x <= bound[0]`/`bound[1] <= x`.
Note that all elements of `bound` have to be within (-1, 1).
Returns:
acos_linear_extrapolation: `Tensor` containing the extrapolated `arccos(x)`.
"""
if isinstance(bound, float):
upper_bound = bound
lower_bound = -bound
else:
lower_bound, upper_bound = bound
if lower_bound > upper_bound:
raise ValueError("lower bound has to be smaller or equal to upper bound.")
if lower_bound <= -1.0 or upper_bound >= 1.0:
raise ValueError("Both lower bound and upper bound have to be within (-1, 1).")
# init an empty tensor and define the domain sets
acos_extrap = torch.empty_like(x)
x_upper = x >= upper_bound
x_lower = x <= lower_bound
x_mid = (~x_upper) & (~x_lower)
# acos calculation for upper_bound < x < lower_bound
acos_extrap[x_mid] = torch.acos(x[x_mid])
# the linear extrapolation for x >= upper_bound
acos_extrap[x_upper] = _acos_linear_approximation(x[x_upper], upper_bound)
# the linear extrapolation for x <= lower_bound
acos_extrap[x_lower] = _acos_linear_approximation(x[x_lower], lower_bound)
return acos_extrap
def _acos_linear_approximation(x: torch.Tensor, x0: float) -> torch.Tensor:
"""
Calculates the 1st order Taylor expansion of `arccos(x)` around `x0`.
"""
return (x - x0) * _dacos_dx(x0) + math.acos(x0)
def _dacos_dx(x: float) -> float:
"""
Calculates the derivative of `arccos(x)` w.r.t. `x`.
"""
return (-1.0) / math.sqrt(1.0 - x * x)