mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Sign issue about quaternion_to_matrix and matrix_to_quaternion
Summary: As reported on github, `matrix_to_quaternion` was incorrect for rotations by 180˚. We resolved the sign of the component `i` based on the sign of `i*r`, assuming `r > 0`, which is untrue if `r == 0`. This diff handles special cases and ensures we use the non-zero elements to copy the sign from. Reviewed By: bottler Differential Revision: D29149465 fbshipit-source-id: cd508cc31567fc37ea3463dd7e8c8e8d5d64a235
This commit is contained in:
parent
a8610e9da4
commit
1b39cebe92
@ -82,7 +82,7 @@ def _copysign(a, b):
|
|||||||
return torch.where(signs_differ, -a, a)
|
return torch.where(signs_differ, -a, a)
|
||||||
|
|
||||||
|
|
||||||
def _sqrt_positive_part(x):
|
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Returns torch.sqrt(torch.max(0, x))
|
Returns torch.sqrt(torch.max(0, x))
|
||||||
but with a zero subgradient where x is 0.
|
but with a zero subgradient where x is 0.
|
||||||
@ -93,7 +93,7 @@ def _sqrt_positive_part(x):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def matrix_to_quaternion(matrix):
|
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Convert rotations given as rotation matrices to quaternions.
|
Convert rotations given as rotation matrices to quaternions.
|
||||||
|
|
||||||
@ -105,17 +105,44 @@ def matrix_to_quaternion(matrix):
|
|||||||
"""
|
"""
|
||||||
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
||||||
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
|
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
|
||||||
m00 = matrix[..., 0, 0]
|
|
||||||
m11 = matrix[..., 1, 1]
|
batch_dim = matrix.shape[:-2]
|
||||||
m22 = matrix[..., 2, 2]
|
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
|
||||||
o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
|
matrix.reshape(*batch_dim, 9), dim=-1
|
||||||
x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
|
)
|
||||||
y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
|
|
||||||
z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
|
q_abs = _sqrt_positive_part(
|
||||||
o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
|
torch.stack(
|
||||||
o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
|
[
|
||||||
o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
|
1.0 + m00 + m11 + m22,
|
||||||
return torch.stack((o0, o1, o2, o3), -1)
|
1.0 + m00 - m11 - m22,
|
||||||
|
1.0 - m00 + m11 - m22,
|
||||||
|
1.0 - m00 - m11 + m22,
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# we produce the desired quaternion multiplied by each of r, i, j, k
|
||||||
|
quat_by_rijk = torch.stack(
|
||||||
|
[
|
||||||
|
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
||||||
|
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
||||||
|
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
||||||
|
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
||||||
|
],
|
||||||
|
dim=-2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# clipping is not important here; if q_abs is small, the candidate won't be picked
|
||||||
|
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].clip(0.1))
|
||||||
|
|
||||||
|
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
||||||
|
# forall i; we pick the best-conditioned one (with the largest denominator)
|
||||||
|
|
||||||
|
return quat_candidates[
|
||||||
|
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : # pyre-ignore[16]
|
||||||
|
].reshape(*batch_dim, 4)
|
||||||
|
|
||||||
|
|
||||||
def _axis_angle_rotation(axis: str, angle):
|
def _axis_angle_rotation(axis: str, angle):
|
||||||
|
@ -4,7 +4,9 @@
|
|||||||
import itertools
|
import itertools
|
||||||
import math
|
import math
|
||||||
import unittest
|
import unittest
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from common_testing import TestCaseMixin
|
from common_testing import TestCaseMixin
|
||||||
from pytorch3d.transforms.rotation_conversions import (
|
from pytorch3d.transforms.rotation_conversions import (
|
||||||
@ -64,7 +66,7 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
|||||||
"""quat -> mtx -> quat"""
|
"""quat -> mtx -> quat"""
|
||||||
data = random_quaternions(13, dtype=torch.float64)
|
data = random_quaternions(13, dtype=torch.float64)
|
||||||
mdata = matrix_to_quaternion(quaternion_to_matrix(data))
|
mdata = matrix_to_quaternion(quaternion_to_matrix(data))
|
||||||
self.assertClose(data, mdata)
|
self._assert_quaternions_close(data, mdata)
|
||||||
|
|
||||||
def test_to_quat(self):
|
def test_to_quat(self):
|
||||||
"""mtx -> quat -> mtx"""
|
"""mtx -> quat -> mtx"""
|
||||||
@ -146,8 +148,7 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
|||||||
b_matrix = quaternion_to_matrix(b)
|
b_matrix = quaternion_to_matrix(b)
|
||||||
ab_matrix = torch.matmul(a_matrix, b_matrix)
|
ab_matrix = torch.matmul(a_matrix, b_matrix)
|
||||||
ab_from_matrix = matrix_to_quaternion(ab_matrix)
|
ab_from_matrix = matrix_to_quaternion(ab_matrix)
|
||||||
self.assertEqual(ab.shape, ab_from_matrix.shape)
|
self._assert_quaternions_close(ab, ab_from_matrix)
|
||||||
self.assertClose(ab, ab_from_matrix)
|
|
||||||
|
|
||||||
def test_matrix_to_quaternion_corner_case(self):
|
def test_matrix_to_quaternion_corner_case(self):
|
||||||
"""Check no bad gradients from sqrt(0)."""
|
"""Check no bad gradients from sqrt(0)."""
|
||||||
@ -161,7 +162,34 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
self.assertClose(matrix, 0.95 * torch.eye(3))
|
self.assertClose(matrix, matrix, msg="Result has non-finite values")
|
||||||
|
delta = 1e-2
|
||||||
|
self.assertLess(
|
||||||
|
matrix.trace(),
|
||||||
|
3.0 - delta,
|
||||||
|
msg="Identity initialisation unchanged by a gradient step",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_matrix_to_quaternion_by_pi(self):
|
||||||
|
# We check that rotations by pi around each of the 26
|
||||||
|
# nonzero vectors containing nothing but 0, 1 and -1
|
||||||
|
# are mapped to the right quaternions.
|
||||||
|
# This is representative across the directions.
|
||||||
|
options = [0.0, -1.0, 1.0]
|
||||||
|
axes = [
|
||||||
|
torch.tensor(vec)
|
||||||
|
for vec in itertools.islice( # exclude [0, 0, 0]
|
||||||
|
itertools.product(options, options, options), 1, None
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
axes = torch.nn.functional.normalize(torch.stack(axes), dim=-1)
|
||||||
|
# Rotation by pi around unit vector x is given by
|
||||||
|
# the matrix 2 x x^T - Id.
|
||||||
|
R = 2 * torch.matmul(axes[..., None], axes[..., None, :]) - torch.eye(3)
|
||||||
|
quats_hat = matrix_to_quaternion(R)
|
||||||
|
R_hat = quaternion_to_matrix(quats_hat)
|
||||||
|
self.assertClose(R, R_hat, atol=1e-3)
|
||||||
|
|
||||||
def test_from_axis_angle(self):
|
def test_from_axis_angle(self):
|
||||||
"""axis_angle -> mtx -> axis_angle"""
|
"""axis_angle -> mtx -> axis_angle"""
|
||||||
@ -228,3 +256,20 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
|||||||
self.assertClose(
|
self.assertClose(
|
||||||
torch.matmul(r, r.permute(0, 2, 1)), torch.eye(3).expand_as(r), atol=1e-6
|
torch.matmul(r, r.permute(0, 2, 1)), torch.eye(3).expand_as(r), atol=1e-6
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _assert_quaternions_close(
|
||||||
|
self,
|
||||||
|
input: Union[torch.Tensor, np.ndarray],
|
||||||
|
other: Union[torch.Tensor, np.ndarray],
|
||||||
|
*,
|
||||||
|
rtol: float = 1e-05,
|
||||||
|
atol: float = 1e-08,
|
||||||
|
equal_nan: bool = False,
|
||||||
|
msg: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.assertEqual(np.shape(input), np.shape(other))
|
||||||
|
dot = (input * other).sum(-1)
|
||||||
|
ones = torch.ones_like(dot)
|
||||||
|
self.assertClose(
|
||||||
|
dot.abs(), ones, rtol=rtol, atol=atol, equal_nan=equal_nan, msg=msg
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user