mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-22 07:10:34 +08:00
Linearly extrapolated acos.
Summary:
Implements a backprop-safe version of `torch.acos` that linearly extrapolates the function outside bounds.
Below is a plot of the extrapolated acos for different bounds:
{F611339485}
Reviewed By: bottler, nikhilaravi
Differential Revision: D27945714
fbshipit-source-id: fa2e2385b56d6fe534338d5192447c4a3aec540c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
88f5d79088
commit
dd45123f20
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from .math import acos_linear_extrapolation
|
||||
from .rotation_conversions import (
|
||||
axis_angle_to_matrix,
|
||||
axis_angle_to_quaternion,
|
||||
|
||||
83
pytorch3d/transforms/math.py
Normal file
83
pytorch3d/transforms/math.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def acos_linear_extrapolation(
|
||||
x: torch.Tensor,
|
||||
bound: Union[float, Tuple[float, float]] = 1.0 - 1e-4,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Implements `arccos(x)` which is linearly extrapolated outside `x`'s original
|
||||
domain of `(-1, 1)`. This allows for stable backpropagation in case `x`
|
||||
is not guaranteed to be strictly within `(-1, 1)`.
|
||||
|
||||
More specifically:
|
||||
```
|
||||
if -bound <= x <= bound:
|
||||
acos_linear_extrapolation(x) = acos(x)
|
||||
elif x <= -bound: # 1st order Taylor approximation
|
||||
acos_linear_extrapolation(x) = acos(-bound) + dacos/dx(-bound) * (x - (-bound))
|
||||
else: # x >= bound
|
||||
acos_linear_extrapolation(x) = acos(bound) + dacos/dx(bound) * (x - bound)
|
||||
```
|
||||
Note that `bound` can be made more specific with setting
|
||||
`bound=[lower_bound, upper_bound]` as detailed below.
|
||||
|
||||
Args:
|
||||
x: Input `Tensor`.
|
||||
bound: A float constant or a float 2-tuple defining the region for the
|
||||
linear extrapolation of `acos`.
|
||||
If `bound` is a float scalar, linearly interpolates acos for
|
||||
`x <= -bound` or `bound <= x`.
|
||||
If `bound` is a 2-tuple, the first/second element of `bound`
|
||||
describes the lower/upper bound that defines the lower/upper
|
||||
extrapolation region, i.e. the region where
|
||||
`x <= bound[0]`/`bound[1] <= x`.
|
||||
Note that all elements of `bound` have to be within (-1, 1).
|
||||
Returns:
|
||||
acos_linear_extrapolation: `Tensor` containing the extrapolated `arccos(x)`.
|
||||
"""
|
||||
|
||||
if isinstance(bound, float):
|
||||
upper_bound = bound
|
||||
lower_bound = -bound
|
||||
else:
|
||||
lower_bound, upper_bound = bound
|
||||
|
||||
if lower_bound > upper_bound:
|
||||
raise ValueError("lower bound has to be smaller or equal to upper bound.")
|
||||
|
||||
if lower_bound <= -1.0 or upper_bound >= 1.0:
|
||||
raise ValueError("Both lower bound and upper bound have to be within (-1, 1).")
|
||||
|
||||
# init an empty tensor and define the domain sets
|
||||
acos_extrap = torch.empty_like(x)
|
||||
x_upper = x >= upper_bound
|
||||
x_lower = x <= lower_bound
|
||||
x_mid = (~x_upper) & (~x_lower)
|
||||
|
||||
# acos calculation for upper_bound < x < lower_bound
|
||||
acos_extrap[x_mid] = torch.acos(x[x_mid])
|
||||
# the linear extrapolation for x >= upper_bound
|
||||
acos_extrap[x_upper] = _acos_linear_approximation(x[x_upper], upper_bound)
|
||||
# the linear extrapolation for x <= lower_bound
|
||||
acos_extrap[x_lower] = _acos_linear_approximation(x[x_lower], lower_bound)
|
||||
|
||||
return acos_extrap
|
||||
|
||||
|
||||
def _acos_linear_approximation(x: torch.Tensor, x0: float) -> torch.Tensor:
|
||||
"""
|
||||
Calculates the 1st order Taylor expansion of `arccos(x)` around `x0`.
|
||||
"""
|
||||
return (x - x0) * _dacos_dx(x0) + math.acos(x0)
|
||||
|
||||
|
||||
def _dacos_dx(x: float) -> float:
|
||||
"""
|
||||
Calculates the derivative of `arccos(x)` w.r.t. `x`.
|
||||
"""
|
||||
return (-1.0) / math.sqrt(1.0 - x * x)
|
||||
Reference in New Issue
Block a user