6D representation of rotations.

Summary: Conversion to/from the 6D representation of rotation from the paper http://arxiv.org/abs/1812.07035 ; based on David’s implementation.

Reviewed By: davnov134

Differential Revision: D22234397

fbshipit-source-id: 9e25ee93da7e3a2f2068cbe362cb5edc88649ce0
This commit is contained in:
Roman Shapovalov
2020-07-08 03:59:51 -07:00
committed by Facebook GitHub Bot
parent ce3da64917
commit 2f3cd98725
4 changed files with 81 additions and 7 deletions

View File

@@ -4,6 +4,7 @@ import functools
from typing import Optional
import torch
import torch.nn.functional as F
"""
@@ -252,7 +253,7 @@ def random_quaternions(
i.e. versors with nonnegative real part.
Args:
n: Number to return.
n: Number of quaternions in a batch to return.
dtype: Type to return.
device: Desired device of returned tensor. Default:
uses the current device for the default tensor type.
@@ -275,7 +276,7 @@ def random_rotations(
Generate random rotations as 3x3 rotation matrices.
Args:
n: Number to return.
n: Number of rotation matrices in a batch to return.
dtype: Type to return.
device: Device of returned tensor. Default: if None,
uses the current device for the default tensor type.
@@ -400,3 +401,45 @@ def quaternion_apply(quaternion, point):
quaternion_invert(quaternion),
)
return out[..., 1:]
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
"""
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
using Gram--Schmidt orthogonalisation per Section B of [1].
Args:
d6: 6D rotation representation, of size (*, 6)
Returns:
batch of rotation matrices of size (*, 3, 3)
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
On the Continuity of Rotation Representations in Neural Networks.
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
Retrieved from http://arxiv.org/abs/1812.07035
"""
a1, a2 = d6[..., :3], d6[..., 3:]
b1 = F.normalize(a1, dim=-1)
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
b2 = F.normalize(b2, dim=-1)
b3 = torch.cross(b1, b2, dim=-1)
return torch.stack((b1, b2, b3), dim=-2)
def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
"""
Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
by dropping the last row. Note that 6D representation is not unique.
Args:
matrix: batch of rotation matrices of size (*, 3, 3)
Returns:
6D rotation representation, of size (*, 6)
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
On the Continuity of Rotation Representations in Neural Networks.
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
Retrieved from http://arxiv.org/abs/1812.07035
"""
return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)

View File

@@ -10,7 +10,7 @@ HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5
def so3_relative_angle(R1, R2, cos_angle: bool = False):
"""
Calculates the relative angle (in radians) between pairs of
rotation matrices `R1` and `R2` with `angle = acos(0.5 * Trace(R1 R2^T)-1)`
rotation matrices `R1` and `R2` with `angle = acos(0.5 * (Trace(R1 R2^T)-1))`
.. note::
This corresponds to a geodesic distance on the 3D manifold of rotation

View File

@@ -45,9 +45,9 @@ class Transform3d:
.. code-block:: python
y1 = t3.transform_points(t2.transform_points(t2.transform_points(x)))
y2 = t1.compose(t2).compose(t3).transform_points()
y3 = t1.compose(t2, t3).transform_points()
y1 = t3.transform_points(t2.transform_points(t1.transform_points(x)))
y2 = t1.compose(t2).compose(t3).transform_points(x)
y3 = t1.compose(t2, t3).transform_points(x)
Composing transforms should broadcast.