From 2f3cd987253c5da3741af9ceaec95c1bdbf12583 Mon Sep 17 00:00:00 2001 From: Roman Shapovalov Date: Wed, 8 Jul 2020 03:59:51 -0700 Subject: [PATCH] 6D representation of rotations. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Conversion to/from the 6D representation of rotation from the paper http://arxiv.org/abs/1812.07035 ; based on David’s implementation. Reviewed By: davnov134 Differential Revision: D22234397 fbshipit-source-id: 9e25ee93da7e3a2f2068cbe362cb5edc88649ce0 --- pytorch3d/transforms/rotation_conversions.py | 47 +++++++++++++++++++- pytorch3d/transforms/so3.py | 2 +- pytorch3d/transforms/transform3d.py | 6 +-- tests/test_rotation_conversions.py | 33 +++++++++++++- 4 files changed, 81 insertions(+), 7 deletions(-) diff --git a/pytorch3d/transforms/rotation_conversions.py b/pytorch3d/transforms/rotation_conversions.py index a30e3759..e322280d 100644 --- a/pytorch3d/transforms/rotation_conversions.py +++ b/pytorch3d/transforms/rotation_conversions.py @@ -4,6 +4,7 @@ import functools from typing import Optional import torch +import torch.nn.functional as F """ @@ -252,7 +253,7 @@ def random_quaternions( i.e. versors with nonnegative real part. Args: - n: Number to return. + n: Number of quaternions in a batch to return. dtype: Type to return. device: Desired device of returned tensor. Default: uses the current device for the default tensor type. @@ -275,7 +276,7 @@ def random_rotations( Generate random rotations as 3x3 rotation matrices. Args: - n: Number to return. + n: Number of rotation matrices in a batch to return. dtype: Type to return. device: Device of returned tensor. Default: if None, uses the current device for the default tensor type. @@ -400,3 +401,45 @@ def quaternion_apply(quaternion, point): quaternion_invert(quaternion), ) return out[..., 1:] + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalisation per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + + Returns: + 6D rotation representation, of size (*, 6) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6) diff --git a/pytorch3d/transforms/so3.py b/pytorch3d/transforms/so3.py index 9345a678..8371d8a6 100644 --- a/pytorch3d/transforms/so3.py +++ b/pytorch3d/transforms/so3.py @@ -10,7 +10,7 @@ HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5 def so3_relative_angle(R1, R2, cos_angle: bool = False): """ Calculates the relative angle (in radians) between pairs of - rotation matrices `R1` and `R2` with `angle = acos(0.5 * Trace(R1 R2^T)-1)` + rotation matrices `R1` and `R2` with `angle = acos(0.5 * (Trace(R1 R2^T)-1))` .. note:: This corresponds to a geodesic distance on the 3D manifold of rotation diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index 21ec6e9e..ac061cbc 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -45,9 +45,9 @@ class Transform3d: .. code-block:: python - y1 = t3.transform_points(t2.transform_points(t2.transform_points(x))) - y2 = t1.compose(t2).compose(t3).transform_points() - y3 = t1.compose(t2, t3).transform_points() + y1 = t3.transform_points(t2.transform_points(t1.transform_points(x))) + y2 = t1.compose(t2).compose(t3).transform_points(x) + y3 = t1.compose(t2, t3).transform_points(x) Composing transforms should broadcast. diff --git a/tests/test_rotation_conversions.py b/tests/test_rotation_conversions.py index a65b00cf..f8bc60fb 100644 --- a/tests/test_rotation_conversions.py +++ b/tests/test_rotation_conversions.py @@ -6,16 +6,19 @@ import math import unittest import torch +from common_testing import TestCaseMixin from pytorch3d.transforms.rotation_conversions import ( euler_angles_to_matrix, matrix_to_euler_angles, matrix_to_quaternion, + matrix_to_rotation_6d, quaternion_apply, quaternion_multiply, quaternion_to_matrix, random_quaternions, random_rotation, random_rotations, + rotation_6d_to_matrix, ) @@ -48,7 +51,7 @@ class TestRandomRotation(unittest.TestCase): self.assertLess(chisquare_statistic, 12, (counts, chisquare_statistic, k)) -class TestRotationConversion(unittest.TestCase): +class TestRotationConversion(TestCaseMixin, unittest.TestCase): def setUp(self) -> None: super().setUp() torch.manual_seed(1) @@ -154,3 +157,31 @@ class TestRotationConversion(unittest.TestCase): [p, q] = torch.autograd.grad(transform1.sum(), [points, quaternions]) self.assertTrue(torch.isfinite(p).all()) self.assertTrue(torch.isfinite(q).all()) + + def test_6d(self): + """Converting to 6d and back""" + r = random_rotations(13, dtype=torch.float64) + + # 6D representation is not unique, + # but we implement it by taking the first two rows of the matrix + r6d = matrix_to_rotation_6d(r) + self.assertClose(r6d, r[:, :2, :].reshape(-1, 6)) + + # going to 6D and back should not change the matrix + r_hat = rotation_6d_to_matrix(r6d) + self.assertClose(r_hat, r) + + # moving the second row R2 in the span of (R1, R2) should not matter + r6d[:, 3:] += 2 * r6d[:, :3] + r6d[:, :3] *= 3.0 + r_hat = rotation_6d_to_matrix(r6d) + self.assertClose(r_hat, r) + + # check that we map anything to a valid rotation + r6d = torch.rand(13, 6) + r6d[:4, :] *= 3.0 + r6d[4:8, :] -= 0.5 + r = rotation_6d_to_matrix(r6d) + self.assertClose( + torch.matmul(r, r.permute(0, 2, 1)), torch.eye(3).expand_as(r), atol=1e-6 + )