mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
Sign issue about quaternion_to_matrix and matrix_to_quaternion
Summary: As reported on github, `matrix_to_quaternion` was incorrect for rotations by 180˚. We resolved the sign of the component `i` based on the sign of `i*r`, assuming `r > 0`, which is untrue if `r == 0`. This diff handles special cases and ensures we use the non-zero elements to copy the sign from. Reviewed By: bottler Differential Revision: D29149465 fbshipit-source-id: cd508cc31567fc37ea3463dd7e8c8e8d5d64a235
This commit is contained in:
committed by
Facebook GitHub Bot
parent
a8610e9da4
commit
1b39cebe92
@@ -4,7 +4,9 @@
|
||||
import itertools
|
||||
import math
|
||||
import unittest
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.transforms.rotation_conversions import (
|
||||
@@ -64,7 +66,7 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
||||
"""quat -> mtx -> quat"""
|
||||
data = random_quaternions(13, dtype=torch.float64)
|
||||
mdata = matrix_to_quaternion(quaternion_to_matrix(data))
|
||||
self.assertClose(data, mdata)
|
||||
self._assert_quaternions_close(data, mdata)
|
||||
|
||||
def test_to_quat(self):
|
||||
"""mtx -> quat -> mtx"""
|
||||
@@ -146,8 +148,7 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
||||
b_matrix = quaternion_to_matrix(b)
|
||||
ab_matrix = torch.matmul(a_matrix, b_matrix)
|
||||
ab_from_matrix = matrix_to_quaternion(ab_matrix)
|
||||
self.assertEqual(ab.shape, ab_from_matrix.shape)
|
||||
self.assertClose(ab, ab_from_matrix)
|
||||
self._assert_quaternions_close(ab, ab_from_matrix)
|
||||
|
||||
def test_matrix_to_quaternion_corner_case(self):
|
||||
"""Check no bad gradients from sqrt(0)."""
|
||||
@@ -161,7 +162,34 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
self.assertClose(matrix, 0.95 * torch.eye(3))
|
||||
self.assertClose(matrix, matrix, msg="Result has non-finite values")
|
||||
delta = 1e-2
|
||||
self.assertLess(
|
||||
matrix.trace(),
|
||||
3.0 - delta,
|
||||
msg="Identity initialisation unchanged by a gradient step",
|
||||
)
|
||||
|
||||
def test_matrix_to_quaternion_by_pi(self):
|
||||
# We check that rotations by pi around each of the 26
|
||||
# nonzero vectors containing nothing but 0, 1 and -1
|
||||
# are mapped to the right quaternions.
|
||||
# This is representative across the directions.
|
||||
options = [0.0, -1.0, 1.0]
|
||||
axes = [
|
||||
torch.tensor(vec)
|
||||
for vec in itertools.islice( # exclude [0, 0, 0]
|
||||
itertools.product(options, options, options), 1, None
|
||||
)
|
||||
]
|
||||
|
||||
axes = torch.nn.functional.normalize(torch.stack(axes), dim=-1)
|
||||
# Rotation by pi around unit vector x is given by
|
||||
# the matrix 2 x x^T - Id.
|
||||
R = 2 * torch.matmul(axes[..., None], axes[..., None, :]) - torch.eye(3)
|
||||
quats_hat = matrix_to_quaternion(R)
|
||||
R_hat = quaternion_to_matrix(quats_hat)
|
||||
self.assertClose(R, R_hat, atol=1e-3)
|
||||
|
||||
def test_from_axis_angle(self):
|
||||
"""axis_angle -> mtx -> axis_angle"""
|
||||
@@ -228,3 +256,20 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(
|
||||
torch.matmul(r, r.permute(0, 2, 1)), torch.eye(3).expand_as(r), atol=1e-6
|
||||
)
|
||||
|
||||
def _assert_quaternions_close(
|
||||
self,
|
||||
input: Union[torch.Tensor, np.ndarray],
|
||||
other: Union[torch.Tensor, np.ndarray],
|
||||
*,
|
||||
rtol: float = 1e-05,
|
||||
atol: float = 1e-08,
|
||||
equal_nan: bool = False,
|
||||
msg: Optional[str] = None,
|
||||
):
|
||||
self.assertEqual(np.shape(input), np.shape(other))
|
||||
dot = (input * other).sum(-1)
|
||||
ones = torch.ones_like(dot)
|
||||
self.assertClose(
|
||||
dot.abs(), ones, rtol=rtol, atol=atol, equal_nan=equal_nan, msg=msg
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user