mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
Update so3 operations for numerical stability
Summary: Replace implementations of `so3_exp_map` and `so3_log_map` in so3.py with existing more-stable implementations. Reviewed By: bottler Differential Revision: D52513319 fbshipit-source-id: fbfc039643fef284d8baa11bab61651964077afe
This commit is contained in:
committed by
Facebook GitHub Bot
parent
3621a36494
commit
292acc71a3
@@ -8,6 +8,7 @@ import warnings
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.transforms import rotation_conversions
|
||||
|
||||
from ..transforms import acos_linear_extrapolation
|
||||
|
||||
@@ -160,19 +161,10 @@ def _so3_exp_map(
|
||||
nrms = (log_rot * log_rot).sum(1)
|
||||
# phis ... rotation angles
|
||||
rot_angles = torch.clamp(nrms, eps).sqrt()
|
||||
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
||||
rot_angles_inv = 1.0 / rot_angles
|
||||
fac1 = rot_angles_inv * rot_angles.sin()
|
||||
fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles.cos())
|
||||
skews = hat(log_rot)
|
||||
skews_square = torch.bmm(skews, skews)
|
||||
|
||||
R = (
|
||||
fac1[:, None, None] * skews
|
||||
# pyre-fixme[16]: `float` has no attribute `__getitem__`.
|
||||
+ fac2[:, None, None] * skews_square
|
||||
+ torch.eye(3, dtype=log_rot.dtype, device=log_rot.device)[None]
|
||||
)
|
||||
R = rotation_conversions.axis_angle_to_matrix(log_rot)
|
||||
|
||||
return R, rot_angles, skews, skews_square
|
||||
|
||||
@@ -183,49 +175,23 @@ def so3_log_map(
|
||||
"""
|
||||
Convert a batch of 3x3 rotation matrices `R`
|
||||
to a batch of 3-dimensional matrix logarithms of rotation matrices
|
||||
The conversion has a singularity around `(R=I)` which is handled
|
||||
by clamping controlled with the `eps` and `cos_bound` arguments.
|
||||
The conversion has a singularity around `(R=I)`.
|
||||
|
||||
Args:
|
||||
R: batch of rotation matrices of shape `(minibatch, 3, 3)`.
|
||||
eps: A float constant handling the conversion singularity.
|
||||
cos_bound: Clamps the cosine of the rotation angle to
|
||||
[-1 + cos_bound, 1 - cos_bound] to avoid non-finite outputs/gradients
|
||||
of the `acos` call when computing `so3_rotation_angle`.
|
||||
Note that the non-finite outputs/gradients are returned when
|
||||
the rotation angle is close to 0 or π.
|
||||
eps: (unused, for backward compatibility)
|
||||
cos_bound: (unused, for backward compatibility)
|
||||
|
||||
Returns:
|
||||
Batch of logarithms of input rotation matrices
|
||||
of shape `(minibatch, 3)`.
|
||||
|
||||
Raises:
|
||||
ValueError if `R` is of incorrect shape.
|
||||
ValueError if `R` has an unexpected trace.
|
||||
"""
|
||||
|
||||
N, dim1, dim2 = R.shape
|
||||
if dim1 != 3 or dim2 != 3:
|
||||
raise ValueError("Input has to be a batch of 3x3 Tensors.")
|
||||
|
||||
phi = so3_rotation_angle(R, cos_bound=cos_bound, eps=eps)
|
||||
|
||||
phi_sin = torch.sin(phi)
|
||||
|
||||
# We want to avoid a tiny denominator of phi_factor = phi / (2.0 * phi_sin).
|
||||
# Hence, for phi_sin.abs() <= 0.5 * eps, we approximate phi_factor with
|
||||
# 2nd order Taylor expansion: phi_factor = 0.5 + (1.0 / 12) * phi**2
|
||||
phi_factor = torch.empty_like(phi)
|
||||
ok_denom = phi_sin.abs() > (0.5 * eps)
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||
phi_factor[~ok_denom] = 0.5 + (phi[~ok_denom] ** 2) * (1.0 / 12)
|
||||
phi_factor[ok_denom] = phi[ok_denom] / (2.0 * phi_sin[ok_denom])
|
||||
|
||||
log_rot_hat = phi_factor[:, None, None] * (R - R.permute(0, 2, 1))
|
||||
|
||||
log_rot = hat_inv(log_rot_hat)
|
||||
|
||||
return log_rot
|
||||
return rotation_conversions.matrix_to_axis_angle(R)
|
||||
|
||||
|
||||
def hat_inv(h: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user