Update so3 operations for numerical stability

Summary: Replace implementations of `so3_exp_map` and `so3_log_map` in so3.py with existing more-stable implementations.

Reviewed By: bottler

Differential Revision: D52513319

fbshipit-source-id: fbfc039643fef284d8baa11bab61651964077afe
This commit is contained in:
Abdelrahman Selim 2024-01-04 02:26:56 -08:00 committed by Facebook GitHub Bot
parent 3621a36494
commit 292acc71a3
2 changed files with 6 additions and 54 deletions

View File

@ -8,6 +8,7 @@ import warnings
from typing import Tuple from typing import Tuple
import torch import torch
from pytorch3d.transforms import rotation_conversions
from ..transforms import acos_linear_extrapolation from ..transforms import acos_linear_extrapolation
@ -160,19 +161,10 @@ def _so3_exp_map(
nrms = (log_rot * log_rot).sum(1) nrms = (log_rot * log_rot).sum(1)
# phis ... rotation angles # phis ... rotation angles
rot_angles = torch.clamp(nrms, eps).sqrt() rot_angles = torch.clamp(nrms, eps).sqrt()
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
rot_angles_inv = 1.0 / rot_angles
fac1 = rot_angles_inv * rot_angles.sin()
fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles.cos())
skews = hat(log_rot) skews = hat(log_rot)
skews_square = torch.bmm(skews, skews) skews_square = torch.bmm(skews, skews)
R = ( R = rotation_conversions.axis_angle_to_matrix(log_rot)
fac1[:, None, None] * skews
# pyre-fixme[16]: `float` has no attribute `__getitem__`.
+ fac2[:, None, None] * skews_square
+ torch.eye(3, dtype=log_rot.dtype, device=log_rot.device)[None]
)
return R, rot_angles, skews, skews_square return R, rot_angles, skews, skews_square
@ -183,49 +175,23 @@ def so3_log_map(
""" """
Convert a batch of 3x3 rotation matrices `R` Convert a batch of 3x3 rotation matrices `R`
to a batch of 3-dimensional matrix logarithms of rotation matrices to a batch of 3-dimensional matrix logarithms of rotation matrices
The conversion has a singularity around `(R=I)` which is handled The conversion has a singularity around `(R=I)`.
by clamping controlled with the `eps` and `cos_bound` arguments.
Args: Args:
R: batch of rotation matrices of shape `(minibatch, 3, 3)`. R: batch of rotation matrices of shape `(minibatch, 3, 3)`.
eps: A float constant handling the conversion singularity. eps: (unused, for backward compatibility)
cos_bound: Clamps the cosine of the rotation angle to cos_bound: (unused, for backward compatibility)
[-1 + cos_bound, 1 - cos_bound] to avoid non-finite outputs/gradients
of the `acos` call when computing `so3_rotation_angle`.
Note that the non-finite outputs/gradients are returned when
the rotation angle is close to 0 or π.
Returns: Returns:
Batch of logarithms of input rotation matrices Batch of logarithms of input rotation matrices
of shape `(minibatch, 3)`. of shape `(minibatch, 3)`.
Raises:
ValueError if `R` is of incorrect shape.
ValueError if `R` has an unexpected trace.
""" """
N, dim1, dim2 = R.shape N, dim1, dim2 = R.shape
if dim1 != 3 or dim2 != 3: if dim1 != 3 or dim2 != 3:
raise ValueError("Input has to be a batch of 3x3 Tensors.") raise ValueError("Input has to be a batch of 3x3 Tensors.")
phi = so3_rotation_angle(R, cos_bound=cos_bound, eps=eps) return rotation_conversions.matrix_to_axis_angle(R)
phi_sin = torch.sin(phi)
# We want to avoid a tiny denominator of phi_factor = phi / (2.0 * phi_sin).
# Hence, for phi_sin.abs() <= 0.5 * eps, we approximate phi_factor with
# 2nd order Taylor expansion: phi_factor = 0.5 + (1.0 / 12) * phi**2
phi_factor = torch.empty_like(phi)
ok_denom = phi_sin.abs() > (0.5 * eps)
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
phi_factor[~ok_denom] = 0.5 + (phi[~ok_denom] ** 2) * (1.0 / 12)
phi_factor[ok_denom] = phi[ok_denom] / (2.0 * phi_sin[ok_denom])
log_rot_hat = phi_factor[:, None, None] * (R - R.permute(0, 2, 1))
log_rot = hat_inv(log_rot_hat)
return log_rot
def hat_inv(h: torch.Tensor) -> torch.Tensor: def hat_inv(h: torch.Tensor) -> torch.Tensor:

View File

@ -97,20 +97,6 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
so3_log_map(rot) so3_log_map(rot)
self.assertTrue("Input has to be a batch of 3x3 Tensors." in str(err.exception)) self.assertTrue("Input has to be a batch of 3x3 Tensors." in str(err.exception))
# trace of rot definitely bigger than 3 or smaller than -1
rot = torch.cat(
(
torch.rand(size=[5, 3, 3], device=device) + 4.0,
torch.rand(size=[5, 3, 3], device=device) - 3.0,
)
)
with self.assertRaises(ValueError) as err:
so3_log_map(rot)
self.assertTrue(
"A matrix has trace outside valid range [-1-eps,3+eps]."
in str(err.exception)
)
def test_so3_exp_singularity(self, batch_size: int = 100): def test_so3_exp_singularity(self, batch_size: int = 100):
""" """
Tests whether the `so3_exp_map` is robust to the input vectors Tests whether the `so3_exp_map` is robust to the input vectors