mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00

Reviewed By: shannonzhu Differential Revision: D33970393 fbshipit-source-id: 9b4dfaccfc3793fd37705a923d689cb14c9d26ba
87 lines
2.9 KiB
Python
87 lines
2.9 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import math
|
|
from typing import Tuple
|
|
|
|
import torch
|
|
|
|
|
|
DEFAULT_ACOS_BOUND: float = 1.0 - 1e-4
|
|
|
|
|
|
def acos_linear_extrapolation(
|
|
x: torch.Tensor,
|
|
bounds: Tuple[float, float] = (-DEFAULT_ACOS_BOUND, DEFAULT_ACOS_BOUND),
|
|
) -> torch.Tensor:
|
|
"""
|
|
Implements `arccos(x)` which is linearly extrapolated outside `x`'s original
|
|
domain of `(-1, 1)`. This allows for stable backpropagation in case `x`
|
|
is not guaranteed to be strictly within `(-1, 1)`.
|
|
|
|
More specifically:
|
|
```
|
|
bounds=(lower_bound, upper_bound)
|
|
if lower_bound <= x <= upper_bound:
|
|
acos_linear_extrapolation(x) = acos(x)
|
|
elif x <= lower_bound: # 1st order Taylor approximation
|
|
acos_linear_extrapolation(x)
|
|
= acos(lower_bound) + dacos/dx(lower_bound) * (x - lower_bound)
|
|
else: # x >= upper_bound
|
|
acos_linear_extrapolation(x)
|
|
= acos(upper_bound) + dacos/dx(upper_bound) * (x - upper_bound)
|
|
```
|
|
|
|
Args:
|
|
x: Input `Tensor`.
|
|
bounds: A float 2-tuple defining the region for the
|
|
linear extrapolation of `acos`.
|
|
The first/second element of `bound`
|
|
describes the lower/upper bound that defines the lower/upper
|
|
extrapolation region, i.e. the region where
|
|
`x <= bound[0]`/`bound[1] <= x`.
|
|
Note that all elements of `bound` have to be within (-1, 1).
|
|
Returns:
|
|
acos_linear_extrapolation: `Tensor` containing the extrapolated `arccos(x)`.
|
|
"""
|
|
|
|
lower_bound, upper_bound = bounds
|
|
|
|
if lower_bound > upper_bound:
|
|
raise ValueError("lower bound has to be smaller or equal to upper bound.")
|
|
|
|
if lower_bound <= -1.0 or upper_bound >= 1.0:
|
|
raise ValueError("Both lower bound and upper bound have to be within (-1, 1).")
|
|
|
|
# init an empty tensor and define the domain sets
|
|
acos_extrap = torch.empty_like(x)
|
|
x_upper = x >= upper_bound
|
|
x_lower = x <= lower_bound
|
|
x_mid = (~x_upper) & (~x_lower)
|
|
|
|
# acos calculation for upper_bound < x < lower_bound
|
|
acos_extrap[x_mid] = torch.acos(x[x_mid])
|
|
# the linear extrapolation for x >= upper_bound
|
|
acos_extrap[x_upper] = _acos_linear_approximation(x[x_upper], upper_bound)
|
|
# the linear extrapolation for x <= lower_bound
|
|
acos_extrap[x_lower] = _acos_linear_approximation(x[x_lower], lower_bound)
|
|
|
|
return acos_extrap
|
|
|
|
|
|
def _acos_linear_approximation(x: torch.Tensor, x0: float) -> torch.Tensor:
|
|
"""
|
|
Calculates the 1st order Taylor expansion of `arccos(x)` around `x0`.
|
|
"""
|
|
return (x - x0) * _dacos_dx(x0) + math.acos(x0)
|
|
|
|
|
|
def _dacos_dx(x: float) -> float:
|
|
"""
|
|
Calculates the derivative of `arccos(x)` w.r.t. `x`.
|
|
"""
|
|
return (-1.0) / math.sqrt(1.0 - x * x)
|