mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
6D representation of rotations.
Summary: Conversion to/from the 6D representation of rotation from the paper http://arxiv.org/abs/1812.07035 ; based on David’s implementation. Reviewed By: davnov134 Differential Revision: D22234397 fbshipit-source-id: 9e25ee93da7e3a2f2068cbe362cb5edc88649ce0
This commit is contained in:
parent
ce3da64917
commit
2f3cd98725
@ -4,6 +4,7 @@ import functools
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -252,7 +253,7 @@ def random_quaternions(
|
|||||||
i.e. versors with nonnegative real part.
|
i.e. versors with nonnegative real part.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n: Number to return.
|
n: Number of quaternions in a batch to return.
|
||||||
dtype: Type to return.
|
dtype: Type to return.
|
||||||
device: Desired device of returned tensor. Default:
|
device: Desired device of returned tensor. Default:
|
||||||
uses the current device for the default tensor type.
|
uses the current device for the default tensor type.
|
||||||
@ -275,7 +276,7 @@ def random_rotations(
|
|||||||
Generate random rotations as 3x3 rotation matrices.
|
Generate random rotations as 3x3 rotation matrices.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n: Number to return.
|
n: Number of rotation matrices in a batch to return.
|
||||||
dtype: Type to return.
|
dtype: Type to return.
|
||||||
device: Device of returned tensor. Default: if None,
|
device: Device of returned tensor. Default: if None,
|
||||||
uses the current device for the default tensor type.
|
uses the current device for the default tensor type.
|
||||||
@ -400,3 +401,45 @@ def quaternion_apply(quaternion, point):
|
|||||||
quaternion_invert(quaternion),
|
quaternion_invert(quaternion),
|
||||||
)
|
)
|
||||||
return out[..., 1:]
|
return out[..., 1:]
|
||||||
|
|
||||||
|
|
||||||
|
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
|
||||||
|
using Gram--Schmidt orthogonalisation per Section B of [1].
|
||||||
|
Args:
|
||||||
|
d6: 6D rotation representation, of size (*, 6)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
batch of rotation matrices of size (*, 3, 3)
|
||||||
|
|
||||||
|
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
|
||||||
|
On the Continuity of Rotation Representations in Neural Networks.
|
||||||
|
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
||||||
|
Retrieved from http://arxiv.org/abs/1812.07035
|
||||||
|
"""
|
||||||
|
|
||||||
|
a1, a2 = d6[..., :3], d6[..., 3:]
|
||||||
|
b1 = F.normalize(a1, dim=-1)
|
||||||
|
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
|
||||||
|
b2 = F.normalize(b2, dim=-1)
|
||||||
|
b3 = torch.cross(b1, b2, dim=-1)
|
||||||
|
return torch.stack((b1, b2, b3), dim=-2)
|
||||||
|
|
||||||
|
|
||||||
|
def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
|
||||||
|
by dropping the last row. Note that 6D representation is not unique.
|
||||||
|
Args:
|
||||||
|
matrix: batch of rotation matrices of size (*, 3, 3)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
6D rotation representation, of size (*, 6)
|
||||||
|
|
||||||
|
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
|
||||||
|
On the Continuity of Rotation Representations in Neural Networks.
|
||||||
|
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
||||||
|
Retrieved from http://arxiv.org/abs/1812.07035
|
||||||
|
"""
|
||||||
|
return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)
|
||||||
|
@ -10,7 +10,7 @@ HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5
|
|||||||
def so3_relative_angle(R1, R2, cos_angle: bool = False):
|
def so3_relative_angle(R1, R2, cos_angle: bool = False):
|
||||||
"""
|
"""
|
||||||
Calculates the relative angle (in radians) between pairs of
|
Calculates the relative angle (in radians) between pairs of
|
||||||
rotation matrices `R1` and `R2` with `angle = acos(0.5 * Trace(R1 R2^T)-1)`
|
rotation matrices `R1` and `R2` with `angle = acos(0.5 * (Trace(R1 R2^T)-1))`
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
This corresponds to a geodesic distance on the 3D manifold of rotation
|
This corresponds to a geodesic distance on the 3D manifold of rotation
|
||||||
|
@ -45,9 +45,9 @@ class Transform3d:
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
y1 = t3.transform_points(t2.transform_points(t2.transform_points(x)))
|
y1 = t3.transform_points(t2.transform_points(t1.transform_points(x)))
|
||||||
y2 = t1.compose(t2).compose(t3).transform_points()
|
y2 = t1.compose(t2).compose(t3).transform_points(x)
|
||||||
y3 = t1.compose(t2, t3).transform_points()
|
y3 = t1.compose(t2, t3).transform_points(x)
|
||||||
|
|
||||||
|
|
||||||
Composing transforms should broadcast.
|
Composing transforms should broadcast.
|
||||||
|
@ -6,16 +6,19 @@ import math
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from common_testing import TestCaseMixin
|
||||||
from pytorch3d.transforms.rotation_conversions import (
|
from pytorch3d.transforms.rotation_conversions import (
|
||||||
euler_angles_to_matrix,
|
euler_angles_to_matrix,
|
||||||
matrix_to_euler_angles,
|
matrix_to_euler_angles,
|
||||||
matrix_to_quaternion,
|
matrix_to_quaternion,
|
||||||
|
matrix_to_rotation_6d,
|
||||||
quaternion_apply,
|
quaternion_apply,
|
||||||
quaternion_multiply,
|
quaternion_multiply,
|
||||||
quaternion_to_matrix,
|
quaternion_to_matrix,
|
||||||
random_quaternions,
|
random_quaternions,
|
||||||
random_rotation,
|
random_rotation,
|
||||||
random_rotations,
|
random_rotations,
|
||||||
|
rotation_6d_to_matrix,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -48,7 +51,7 @@ class TestRandomRotation(unittest.TestCase):
|
|||||||
self.assertLess(chisquare_statistic, 12, (counts, chisquare_statistic, k))
|
self.assertLess(chisquare_statistic, 12, (counts, chisquare_statistic, k))
|
||||||
|
|
||||||
|
|
||||||
class TestRotationConversion(unittest.TestCase):
|
class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
super().setUp()
|
super().setUp()
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
@ -154,3 +157,31 @@ class TestRotationConversion(unittest.TestCase):
|
|||||||
[p, q] = torch.autograd.grad(transform1.sum(), [points, quaternions])
|
[p, q] = torch.autograd.grad(transform1.sum(), [points, quaternions])
|
||||||
self.assertTrue(torch.isfinite(p).all())
|
self.assertTrue(torch.isfinite(p).all())
|
||||||
self.assertTrue(torch.isfinite(q).all())
|
self.assertTrue(torch.isfinite(q).all())
|
||||||
|
|
||||||
|
def test_6d(self):
|
||||||
|
"""Converting to 6d and back"""
|
||||||
|
r = random_rotations(13, dtype=torch.float64)
|
||||||
|
|
||||||
|
# 6D representation is not unique,
|
||||||
|
# but we implement it by taking the first two rows of the matrix
|
||||||
|
r6d = matrix_to_rotation_6d(r)
|
||||||
|
self.assertClose(r6d, r[:, :2, :].reshape(-1, 6))
|
||||||
|
|
||||||
|
# going to 6D and back should not change the matrix
|
||||||
|
r_hat = rotation_6d_to_matrix(r6d)
|
||||||
|
self.assertClose(r_hat, r)
|
||||||
|
|
||||||
|
# moving the second row R2 in the span of (R1, R2) should not matter
|
||||||
|
r6d[:, 3:] += 2 * r6d[:, :3]
|
||||||
|
r6d[:, :3] *= 3.0
|
||||||
|
r_hat = rotation_6d_to_matrix(r6d)
|
||||||
|
self.assertClose(r_hat, r)
|
||||||
|
|
||||||
|
# check that we map anything to a valid rotation
|
||||||
|
r6d = torch.rand(13, 6)
|
||||||
|
r6d[:4, :] *= 3.0
|
||||||
|
r6d[4:8, :] -= 0.5
|
||||||
|
r = rotation_6d_to_matrix(r6d)
|
||||||
|
self.assertClose(
|
||||||
|
torch.matmul(r, r.permute(0, 2, 1)), torch.eye(3).expand_as(r), atol=1e-6
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user