mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
SO3 improvements for stable gradients.
Summary: Improves so3 functions to make gradient computation stable: - Instead of `torch.acos`, uses `acos_linear_extrapolation` which has finite gradients of reasonable magnitude for all inputs. - Adds tests for the latter. The tests of the finiteness of the gradient in `test_so3_exp_singularity`, `test_so3_exp_singularity`, `test_so3_cos_bound` would fail if the `so3` functions would call `torch.acos` instead of `acos_linear_extrapolation`. Reviewed By: bottler Differential Revision: D23326429 fbshipit-source-id: dc296abf2ae3ddfb3942c8146621491a9cb740ee
This commit is contained in:
parent
dd45123f20
commit
9f14e82b5a
@ -22,6 +22,7 @@ from .rotation_conversions import (
|
||||
)
|
||||
from .so3 import (
|
||||
so3_exponential_map,
|
||||
so3_exp_map,
|
||||
so3_log_map,
|
||||
so3_relative_angle,
|
||||
so3_rotation_angle,
|
||||
|
@ -1,13 +1,20 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ..transforms import acos_linear_extrapolation
|
||||
|
||||
HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5
|
||||
|
||||
|
||||
def so3_relative_angle(R1, R2, cos_angle: bool = False):
|
||||
def so3_relative_angle(
|
||||
R1: torch.Tensor,
|
||||
R2: torch.Tensor,
|
||||
cos_angle: bool = False,
|
||||
cos_bound: float = 1e-4,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculates the relative angle (in radians) between pairs of
|
||||
rotation matrices `R1` and `R2` with `angle = acos(0.5 * (Trace(R1 R2^T)-1))`
|
||||
@ -20,8 +27,12 @@ def so3_relative_angle(R1, R2, cos_angle: bool = False):
|
||||
R1: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
|
||||
R2: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
|
||||
cos_angle: If==True return cosine of the relative angle rather than
|
||||
the angle itself. This can avoid the unstable
|
||||
calculation of `acos`.
|
||||
the angle itself. This can avoid the unstable calculation of `acos`.
|
||||
cos_bound: Clamps the cosine of the relative rotation angle to
|
||||
[-1 + cos_bound, 1 - cos_bound] to avoid non-finite outputs/gradients
|
||||
of the `acos` call. Note that the non-finite outputs/gradients
|
||||
are returned when the angle is requested (i.e. `cos_angle==False`)
|
||||
and the rotation angle is close to 0 or π.
|
||||
|
||||
Returns:
|
||||
Corresponding rotation angles of shape `(minibatch,)`.
|
||||
@ -32,10 +43,15 @@ def so3_relative_angle(R1, R2, cos_angle: bool = False):
|
||||
ValueError if `R1` or `R2` has an unexpected trace.
|
||||
"""
|
||||
R12 = torch.bmm(R1, R2.permute(0, 2, 1))
|
||||
return so3_rotation_angle(R12, cos_angle=cos_angle)
|
||||
return so3_rotation_angle(R12, cos_angle=cos_angle, cos_bound=cos_bound)
|
||||
|
||||
|
||||
def so3_rotation_angle(R, eps: float = 1e-4, cos_angle: bool = False):
|
||||
def so3_rotation_angle(
|
||||
R: torch.Tensor,
|
||||
eps: float = 1e-4,
|
||||
cos_angle: bool = False,
|
||||
cos_bound: float = 1e-4,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculates angles (in radians) of a batch of rotation matrices `R` with
|
||||
`angle = acos(0.5 * (Trace(R)-1))`. The trace of the
|
||||
@ -47,8 +63,13 @@ def so3_rotation_angle(R, eps: float = 1e-4, cos_angle: bool = False):
|
||||
R: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
|
||||
eps: Tolerance for the valid trace check.
|
||||
cos_angle: If==True return cosine of the rotation angles rather than
|
||||
the angle itself. This can avoid the unstable
|
||||
calculation of `acos`.
|
||||
the angle itself. This can avoid the unstable
|
||||
calculation of `acos`.
|
||||
cos_bound: Clamps the cosine of the rotation angle to
|
||||
[-1 + cos_bound, 1 - cos_bound] to avoid non-finite outputs/gradients
|
||||
of the `acos` call. Note that the non-finite outputs/gradients
|
||||
are returned when the angle is requested (i.e. `cos_angle==False`)
|
||||
and the rotation angle is close to 0 or π.
|
||||
|
||||
Returns:
|
||||
Corresponding rotation angles of shape `(minibatch,)`.
|
||||
@ -68,20 +89,19 @@ def so3_rotation_angle(R, eps: float = 1e-4, cos_angle: bool = False):
|
||||
if ((rot_trace < -1.0 - eps) + (rot_trace > 3.0 + eps)).any():
|
||||
raise ValueError("A matrix has trace outside valid range [-1-eps,3+eps].")
|
||||
|
||||
# clamp to valid range
|
||||
rot_trace = torch.clamp(rot_trace, -1.0, 3.0)
|
||||
|
||||
# phi ... rotation angle
|
||||
phi = 0.5 * (rot_trace - 1.0)
|
||||
phi_cos = (rot_trace - 1.0) * 0.5
|
||||
|
||||
if cos_angle:
|
||||
return phi
|
||||
return phi_cos
|
||||
else:
|
||||
# pyre-fixme[16]: `float` has no attribute `acos`.
|
||||
return phi.acos()
|
||||
if cos_bound > 0.0:
|
||||
return acos_linear_extrapolation(phi_cos, 1.0 - cos_bound)
|
||||
else:
|
||||
return torch.acos(phi_cos)
|
||||
|
||||
|
||||
def so3_exponential_map(log_rot, eps: float = 0.0001):
|
||||
def so3_exp_map(log_rot: torch.Tensor, eps: float = 0.0001) -> torch.Tensor:
|
||||
"""
|
||||
Convert a batch of logarithmic representations of rotation matrices `log_rot`
|
||||
to a batch of 3x3 rotation matrices using Rodrigues formula [1].
|
||||
@ -94,18 +114,31 @@ def so3_exponential_map(log_rot, eps: float = 0.0001):
|
||||
which is handled by clamping controlled with the `eps` argument.
|
||||
|
||||
Args:
|
||||
log_rot: Batch of vectors of shape `(minibatch , 3)`.
|
||||
log_rot: Batch of vectors of shape `(minibatch, 3)`.
|
||||
eps: A float constant handling the conversion singularity.
|
||||
|
||||
Returns:
|
||||
Batch of rotation matrices of shape `(minibatch , 3 , 3)`.
|
||||
Batch of rotation matrices of shape `(minibatch, 3, 3)`.
|
||||
|
||||
Raises:
|
||||
ValueError if `log_rot` is of incorrect shape.
|
||||
|
||||
[1] https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
|
||||
"""
|
||||
return _so3_exp_map(log_rot, eps=eps)[0]
|
||||
|
||||
|
||||
so3_exponential_map = so3_exp_map
|
||||
|
||||
|
||||
def _so3_exp_map(
|
||||
log_rot: torch.Tensor, eps: float = 0.0001
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
A helper function that computes the so3 exponential map and,
|
||||
apart from the rotation matrix, also returns intermediate variables
|
||||
that can be re-used in other functions.
|
||||
"""
|
||||
_, dim = log_rot.shape
|
||||
if dim != 3:
|
||||
raise ValueError("Input tensor shape has to be Nx3.")
|
||||
@ -117,27 +150,35 @@ def so3_exponential_map(log_rot, eps: float = 0.0001):
|
||||
fac1 = rot_angles_inv * rot_angles.sin()
|
||||
fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles.cos())
|
||||
skews = hat(log_rot)
|
||||
skews_square = torch.bmm(skews, skews)
|
||||
|
||||
R = (
|
||||
# pyre-fixme[16]: `float` has no attribute `__getitem__`.
|
||||
fac1[:, None, None] * skews
|
||||
+ fac2[:, None, None] * torch.bmm(skews, skews)
|
||||
+ fac2[:, None, None] * skews_square
|
||||
+ torch.eye(3, dtype=log_rot.dtype, device=log_rot.device)[None]
|
||||
)
|
||||
|
||||
return R
|
||||
return R, rot_angles, skews, skews_square
|
||||
|
||||
|
||||
def so3_log_map(R, eps: float = 0.0001):
|
||||
def so3_log_map(
|
||||
R: torch.Tensor, eps: float = 0.0001, cos_bound: float = 1e-4
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert a batch of 3x3 rotation matrices `R`
|
||||
to a batch of 3-dimensional matrix logarithms of rotation matrices
|
||||
The conversion has a singularity around `(R=I)` which is handled
|
||||
by clamping controlled with the `eps` argument.
|
||||
by clamping controlled with the `eps` and `cos_bound` arguments.
|
||||
|
||||
Args:
|
||||
R: batch of rotation matrices of shape `(minibatch, 3, 3)`.
|
||||
eps: A float constant handling the conversion singularity.
|
||||
cos_bound: Clamps the cosine of the rotation angle to
|
||||
[-1 + cos_bound, 1 - cos_bound] to avoid non-finite outputs/gradients
|
||||
of the `acos` call when computing `so3_rotation_angle`.
|
||||
Note that the non-finite outputs/gradients are returned when
|
||||
the rotation angle is close to 0 or π.
|
||||
|
||||
Returns:
|
||||
Batch of logarithms of input rotation matrices
|
||||
@ -152,22 +193,26 @@ def so3_log_map(R, eps: float = 0.0001):
|
||||
if dim1 != 3 or dim2 != 3:
|
||||
raise ValueError("Input has to be a batch of 3x3 Tensors.")
|
||||
|
||||
phi = so3_rotation_angle(R)
|
||||
phi = so3_rotation_angle(R, cos_bound=cos_bound, eps=eps)
|
||||
|
||||
phi_sin = phi.sin()
|
||||
phi_sin = torch.sin(phi)
|
||||
|
||||
phi_denom = (
|
||||
torch.clamp(phi_sin.abs(), eps) * phi_sin.sign()
|
||||
+ (phi_sin == 0).type_as(phi) * eps
|
||||
)
|
||||
# We want to avoid a tiny denominator of phi_factor = phi / (2.0 * phi_sin).
|
||||
# Hence, for phi_sin.abs() <= 0.5 * eps, we approximate phi_factor with
|
||||
# 2nd order Taylor expansion: phi_factor = 0.5 + (1.0 / 12) * phi**2
|
||||
phi_factor = torch.empty_like(phi)
|
||||
ok_denom = phi_sin.abs() > (0.5 * eps)
|
||||
phi_factor[~ok_denom] = 0.5 + (phi[~ok_denom] ** 2) * (1.0 / 12)
|
||||
phi_factor[ok_denom] = phi[ok_denom] / (2.0 * phi_sin[ok_denom])
|
||||
|
||||
log_rot_hat = phi_factor[:, None, None] * (R - R.permute(0, 2, 1))
|
||||
|
||||
log_rot_hat = (phi / (2.0 * phi_denom))[:, None, None] * (R - R.permute(0, 2, 1))
|
||||
log_rot = hat_inv(log_rot_hat)
|
||||
|
||||
return log_rot
|
||||
|
||||
|
||||
def hat_inv(h):
|
||||
def hat_inv(h: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute the inverse Hat operator [1] of a batch of 3x3 matrices.
|
||||
|
||||
@ -188,9 +233,9 @@ def hat_inv(h):
|
||||
if dim1 != 3 or dim2 != 3:
|
||||
raise ValueError("Input has to be a batch of 3x3 Tensors.")
|
||||
|
||||
ss_diff = (h + h.permute(0, 2, 1)).abs().max()
|
||||
ss_diff = torch.abs(h + h.permute(0, 2, 1)).max()
|
||||
if float(ss_diff) > HAT_INV_SKEW_SYMMETRIC_TOL:
|
||||
raise ValueError("One of input matrices not skew-symmetric.")
|
||||
raise ValueError("One of input matrices is not skew-symmetric.")
|
||||
|
||||
x = h[:, 2, 1]
|
||||
y = h[:, 0, 2]
|
||||
@ -201,7 +246,7 @@ def hat_inv(h):
|
||||
return v
|
||||
|
||||
|
||||
def hat(v):
|
||||
def hat(v: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute the Hat operator [1] of a batch of 3D vectors.
|
||||
|
||||
@ -225,7 +270,7 @@ def hat(v):
|
||||
if dim != 3:
|
||||
raise ValueError("Input vectors have to be 3-dimensional.")
|
||||
|
||||
h = v.new_zeros(N, 3, 3)
|
||||
h = torch.zeros((N, 3, 3), dtype=v.dtype, device=v.device)
|
||||
|
||||
x, y, z = v.unbind(1)
|
||||
|
||||
|
@ -9,9 +9,10 @@ import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.transforms.so3 import (
|
||||
hat,
|
||||
so3_exponential_map,
|
||||
so3_exp_map,
|
||||
so3_log_map,
|
||||
so3_relative_angle,
|
||||
so3_rotation_angle,
|
||||
)
|
||||
|
||||
|
||||
@ -53,10 +54,10 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
|
||||
def test_determinant(self):
|
||||
"""
|
||||
Tests whether the determinants of 3x3 rotation matrices produced
|
||||
by `so3_exponential_map` are (almost) equal to 1.
|
||||
by `so3_exp_map` are (almost) equal to 1.
|
||||
"""
|
||||
log_rot = TestSO3.init_log_rot(batch_size=30)
|
||||
Rs = so3_exponential_map(log_rot)
|
||||
Rs = so3_exp_map(log_rot)
|
||||
dets = torch.det(Rs)
|
||||
self.assertClose(dets, torch.ones_like(dets), atol=1e-4)
|
||||
|
||||
@ -75,14 +76,14 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
def test_bad_so3_input_value_err(self):
|
||||
"""
|
||||
Tests whether `so3_exponential_map` and `so3_log_map` correctly return
|
||||
Tests whether `so3_exp_map` and `so3_log_map` correctly return
|
||||
a ValueError if called with an argument of incorrect shape or, in case
|
||||
of `so3_exponential_map`, unexpected trace.
|
||||
of `so3_exp_map`, unexpected trace.
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
log_rot = torch.randn(size=[5, 4], device=device)
|
||||
with self.assertRaises(ValueError) as err:
|
||||
so3_exponential_map(log_rot)
|
||||
so3_exp_map(log_rot)
|
||||
self.assertTrue("Input tensor shape has to be Nx3." in str(err.exception))
|
||||
|
||||
rot = torch.randn(size=[5, 3, 5], device=device)
|
||||
@ -106,17 +107,22 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
def test_so3_exp_singularity(self, batch_size: int = 100):
|
||||
"""
|
||||
Tests whether the `so3_exponential_map` is robust to the input vectors
|
||||
Tests whether the `so3_exp_map` is robust to the input vectors
|
||||
the norms of which are close to the numerically unstable region
|
||||
(vectors with low l2-norms).
|
||||
"""
|
||||
# generate random log-rotations with a tiny angle
|
||||
log_rot = TestSO3.init_log_rot(batch_size=batch_size)
|
||||
log_rot_small = log_rot * 1e-6
|
||||
R = so3_exponential_map(log_rot_small)
|
||||
log_rot_small.requires_grad = True
|
||||
R = so3_exp_map(log_rot_small)
|
||||
# tests whether all outputs are finite
|
||||
R_sum = float(R.sum())
|
||||
self.assertEqual(R_sum, R_sum)
|
||||
self.assertTrue(torch.isfinite(R).all())
|
||||
# tests whether the gradient is not None and all finite
|
||||
loss = R.sum()
|
||||
loss.backward()
|
||||
self.assertIsNotNone(log_rot_small.grad)
|
||||
self.assertTrue(torch.isfinite(log_rot_small.grad).all())
|
||||
|
||||
def test_so3_log_singularity(self, batch_size: int = 100):
|
||||
"""
|
||||
@ -129,6 +135,107 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
|
||||
identity = torch.eye(3, device=device)
|
||||
rot180 = identity * torch.tensor([[1.0, -1.0, -1.0]], device=device)
|
||||
r = [identity, rot180]
|
||||
# add random rotations and random almost orthonormal matrices
|
||||
r.extend(
|
||||
[
|
||||
torch.qr(identity + torch.randn_like(identity) * 1e-4)[0]
|
||||
+ float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-3
|
||||
# this adds random noise to the second half
|
||||
# of the random orthogonal matrices to generate
|
||||
# near-orthogonal matrices
|
||||
for i in range(batch_size - 2)
|
||||
]
|
||||
)
|
||||
r = torch.stack(r)
|
||||
r.requires_grad = True
|
||||
# the log of the rotation matrix r
|
||||
r_log = so3_log_map(r, cos_bound=1e-4, eps=1e-2)
|
||||
# tests whether all outputs are finite
|
||||
self.assertTrue(torch.isfinite(r_log).all())
|
||||
# tests whether the gradient is not None and all finite
|
||||
loss = r.sum()
|
||||
loss.backward()
|
||||
self.assertIsNotNone(r.grad)
|
||||
self.assertTrue(torch.isfinite(r.grad).all())
|
||||
|
||||
def test_so3_log_to_exp_to_log_to_exp(self, batch_size: int = 100):
|
||||
"""
|
||||
Check that
|
||||
`so3_exp_map(so3_log_map(so3_exp_map(log_rot)))
|
||||
== so3_exp_map(log_rot)`
|
||||
for a randomly generated batch of rotation matrix logarithms `log_rot`.
|
||||
Unlike `test_so3_log_to_exp_to_log`, this test checks the
|
||||
correctness of converting a `log_rot` which contains values > math.pi.
|
||||
"""
|
||||
log_rot = 2.0 * TestSO3.init_log_rot(batch_size=batch_size)
|
||||
# check also the singular cases where rot. angle = {0, 2pi}
|
||||
log_rot[:2] = 0
|
||||
log_rot[1, 0] = 2.0 * math.pi - 1e-6
|
||||
rot = so3_exp_map(log_rot, eps=1e-4)
|
||||
rot_ = so3_exp_map(so3_log_map(rot, eps=1e-4, cos_bound=1e-6), eps=1e-6)
|
||||
self.assertClose(rot, rot_, atol=0.01)
|
||||
angles = so3_relative_angle(rot, rot_, cos_bound=1e-6)
|
||||
self.assertClose(angles, torch.zeros_like(angles), atol=0.01)
|
||||
|
||||
def test_so3_log_to_exp_to_log(self, batch_size: int = 100):
|
||||
"""
|
||||
Check that `so3_log_map(so3_exp_map(log_rot))==log_rot` for
|
||||
a randomly generated batch of rotation matrix logarithms `log_rot`.
|
||||
"""
|
||||
log_rot = TestSO3.init_log_rot(batch_size=batch_size)
|
||||
# check also the singular cases where rot. angle = 0
|
||||
log_rot[:1] = 0
|
||||
log_rot_ = so3_log_map(so3_exp_map(log_rot))
|
||||
self.assertClose(log_rot, log_rot_, atol=1e-4)
|
||||
|
||||
def test_so3_exp_to_log_to_exp(self, batch_size: int = 100):
|
||||
"""
|
||||
Check that `so3_exp_map(so3_log_map(R))==R` for
|
||||
a batch of randomly generated rotation matrices `R`.
|
||||
"""
|
||||
rot = TestSO3.init_rot(batch_size=batch_size)
|
||||
non_singular = (so3_rotation_angle(rot) - math.pi).abs() > 1e-2
|
||||
rot = rot[non_singular]
|
||||
rot_ = so3_exp_map(so3_log_map(rot, eps=1e-8, cos_bound=1e-8), eps=1e-8)
|
||||
self.assertClose(rot_, rot, atol=0.1)
|
||||
angles = so3_relative_angle(rot, rot_, cos_bound=1e-4)
|
||||
self.assertClose(angles, torch.zeros_like(angles), atol=0.1)
|
||||
|
||||
def test_so3_cos_relative_angle(self, batch_size: int = 100):
|
||||
"""
|
||||
Check that `so3_relative_angle(R1, R2, cos_angle=False).cos()`
|
||||
is the same as `so3_relative_angle(R1, R2, cos_angle=True)` for
|
||||
batches of randomly generated rotation matrices `R1` and `R2`.
|
||||
"""
|
||||
rot1 = TestSO3.init_rot(batch_size=batch_size)
|
||||
rot2 = TestSO3.init_rot(batch_size=batch_size)
|
||||
angles = so3_relative_angle(rot1, rot2, cos_angle=False).cos()
|
||||
angles_ = so3_relative_angle(rot1, rot2, cos_angle=True)
|
||||
self.assertClose(angles, angles_, atol=1e-4)
|
||||
|
||||
def test_so3_cos_angle(self, batch_size: int = 100):
|
||||
"""
|
||||
Check that `so3_rotation_angle(R, cos_angle=False).cos()`
|
||||
is the same as `so3_rotation_angle(R, cos_angle=True)` for
|
||||
a batch of randomly generated rotation matrices `R`.
|
||||
"""
|
||||
rot = TestSO3.init_rot(batch_size=batch_size)
|
||||
angles = so3_rotation_angle(rot, cos_angle=False).cos()
|
||||
angles_ = so3_rotation_angle(rot, cos_angle=True)
|
||||
self.assertClose(angles, angles_, atol=1e-4)
|
||||
|
||||
def test_so3_cos_bound(self, batch_size: int = 100):
|
||||
"""
|
||||
Checks that for an identity rotation `R=I`, the so3_rotation_angle returns
|
||||
non-finite gradients when `cos_bound=None` and finite gradients
|
||||
for `cos_bound > 0.0`.
|
||||
"""
|
||||
# generate random rotations with a tiny angle to generate cases
|
||||
# with the gradient singularity
|
||||
device = torch.device("cuda:0")
|
||||
identity = torch.eye(3, device=device)
|
||||
rot180 = identity * torch.tensor([[1.0, -1.0, -1.0]], device=device)
|
||||
r = [identity, rot180]
|
||||
r.extend(
|
||||
[
|
||||
torch.qr(identity + torch.randn_like(identity) * 1e-4)[0]
|
||||
@ -136,65 +243,25 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
|
||||
]
|
||||
)
|
||||
r = torch.stack(r)
|
||||
# the log of the rotation matrix r
|
||||
r_log = so3_log_map(r)
|
||||
# tests whether all outputs are finite
|
||||
r_sum = float(r_log.sum())
|
||||
self.assertEqual(r_sum, r_sum)
|
||||
|
||||
def test_so3_log_to_exp_to_log_to_exp(self, batch_size: int = 100):
|
||||
"""
|
||||
Check that
|
||||
`so3_exponential_map(so3_log_map(so3_exponential_map(log_rot)))
|
||||
== so3_exponential_map(log_rot)`
|
||||
for a randomly generated batch of rotation matrix logarithms `log_rot`.
|
||||
Unlike `test_so3_log_to_exp_to_log`, this test checks the
|
||||
correctness of converting a `log_rot` which contains values > math.pi.
|
||||
"""
|
||||
log_rot = 2.0 * TestSO3.init_log_rot(batch_size=batch_size)
|
||||
# check also the singular cases where rot. angle = {0, pi, 2pi, 3pi}
|
||||
log_rot[:3] = 0
|
||||
log_rot[1, 0] = math.pi
|
||||
log_rot[2, 0] = 2.0 * math.pi
|
||||
log_rot[3, 0] = 3.0 * math.pi
|
||||
rot = so3_exponential_map(log_rot, eps=1e-8)
|
||||
rot_ = so3_exponential_map(so3_log_map(rot, eps=1e-8), eps=1e-8)
|
||||
angles = so3_relative_angle(rot, rot_)
|
||||
self.assertClose(angles, torch.zeros_like(angles), atol=0.01)
|
||||
|
||||
def test_so3_log_to_exp_to_log(self, batch_size: int = 100):
|
||||
"""
|
||||
Check that `so3_log_map(so3_exponential_map(log_rot))==log_rot` for
|
||||
a randomly generated batch of rotation matrix logarithms `log_rot`.
|
||||
"""
|
||||
log_rot = TestSO3.init_log_rot(batch_size=batch_size)
|
||||
# check also the singular cases where rot. angle = 0
|
||||
log_rot[:1] = 0
|
||||
log_rot_ = so3_log_map(so3_exponential_map(log_rot))
|
||||
self.assertClose(log_rot, log_rot_, atol=1e-4)
|
||||
|
||||
def test_so3_exp_to_log_to_exp(self, batch_size: int = 100):
|
||||
"""
|
||||
Check that `so3_exponential_map(so3_log_map(R))==R` for
|
||||
a batch of randomly generated rotation matrices `R`.
|
||||
"""
|
||||
rot = TestSO3.init_rot(batch_size=batch_size)
|
||||
rot_ = so3_exponential_map(so3_log_map(rot, eps=1e-8), eps=1e-8)
|
||||
angles = so3_relative_angle(rot, rot_)
|
||||
# TODO: a lot of precision lost here ...
|
||||
self.assertClose(angles, torch.zeros_like(angles), atol=0.1)
|
||||
|
||||
def test_so3_cos_angle(self, batch_size: int = 100):
|
||||
"""
|
||||
Check that `so3_relative_angle(R1, R2, cos_angle=False).cos()`
|
||||
is the same as `so3_relative_angle(R1, R2, cos_angle=True)`
|
||||
batches of randomly generated rotation matrices `R1` and `R2`.
|
||||
"""
|
||||
rot1 = TestSO3.init_rot(batch_size=batch_size)
|
||||
rot2 = TestSO3.init_rot(batch_size=batch_size)
|
||||
angles = so3_relative_angle(rot1, rot2, cos_angle=False).cos()
|
||||
angles_ = so3_relative_angle(rot1, rot2, cos_angle=True)
|
||||
self.assertClose(angles, angles_)
|
||||
r.requires_grad = True
|
||||
for is_grad_finite in (True, False):
|
||||
# clear the gradients and decide the cos_bound:
|
||||
# for is_grad_finite we run so3_rotation_angle with cos_bound
|
||||
# set to a small float, otherwise we set to 0.0
|
||||
r.grad = None
|
||||
cos_bound = 1e-4 if is_grad_finite else 0.0
|
||||
# compute the angles of r
|
||||
angles = so3_rotation_angle(r, cos_bound=cos_bound)
|
||||
# tests whether all outputs are finite in both cases
|
||||
self.assertTrue(torch.isfinite(angles).all())
|
||||
# compute the gradients
|
||||
loss = angles.sum()
|
||||
loss.backward()
|
||||
# tests whether the gradient is not None for both cases
|
||||
self.assertIsNotNone(r.grad)
|
||||
if is_grad_finite:
|
||||
# all grad values have to be finite
|
||||
self.assertTrue(torch.isfinite(r.grad).all())
|
||||
|
||||
@staticmethod
|
||||
def so3_expmap(batch_size: int = 10):
|
||||
@ -202,7 +269,7 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def compute_rots():
|
||||
so3_exponential_map(log_rot)
|
||||
so3_exp_map(log_rot)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return compute_rots
|
||||
|
Loading…
x
Reference in New Issue
Block a user