mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-07-31 10:52:50 +08:00
Update so3 operations for numerical stability
Summary: Replace implementations of `so3_exp_map` and `so3_log_map` in so3.py with existing more-stable implementations. Reviewed By: bottler Differential Revision: D52513319 fbshipit-source-id: fbfc039643fef284d8baa11bab61651964077afe
This commit is contained in:
parent
3621a36494
commit
292acc71a3
@ -8,6 +8,7 @@ import warnings
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.transforms import rotation_conversions
|
||||
|
||||
from ..transforms import acos_linear_extrapolation
|
||||
|
||||
@ -160,19 +161,10 @@ def _so3_exp_map(
|
||||
nrms = (log_rot * log_rot).sum(1)
|
||||
# phis ... rotation angles
|
||||
rot_angles = torch.clamp(nrms, eps).sqrt()
|
||||
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
||||
rot_angles_inv = 1.0 / rot_angles
|
||||
fac1 = rot_angles_inv * rot_angles.sin()
|
||||
fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles.cos())
|
||||
skews = hat(log_rot)
|
||||
skews_square = torch.bmm(skews, skews)
|
||||
|
||||
R = (
|
||||
fac1[:, None, None] * skews
|
||||
# pyre-fixme[16]: `float` has no attribute `__getitem__`.
|
||||
+ fac2[:, None, None] * skews_square
|
||||
+ torch.eye(3, dtype=log_rot.dtype, device=log_rot.device)[None]
|
||||
)
|
||||
R = rotation_conversions.axis_angle_to_matrix(log_rot)
|
||||
|
||||
return R, rot_angles, skews, skews_square
|
||||
|
||||
@ -183,49 +175,23 @@ def so3_log_map(
|
||||
"""
|
||||
Convert a batch of 3x3 rotation matrices `R`
|
||||
to a batch of 3-dimensional matrix logarithms of rotation matrices
|
||||
The conversion has a singularity around `(R=I)` which is handled
|
||||
by clamping controlled with the `eps` and `cos_bound` arguments.
|
||||
The conversion has a singularity around `(R=I)`.
|
||||
|
||||
Args:
|
||||
R: batch of rotation matrices of shape `(minibatch, 3, 3)`.
|
||||
eps: A float constant handling the conversion singularity.
|
||||
cos_bound: Clamps the cosine of the rotation angle to
|
||||
[-1 + cos_bound, 1 - cos_bound] to avoid non-finite outputs/gradients
|
||||
of the `acos` call when computing `so3_rotation_angle`.
|
||||
Note that the non-finite outputs/gradients are returned when
|
||||
the rotation angle is close to 0 or π.
|
||||
eps: (unused, for backward compatibility)
|
||||
cos_bound: (unused, for backward compatibility)
|
||||
|
||||
Returns:
|
||||
Batch of logarithms of input rotation matrices
|
||||
of shape `(minibatch, 3)`.
|
||||
|
||||
Raises:
|
||||
ValueError if `R` is of incorrect shape.
|
||||
ValueError if `R` has an unexpected trace.
|
||||
"""
|
||||
|
||||
N, dim1, dim2 = R.shape
|
||||
if dim1 != 3 or dim2 != 3:
|
||||
raise ValueError("Input has to be a batch of 3x3 Tensors.")
|
||||
|
||||
phi = so3_rotation_angle(R, cos_bound=cos_bound, eps=eps)
|
||||
|
||||
phi_sin = torch.sin(phi)
|
||||
|
||||
# We want to avoid a tiny denominator of phi_factor = phi / (2.0 * phi_sin).
|
||||
# Hence, for phi_sin.abs() <= 0.5 * eps, we approximate phi_factor with
|
||||
# 2nd order Taylor expansion: phi_factor = 0.5 + (1.0 / 12) * phi**2
|
||||
phi_factor = torch.empty_like(phi)
|
||||
ok_denom = phi_sin.abs() > (0.5 * eps)
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||
phi_factor[~ok_denom] = 0.5 + (phi[~ok_denom] ** 2) * (1.0 / 12)
|
||||
phi_factor[ok_denom] = phi[ok_denom] / (2.0 * phi_sin[ok_denom])
|
||||
|
||||
log_rot_hat = phi_factor[:, None, None] * (R - R.permute(0, 2, 1))
|
||||
|
||||
log_rot = hat_inv(log_rot_hat)
|
||||
|
||||
return log_rot
|
||||
return rotation_conversions.matrix_to_axis_angle(R)
|
||||
|
||||
|
||||
def hat_inv(h: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -97,20 +97,6 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
|
||||
so3_log_map(rot)
|
||||
self.assertTrue("Input has to be a batch of 3x3 Tensors." in str(err.exception))
|
||||
|
||||
# trace of rot definitely bigger than 3 or smaller than -1
|
||||
rot = torch.cat(
|
||||
(
|
||||
torch.rand(size=[5, 3, 3], device=device) + 4.0,
|
||||
torch.rand(size=[5, 3, 3], device=device) - 3.0,
|
||||
)
|
||||
)
|
||||
with self.assertRaises(ValueError) as err:
|
||||
so3_log_map(rot)
|
||||
self.assertTrue(
|
||||
"A matrix has trace outside valid range [-1-eps,3+eps]."
|
||||
in str(err.exception)
|
||||
)
|
||||
|
||||
def test_so3_exp_singularity(self, batch_size: int = 100):
|
||||
"""
|
||||
Tests whether the `so3_exp_map` is robust to the input vectors
|
||||
|
Loading…
x
Reference in New Issue
Block a user