Initial commit

fbshipit-source-id: ad58e416e3ceeca85fae0583308968d04e78fe0d
This commit is contained in:
facebook-github-bot
2020-01-23 11:53:41 -08:00
commit dbf06b504b
211 changed files with 47362 additions and 0 deletions

View File

@@ -0,0 +1,374 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import functools
import torch
def quaternion_to_matrix(quaternions):
"""
Convert rotations given as quaternions to rotation matrices.
Args:
quaternions: quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
r, i, j, k = torch.unbind(quaternions, -1)
two_s = 2.0 / (quaternions * quaternions).sum(-1)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))
def _copysign(a, b):
"""
Return a tensor where each element has the absolute value taken from the,
corresponding element of a, with sign taken from the corresponding
element of b. This is like the standard copysign floating-point operation,
but is not careful about negative 0 and NaN.
Args:
a: source tensor.
b: tensor whose signs will be used, of the same shape as a.
Returns:
Tensor of the same shape as a with the signs of b.
"""
signs_differ = (a < 0) != (b < 0)
return torch.where(signs_differ, -a, a)
def matrix_to_quaternion(matrix):
"""
Convert rotations given as rotation matrices to quaternions.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
quaternions with real part first, as tensor of shape (..., 4).
"""
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
zero = matrix.new_zeros((1,))
m00 = matrix[..., 0, 0]
m11 = matrix[..., 1, 1]
m22 = matrix[..., 2, 2]
o0 = 0.5 * torch.sqrt(torch.max(zero, 1 + m00 + m11 + m22))
x = 0.5 * torch.sqrt(torch.max(zero, 1 + m00 - m11 - m22))
y = 0.5 * torch.sqrt(torch.max(zero, 1 - m00 + m11 - m22))
z = 0.5 * torch.sqrt(torch.max(zero, 1 - m00 - m11 + m22))
o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
return torch.stack((o0, o1, o2, o3), -1)
def _primary_matrix(axis: str, angle):
"""
Return the rotation matrices for one of the rotations about an axis
of which Euler angles describe, for each value of the angle given.
Args:
axis: Axis label "X" or "Y or "Z".
angle: any shape tensor of Euler angles in radians
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
cos = torch.cos(angle)
sin = torch.sin(angle)
one = torch.ones_like(angle)
zero = torch.zeros_like(angle)
if axis == "X":
o = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
if axis == "Y":
o = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
if axis == "Z":
o = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
return torch.stack(o, -1).reshape(angle.shape + (3, 3))
def euler_angles_to_matrix(euler_angles, convention: str):
"""
Convert rotations given as Euler angles in radians to rotation matrices.
Args:
euler_angles: Euler angles in radians as tensor of shape (..., 3).
convention: Convention string of three uppercase letters from
{"X", "Y", and "Z"}.
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
raise ValueError("Invalid input euler angles.")
if len(convention) != 3:
raise ValueError("Convention must have 3 letters.")
if convention[1] in (convention[0], convention[2]):
raise ValueError(f"Invalid convention {convention}.")
for letter in convention:
if letter not in ("X", "Y", "Z"):
raise ValueError(f"Invalid letter {letter} in convention string.")
matrices = map(_primary_matrix, convention, torch.unbind(euler_angles, -1))
return functools.reduce(torch.matmul, matrices)
def _angle_from_tan(
axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
):
"""
Extract the first or third Euler angle from the two members of
the matrix which are positive constant times its sine and cosine.
Args:
axis: Axis label "X" or "Y or "Z" for the angle we are finding.
other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
convention.
data: Rotation matrices as tensor of shape (..., 3, 3).
horizontal: Whether we are looking for the angle for the third axis,
which means the relevant entries are in the same row of the
rotation matrix. If not, they are in the same column.
tait_bryan: Whether the first and third axes in the convention differ.
Returns:
Euler Angles in radians for each matrix in data as a tensor
of shape (...).
"""
i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
if horizontal:
i2, i1 = i1, i2
even = (axis + other_axis) in ["XY", "YZ", "ZX"]
if horizontal == even:
return torch.atan2(data[..., i1], data[..., i2])
if tait_bryan:
return torch.atan2(-data[..., i2], data[..., i1])
return torch.atan2(data[..., i2], -data[..., i1])
def _index_from_letter(letter: str):
if letter == "X":
return 0
if letter == "Y":
return 1
if letter == "Z":
return 2
def matrix_to_euler_angles(matrix, convention: str):
"""
Convert rotations given as rotation matrices to Euler angles in radians.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
convention: Convention string of three uppercase letters.
Returns:
Euler angles in radians as tensor of shape (..., 3).
"""
if len(convention) != 3:
raise ValueError("Convention must have 3 letters.")
if convention[1] in (convention[0], convention[2]):
raise ValueError(f"Invalid convention {convention}.")
for letter in convention:
if letter not in ("X", "Y", "Z"):
raise ValueError(f"Invalid letter {letter} in convention string.")
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
i0 = _index_from_letter(convention[0])
i2 = _index_from_letter(convention[2])
tait_bryan = i0 != i2
if tait_bryan:
central_angle = torch.asin(
matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
)
else:
central_angle = torch.acos(matrix[..., i0, i0])
o = (
_angle_from_tan(
convention[0], convention[1], matrix[..., i2], False, tait_bryan
),
central_angle,
_angle_from_tan(
convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
),
)
return torch.stack(o, -1)
def random_quaternions(
n: int, dtype: torch.dtype = None, device=None, requires_grad=False
):
"""
Generate random quaternions representing rotations,
i.e. versors with nonnegative real part.
Args:
n: Number to return.
dtype: Type to return.
device: Desired device of returned tensor. Default:
uses the current device for the default tensor type.
requires_grad: Whether the resulting tensor should have the gradient
flag set.
Returns:
Quaternions as tensor of shape (N, 4).
"""
o = torch.randn(
(n, 4), dtype=dtype, device=device, requires_grad=requires_grad
)
s = (o * o).sum(1)
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
return o
def random_rotations(
n: int, dtype: torch.dtype = None, device=None, requires_grad=False
):
"""
Generate random rotations as 3x3 rotation matrices.
Args:
n: Number to return.
dtype: Type to return.
device: Device of returned tensor. Default: if None,
uses the current device for the default tensor type.
requires_grad: Whether the resulting tensor should have the gradient
flag set.
Returns:
Rotation matrices as tensor of shape (n, 3, 3).
"""
quaternions = random_quaternions(
n, dtype=dtype, device=device, requires_grad=requires_grad
)
return quaternion_to_matrix(quaternions)
def random_rotation(
dtype: torch.dtype = None, device=None, requires_grad=False
):
"""
Generate a single random 3x3 rotation matrix.
Args:
dtype: Type to return
device: Device of returned tensor. Default: if None,
uses the current device for the default tensor type
requires_grad: Whether the resulting tensor should have the gradient
flag set
Returns:
Rotation matrix as tensor of shape (3, 3).
"""
return random_rotations(1, dtype, device, requires_grad)[0]
def standardize_quaternion(quaternions):
"""
Convert a unit quaternion to a standard form: one in which the real
part is non negative.
Args:
quaternions: Quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Standardized quaternions as tensor of shape (..., 4).
"""
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
def quaternion_raw_multiply(a, b):
"""
Multiply two quaternions.
Usual torch rules for broadcasting apply.
Args:
a: Quaternions as tensor of shape (..., 4), real part first.
b: Quaternions as tensor of shape (..., 4), real part first.
Returns:
The product of a and b, a tensor of quaternions shape (..., 4).
"""
aw, ax, ay, az = torch.unbind(a, -1)
bw, bx, by, bz = torch.unbind(b, -1)
ow = aw * bw - ax * bx - ay * by - az * bz
ox = aw * bx + ax * bw + ay * bz - az * by
oy = aw * by - ax * bz + ay * bw + az * bx
oz = aw * bz + ax * by - ay * bx + az * bw
return torch.stack((ow, ox, oy, oz), -1)
def quaternion_multiply(a, b):
"""
Multiply two quaternions representing rotations, returning the quaternion
representing their composition, i.e. the versor with nonnegative real part.
Usual torch rules for broadcasting apply.
Args:
a: Quaternions as tensor of shape (..., 4), real part first.
b: Quaternions as tensor of shape (..., 4), real part first.
Returns:
The product of a and b, a tensor of quaternions of shape (..., 4).
"""
ab = quaternion_raw_multiply(a, b)
return standardize_quaternion(ab)
def quaternion_invert(quaternion):
"""
Given a quaternion representing rotation, get the quaternion representing
its inverse.
Args:
quaternion: Quaternions as tensor of shape (..., 4), with real part
first, which must be versors (unit quaternions).
Returns:
The inverse, a tensor of quaternions of shape (..., 4).
"""
return quaternion * quaternion.new_tensor([1, -1, -1, -1])
def quaternion_apply(quaternion, point):
"""
Apply the rotation given by a quaternion to a 3D point.
Usual torch rules for broadcasting apply.
Args:
quaternion: Tensor of quaternions, real part first, of shape (..., 4).
point: Tensor of 3D points of shape (..., 3).
Returns:
Tensor of rotated points of shape (..., 3).
"""
if point.size(-1) != 3:
raise ValueError(f"Points are not in 3D, f{point.shape}.")
real_parts = point.new_zeros(point.shape[:-1] + (1,))
point_as_quaternion = torch.cat((real_parts, point), -1)
out = quaternion_raw_multiply(
quaternion_raw_multiply(quaternion, point_as_quaternion),
quaternion_invert(quaternion),
)
return out[..., 1:]