mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-03 04:12:48 +08:00
Summary: Conversion to/from the 6D representation of rotation from the paper http://arxiv.org/abs/1812.07035 ; based on David’s implementation. Reviewed By: davnov134 Differential Revision: D22234397 fbshipit-source-id: 9e25ee93da7e3a2f2068cbe362cb5edc88649ce0
240 lines
6.9 KiB
Python
240 lines
6.9 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
|
|
|
|
import torch
|
|
|
|
|
|
HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5
|
|
|
|
|
|
def so3_relative_angle(R1, R2, cos_angle: bool = False):
|
|
"""
|
|
Calculates the relative angle (in radians) between pairs of
|
|
rotation matrices `R1` and `R2` with `angle = acos(0.5 * (Trace(R1 R2^T)-1))`
|
|
|
|
.. note::
|
|
This corresponds to a geodesic distance on the 3D manifold of rotation
|
|
matrices.
|
|
|
|
Args:
|
|
R1: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
|
|
R2: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
|
|
cos_angle: If==True return cosine of the relative angle rather than
|
|
the angle itself. This can avoid the unstable
|
|
calculation of `acos`.
|
|
|
|
Returns:
|
|
Corresponding rotation angles of shape `(minibatch,)`.
|
|
If `cos_angle==True`, returns the cosine of the angles.
|
|
|
|
Raises:
|
|
ValueError if `R1` or `R2` is of incorrect shape.
|
|
ValueError if `R1` or `R2` has an unexpected trace.
|
|
"""
|
|
R12 = torch.bmm(R1, R2.permute(0, 2, 1))
|
|
return so3_rotation_angle(R12, cos_angle=cos_angle)
|
|
|
|
|
|
def so3_rotation_angle(R, eps: float = 1e-4, cos_angle: bool = False):
|
|
"""
|
|
Calculates angles (in radians) of a batch of rotation matrices `R` with
|
|
`angle = acos(0.5 * (Trace(R)-1))`. The trace of the
|
|
input matrices is checked to be in the valid range `[-1-eps,3+eps]`.
|
|
The `eps` argument is a small constant that allows for small errors
|
|
caused by limited machine precision.
|
|
|
|
Args:
|
|
R: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
|
|
eps: Tolerance for the valid trace check.
|
|
cos_angle: If==True return cosine of the rotation angles rather than
|
|
the angle itself. This can avoid the unstable
|
|
calculation of `acos`.
|
|
|
|
Returns:
|
|
Corresponding rotation angles of shape `(minibatch,)`.
|
|
If `cos_angle==True`, returns the cosine of the angles.
|
|
|
|
Raises:
|
|
ValueError if `R` is of incorrect shape.
|
|
ValueError if `R` has an unexpected trace.
|
|
"""
|
|
|
|
N, dim1, dim2 = R.shape
|
|
if dim1 != 3 or dim2 != 3:
|
|
raise ValueError("Input has to be a batch of 3x3 Tensors.")
|
|
|
|
rot_trace = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2]
|
|
|
|
if ((rot_trace < -1.0 - eps) + (rot_trace > 3.0 + eps)).any():
|
|
raise ValueError("A matrix has trace outside valid range [-1-eps,3+eps].")
|
|
|
|
# clamp to valid range
|
|
rot_trace = torch.clamp(rot_trace, -1.0, 3.0)
|
|
|
|
# phi ... rotation angle
|
|
phi = 0.5 * (rot_trace - 1.0)
|
|
|
|
if cos_angle:
|
|
return phi
|
|
else:
|
|
# pyre-fixme[16]: `float` has no attribute `acos`.
|
|
return phi.acos()
|
|
|
|
|
|
def so3_exponential_map(log_rot, eps: float = 0.0001):
|
|
"""
|
|
Convert a batch of logarithmic representations of rotation matrices `log_rot`
|
|
to a batch of 3x3 rotation matrices using Rodrigues formula [1].
|
|
|
|
In the logarithmic representation, each rotation matrix is represented as
|
|
a 3-dimensional vector (`log_rot`) who's l2-norm and direction correspond
|
|
to the magnitude of the rotation angle and the axis of rotation respectively.
|
|
|
|
The conversion has a singularity around `log(R) = 0`
|
|
which is handled by clamping controlled with the `eps` argument.
|
|
|
|
Args:
|
|
log_rot: Batch of vectors of shape `(minibatch , 3)`.
|
|
eps: A float constant handling the conversion singularity.
|
|
|
|
Returns:
|
|
Batch of rotation matrices of shape `(minibatch , 3 , 3)`.
|
|
|
|
Raises:
|
|
ValueError if `log_rot` is of incorrect shape.
|
|
|
|
[1] https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
|
|
"""
|
|
|
|
_, dim = log_rot.shape
|
|
if dim != 3:
|
|
raise ValueError("Input tensor shape has to be Nx3.")
|
|
|
|
nrms = (log_rot * log_rot).sum(1)
|
|
# phis ... rotation angles
|
|
rot_angles = torch.clamp(nrms, eps).sqrt()
|
|
rot_angles_inv = 1.0 / rot_angles
|
|
fac1 = rot_angles_inv * rot_angles.sin()
|
|
fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles.cos())
|
|
skews = hat(log_rot)
|
|
|
|
R = (
|
|
# pyre-fixme[16]: `float` has no attribute `__getitem__`.
|
|
fac1[:, None, None] * skews
|
|
+ fac2[:, None, None] * torch.bmm(skews, skews)
|
|
+ torch.eye(3, dtype=log_rot.dtype, device=log_rot.device)[None]
|
|
)
|
|
|
|
return R
|
|
|
|
|
|
def so3_log_map(R, eps: float = 0.0001):
|
|
"""
|
|
Convert a batch of 3x3 rotation matrices `R`
|
|
to a batch of 3-dimensional matrix logarithms of rotation matrices
|
|
The conversion has a singularity around `(R=I)` which is handled
|
|
by clamping controlled with the `eps` argument.
|
|
|
|
Args:
|
|
R: batch of rotation matrices of shape `(minibatch, 3, 3)`.
|
|
eps: A float constant handling the conversion singularity.
|
|
|
|
Returns:
|
|
Batch of logarithms of input rotation matrices
|
|
of shape `(minibatch, 3)`.
|
|
|
|
Raises:
|
|
ValueError if `R` is of incorrect shape.
|
|
ValueError if `R` has an unexpected trace.
|
|
"""
|
|
|
|
N, dim1, dim2 = R.shape
|
|
if dim1 != 3 or dim2 != 3:
|
|
raise ValueError("Input has to be a batch of 3x3 Tensors.")
|
|
|
|
phi = so3_rotation_angle(R)
|
|
|
|
phi_sin = phi.sin()
|
|
|
|
phi_denom = (
|
|
torch.clamp(phi_sin.abs(), eps) * phi_sin.sign()
|
|
+ (phi_sin == 0).type_as(phi) * eps
|
|
)
|
|
|
|
log_rot_hat = (phi / (2.0 * phi_denom))[:, None, None] * (R - R.permute(0, 2, 1))
|
|
log_rot = hat_inv(log_rot_hat)
|
|
|
|
return log_rot
|
|
|
|
|
|
def hat_inv(h):
|
|
"""
|
|
Compute the inverse Hat operator [1] of a batch of 3x3 matrices.
|
|
|
|
Args:
|
|
h: Batch of skew-symmetric matrices of shape `(minibatch, 3, 3)`.
|
|
|
|
Returns:
|
|
Batch of 3d vectors of shape `(minibatch, 3, 3)`.
|
|
|
|
Raises:
|
|
ValueError if `h` is of incorrect shape.
|
|
ValueError if `h` not skew-symmetric.
|
|
|
|
[1] https://en.wikipedia.org/wiki/Hat_operator
|
|
"""
|
|
|
|
N, dim1, dim2 = h.shape
|
|
if dim1 != 3 or dim2 != 3:
|
|
raise ValueError("Input has to be a batch of 3x3 Tensors.")
|
|
|
|
ss_diff = (h + h.permute(0, 2, 1)).abs().max()
|
|
if float(ss_diff) > HAT_INV_SKEW_SYMMETRIC_TOL:
|
|
raise ValueError("One of input matrices not skew-symmetric.")
|
|
|
|
x = h[:, 2, 1]
|
|
y = h[:, 0, 2]
|
|
z = h[:, 1, 0]
|
|
|
|
v = torch.stack((x, y, z), dim=1)
|
|
|
|
return v
|
|
|
|
|
|
def hat(v):
|
|
"""
|
|
Compute the Hat operator [1] of a batch of 3D vectors.
|
|
|
|
Args:
|
|
v: Batch of vectors of shape `(minibatch , 3)`.
|
|
|
|
Returns:
|
|
Batch of skew-symmetric matrices of shape
|
|
`(minibatch, 3 , 3)` where each matrix is of the form:
|
|
`[ 0 -v_z v_y ]
|
|
[ v_z 0 -v_x ]
|
|
[ -v_y v_x 0 ]`
|
|
|
|
Raises:
|
|
ValueError if `v` is of incorrect shape.
|
|
|
|
[1] https://en.wikipedia.org/wiki/Hat_operator
|
|
"""
|
|
|
|
N, dim = v.shape
|
|
if dim != 3:
|
|
raise ValueError("Input vectors have to be 3-dimensional.")
|
|
|
|
h = v.new_zeros(N, 3, 3)
|
|
|
|
x, y, z = v.unbind(1)
|
|
|
|
h[:, 0, 1] = -z
|
|
h[:, 0, 2] = y
|
|
h[:, 1, 0] = z
|
|
h[:, 1, 2] = -x
|
|
h[:, 2, 0] = -y
|
|
h[:, 2, 1] = x
|
|
|
|
return h
|