mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 23:00:34 +08:00
Initial commit
fbshipit-source-id: ad58e416e3ceeca85fae0583308968d04e78fe0d
This commit is contained in:
25
pytorch3d/transforms/__init__.py
Normal file
25
pytorch3d/transforms/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from .rotation_conversions import (
|
||||
euler_angles_to_matrix,
|
||||
matrix_to_euler_angles,
|
||||
matrix_to_quaternion,
|
||||
quaternion_apply,
|
||||
quaternion_invert,
|
||||
quaternion_multiply,
|
||||
quaternion_raw_multiply,
|
||||
quaternion_to_matrix,
|
||||
random_quaternions,
|
||||
random_rotation,
|
||||
random_rotations,
|
||||
standardize_quaternion,
|
||||
)
|
||||
from .so3 import (
|
||||
so3_exponential_map,
|
||||
so3_log_map,
|
||||
so3_relative_angle,
|
||||
so3_rotation_angle,
|
||||
)
|
||||
from .transform3d import Rotate, RotateAxisAngle, Scale, Transform3d, Translate
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
374
pytorch3d/transforms/rotation_conversions.py
Normal file
374
pytorch3d/transforms/rotation_conversions.py
Normal file
@@ -0,0 +1,374 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import functools
|
||||
import torch
|
||||
|
||||
|
||||
def quaternion_to_matrix(quaternions):
|
||||
"""
|
||||
Convert rotations given as quaternions to rotation matrices.
|
||||
|
||||
Args:
|
||||
quaternions: quaternions with real part first,
|
||||
as tensor of shape (..., 4).
|
||||
|
||||
Returns:
|
||||
Rotation matrices as tensor of shape (..., 3, 3).
|
||||
"""
|
||||
r, i, j, k = torch.unbind(quaternions, -1)
|
||||
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
||||
|
||||
o = torch.stack(
|
||||
(
|
||||
1 - two_s * (j * j + k * k),
|
||||
two_s * (i * j - k * r),
|
||||
two_s * (i * k + j * r),
|
||||
two_s * (i * j + k * r),
|
||||
1 - two_s * (i * i + k * k),
|
||||
two_s * (j * k - i * r),
|
||||
two_s * (i * k - j * r),
|
||||
two_s * (j * k + i * r),
|
||||
1 - two_s * (i * i + j * j),
|
||||
),
|
||||
-1,
|
||||
)
|
||||
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
||||
|
||||
|
||||
def _copysign(a, b):
|
||||
"""
|
||||
Return a tensor where each element has the absolute value taken from the,
|
||||
corresponding element of a, with sign taken from the corresponding
|
||||
element of b. This is like the standard copysign floating-point operation,
|
||||
but is not careful about negative 0 and NaN.
|
||||
|
||||
Args:
|
||||
a: source tensor.
|
||||
b: tensor whose signs will be used, of the same shape as a.
|
||||
|
||||
Returns:
|
||||
Tensor of the same shape as a with the signs of b.
|
||||
"""
|
||||
signs_differ = (a < 0) != (b < 0)
|
||||
return torch.where(signs_differ, -a, a)
|
||||
|
||||
|
||||
def matrix_to_quaternion(matrix):
|
||||
"""
|
||||
Convert rotations given as rotation matrices to quaternions.
|
||||
|
||||
Args:
|
||||
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
||||
|
||||
Returns:
|
||||
quaternions with real part first, as tensor of shape (..., 4).
|
||||
"""
|
||||
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
||||
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
|
||||
zero = matrix.new_zeros((1,))
|
||||
m00 = matrix[..., 0, 0]
|
||||
m11 = matrix[..., 1, 1]
|
||||
m22 = matrix[..., 2, 2]
|
||||
o0 = 0.5 * torch.sqrt(torch.max(zero, 1 + m00 + m11 + m22))
|
||||
x = 0.5 * torch.sqrt(torch.max(zero, 1 + m00 - m11 - m22))
|
||||
y = 0.5 * torch.sqrt(torch.max(zero, 1 - m00 + m11 - m22))
|
||||
z = 0.5 * torch.sqrt(torch.max(zero, 1 - m00 - m11 + m22))
|
||||
o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
|
||||
o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
|
||||
o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
|
||||
return torch.stack((o0, o1, o2, o3), -1)
|
||||
|
||||
|
||||
def _primary_matrix(axis: str, angle):
|
||||
"""
|
||||
Return the rotation matrices for one of the rotations about an axis
|
||||
of which Euler angles describe, for each value of the angle given.
|
||||
|
||||
Args:
|
||||
axis: Axis label "X" or "Y or "Z".
|
||||
angle: any shape tensor of Euler angles in radians
|
||||
|
||||
Returns:
|
||||
Rotation matrices as tensor of shape (..., 3, 3).
|
||||
"""
|
||||
cos = torch.cos(angle)
|
||||
sin = torch.sin(angle)
|
||||
one = torch.ones_like(angle)
|
||||
zero = torch.zeros_like(angle)
|
||||
if axis == "X":
|
||||
o = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
|
||||
if axis == "Y":
|
||||
o = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
|
||||
if axis == "Z":
|
||||
o = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
|
||||
return torch.stack(o, -1).reshape(angle.shape + (3, 3))
|
||||
|
||||
|
||||
def euler_angles_to_matrix(euler_angles, convention: str):
|
||||
"""
|
||||
Convert rotations given as Euler angles in radians to rotation matrices.
|
||||
|
||||
Args:
|
||||
euler_angles: Euler angles in radians as tensor of shape (..., 3).
|
||||
convention: Convention string of three uppercase letters from
|
||||
{"X", "Y", and "Z"}.
|
||||
|
||||
Returns:
|
||||
Rotation matrices as tensor of shape (..., 3, 3).
|
||||
"""
|
||||
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
|
||||
raise ValueError("Invalid input euler angles.")
|
||||
if len(convention) != 3:
|
||||
raise ValueError("Convention must have 3 letters.")
|
||||
if convention[1] in (convention[0], convention[2]):
|
||||
raise ValueError(f"Invalid convention {convention}.")
|
||||
for letter in convention:
|
||||
if letter not in ("X", "Y", "Z"):
|
||||
raise ValueError(f"Invalid letter {letter} in convention string.")
|
||||
matrices = map(_primary_matrix, convention, torch.unbind(euler_angles, -1))
|
||||
return functools.reduce(torch.matmul, matrices)
|
||||
|
||||
|
||||
def _angle_from_tan(
|
||||
axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
|
||||
):
|
||||
"""
|
||||
Extract the first or third Euler angle from the two members of
|
||||
the matrix which are positive constant times its sine and cosine.
|
||||
|
||||
Args:
|
||||
axis: Axis label "X" or "Y or "Z" for the angle we are finding.
|
||||
other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
|
||||
convention.
|
||||
data: Rotation matrices as tensor of shape (..., 3, 3).
|
||||
horizontal: Whether we are looking for the angle for the third axis,
|
||||
which means the relevant entries are in the same row of the
|
||||
rotation matrix. If not, they are in the same column.
|
||||
tait_bryan: Whether the first and third axes in the convention differ.
|
||||
|
||||
Returns:
|
||||
Euler Angles in radians for each matrix in data as a tensor
|
||||
of shape (...).
|
||||
"""
|
||||
|
||||
i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
|
||||
if horizontal:
|
||||
i2, i1 = i1, i2
|
||||
even = (axis + other_axis) in ["XY", "YZ", "ZX"]
|
||||
if horizontal == even:
|
||||
return torch.atan2(data[..., i1], data[..., i2])
|
||||
if tait_bryan:
|
||||
return torch.atan2(-data[..., i2], data[..., i1])
|
||||
return torch.atan2(data[..., i2], -data[..., i1])
|
||||
|
||||
|
||||
def _index_from_letter(letter: str):
|
||||
if letter == "X":
|
||||
return 0
|
||||
if letter == "Y":
|
||||
return 1
|
||||
if letter == "Z":
|
||||
return 2
|
||||
|
||||
|
||||
def matrix_to_euler_angles(matrix, convention: str):
|
||||
"""
|
||||
Convert rotations given as rotation matrices to Euler angles in radians.
|
||||
|
||||
Args:
|
||||
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
||||
convention: Convention string of three uppercase letters.
|
||||
|
||||
Returns:
|
||||
Euler angles in radians as tensor of shape (..., 3).
|
||||
"""
|
||||
if len(convention) != 3:
|
||||
raise ValueError("Convention must have 3 letters.")
|
||||
if convention[1] in (convention[0], convention[2]):
|
||||
raise ValueError(f"Invalid convention {convention}.")
|
||||
for letter in convention:
|
||||
if letter not in ("X", "Y", "Z"):
|
||||
raise ValueError(f"Invalid letter {letter} in convention string.")
|
||||
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
||||
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
|
||||
i0 = _index_from_letter(convention[0])
|
||||
i2 = _index_from_letter(convention[2])
|
||||
tait_bryan = i0 != i2
|
||||
if tait_bryan:
|
||||
central_angle = torch.asin(
|
||||
matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
|
||||
)
|
||||
else:
|
||||
central_angle = torch.acos(matrix[..., i0, i0])
|
||||
|
||||
o = (
|
||||
_angle_from_tan(
|
||||
convention[0], convention[1], matrix[..., i2], False, tait_bryan
|
||||
),
|
||||
central_angle,
|
||||
_angle_from_tan(
|
||||
convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
|
||||
),
|
||||
)
|
||||
return torch.stack(o, -1)
|
||||
|
||||
|
||||
def random_quaternions(
|
||||
n: int, dtype: torch.dtype = None, device=None, requires_grad=False
|
||||
):
|
||||
"""
|
||||
Generate random quaternions representing rotations,
|
||||
i.e. versors with nonnegative real part.
|
||||
|
||||
Args:
|
||||
n: Number to return.
|
||||
dtype: Type to return.
|
||||
device: Desired device of returned tensor. Default:
|
||||
uses the current device for the default tensor type.
|
||||
requires_grad: Whether the resulting tensor should have the gradient
|
||||
flag set.
|
||||
|
||||
Returns:
|
||||
Quaternions as tensor of shape (N, 4).
|
||||
"""
|
||||
o = torch.randn(
|
||||
(n, 4), dtype=dtype, device=device, requires_grad=requires_grad
|
||||
)
|
||||
s = (o * o).sum(1)
|
||||
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
|
||||
return o
|
||||
|
||||
|
||||
def random_rotations(
|
||||
n: int, dtype: torch.dtype = None, device=None, requires_grad=False
|
||||
):
|
||||
"""
|
||||
Generate random rotations as 3x3 rotation matrices.
|
||||
|
||||
Args:
|
||||
n: Number to return.
|
||||
dtype: Type to return.
|
||||
device: Device of returned tensor. Default: if None,
|
||||
uses the current device for the default tensor type.
|
||||
requires_grad: Whether the resulting tensor should have the gradient
|
||||
flag set.
|
||||
|
||||
Returns:
|
||||
Rotation matrices as tensor of shape (n, 3, 3).
|
||||
"""
|
||||
quaternions = random_quaternions(
|
||||
n, dtype=dtype, device=device, requires_grad=requires_grad
|
||||
)
|
||||
return quaternion_to_matrix(quaternions)
|
||||
|
||||
|
||||
def random_rotation(
|
||||
dtype: torch.dtype = None, device=None, requires_grad=False
|
||||
):
|
||||
"""
|
||||
Generate a single random 3x3 rotation matrix.
|
||||
|
||||
Args:
|
||||
dtype: Type to return
|
||||
device: Device of returned tensor. Default: if None,
|
||||
uses the current device for the default tensor type
|
||||
requires_grad: Whether the resulting tensor should have the gradient
|
||||
flag set
|
||||
|
||||
Returns:
|
||||
Rotation matrix as tensor of shape (3, 3).
|
||||
"""
|
||||
return random_rotations(1, dtype, device, requires_grad)[0]
|
||||
|
||||
|
||||
def standardize_quaternion(quaternions):
|
||||
"""
|
||||
Convert a unit quaternion to a standard form: one in which the real
|
||||
part is non negative.
|
||||
|
||||
Args:
|
||||
quaternions: Quaternions with real part first,
|
||||
as tensor of shape (..., 4).
|
||||
|
||||
Returns:
|
||||
Standardized quaternions as tensor of shape (..., 4).
|
||||
"""
|
||||
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
|
||||
|
||||
|
||||
def quaternion_raw_multiply(a, b):
|
||||
"""
|
||||
Multiply two quaternions.
|
||||
Usual torch rules for broadcasting apply.
|
||||
|
||||
Args:
|
||||
a: Quaternions as tensor of shape (..., 4), real part first.
|
||||
b: Quaternions as tensor of shape (..., 4), real part first.
|
||||
|
||||
Returns:
|
||||
The product of a and b, a tensor of quaternions shape (..., 4).
|
||||
"""
|
||||
aw, ax, ay, az = torch.unbind(a, -1)
|
||||
bw, bx, by, bz = torch.unbind(b, -1)
|
||||
ow = aw * bw - ax * bx - ay * by - az * bz
|
||||
ox = aw * bx + ax * bw + ay * bz - az * by
|
||||
oy = aw * by - ax * bz + ay * bw + az * bx
|
||||
oz = aw * bz + ax * by - ay * bx + az * bw
|
||||
return torch.stack((ow, ox, oy, oz), -1)
|
||||
|
||||
|
||||
def quaternion_multiply(a, b):
|
||||
"""
|
||||
Multiply two quaternions representing rotations, returning the quaternion
|
||||
representing their composition, i.e. the versor with nonnegative real part.
|
||||
Usual torch rules for broadcasting apply.
|
||||
|
||||
Args:
|
||||
a: Quaternions as tensor of shape (..., 4), real part first.
|
||||
b: Quaternions as tensor of shape (..., 4), real part first.
|
||||
|
||||
Returns:
|
||||
The product of a and b, a tensor of quaternions of shape (..., 4).
|
||||
"""
|
||||
ab = quaternion_raw_multiply(a, b)
|
||||
return standardize_quaternion(ab)
|
||||
|
||||
|
||||
def quaternion_invert(quaternion):
|
||||
"""
|
||||
Given a quaternion representing rotation, get the quaternion representing
|
||||
its inverse.
|
||||
|
||||
Args:
|
||||
quaternion: Quaternions as tensor of shape (..., 4), with real part
|
||||
first, which must be versors (unit quaternions).
|
||||
|
||||
Returns:
|
||||
The inverse, a tensor of quaternions of shape (..., 4).
|
||||
"""
|
||||
|
||||
return quaternion * quaternion.new_tensor([1, -1, -1, -1])
|
||||
|
||||
|
||||
def quaternion_apply(quaternion, point):
|
||||
"""
|
||||
Apply the rotation given by a quaternion to a 3D point.
|
||||
Usual torch rules for broadcasting apply.
|
||||
|
||||
Args:
|
||||
quaternion: Tensor of quaternions, real part first, of shape (..., 4).
|
||||
point: Tensor of 3D points of shape (..., 3).
|
||||
|
||||
Returns:
|
||||
Tensor of rotated points of shape (..., 3).
|
||||
"""
|
||||
if point.size(-1) != 3:
|
||||
raise ValueError(f"Points are not in 3D, f{point.shape}.")
|
||||
real_parts = point.new_zeros(point.shape[:-1] + (1,))
|
||||
point_as_quaternion = torch.cat((real_parts, point), -1)
|
||||
out = quaternion_raw_multiply(
|
||||
quaternion_raw_multiply(quaternion, point_as_quaternion),
|
||||
quaternion_invert(quaternion),
|
||||
)
|
||||
return out[..., 1:]
|
||||
236
pytorch3d/transforms/so3.py
Normal file
236
pytorch3d/transforms/so3.py
Normal file
@@ -0,0 +1,236 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5
|
||||
|
||||
|
||||
def so3_relative_angle(R1, R2, cos_angle: bool = False):
|
||||
"""
|
||||
Calculates the relative angle (in radians) between pairs of
|
||||
rotation matrices `R1` and `R2` with `angle = acos(0.5 * Trace(R1 R2^T)-1)`
|
||||
|
||||
.. note::
|
||||
This corresponds to a geodesic distance on the 3D manifold of rotation
|
||||
matrices.
|
||||
|
||||
Args:
|
||||
R1: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
|
||||
R2: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
|
||||
cos_angle: If==True return cosine of the relative angle rather than
|
||||
the angle itself. This can avoid the unstable
|
||||
calculation of `acos`.
|
||||
|
||||
Returns:
|
||||
Corresponding rotation angles of shape `(minibatch,)`.
|
||||
If `cos_angle==True`, returns the cosine of the angles.
|
||||
|
||||
Raises:
|
||||
ValueError if `R1` or `R2` is of incorrect shape.
|
||||
ValueError if `R1` or `R2` has an unexpected trace.
|
||||
"""
|
||||
R12 = torch.bmm(R1, R2.permute(0, 2, 1))
|
||||
return so3_rotation_angle(R12, cos_angle=cos_angle)
|
||||
|
||||
|
||||
def so3_rotation_angle(R, eps: float = 1e-4, cos_angle: bool = False):
|
||||
"""
|
||||
Calculates angles (in radians) of a batch of rotation matrices `R` with
|
||||
`angle = acos(0.5 * (Trace(R)-1))`. The trace of the
|
||||
input matrices is checked to be in the valid range `[-1-eps,3+eps]`.
|
||||
The `eps` argument is a small constant that allows for small errors
|
||||
caused by limited machine precision.
|
||||
|
||||
Args:
|
||||
R: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
|
||||
eps: Tolerance for the valid trace check.
|
||||
cos_angle: If==True return cosine of the rotation angles rather than
|
||||
the angle itself. This can avoid the unstable
|
||||
calculation of `acos`.
|
||||
|
||||
Returns:
|
||||
Corresponding rotation angles of shape `(minibatch,)`.
|
||||
If `cos_angle==True`, returns the cosine of the angles.
|
||||
|
||||
Raises:
|
||||
ValueError if `R` is of incorrect shape.
|
||||
ValueError if `R` has an unexpected trace.
|
||||
"""
|
||||
|
||||
N, dim1, dim2 = R.shape
|
||||
if dim1 != 3 or dim2 != 3:
|
||||
raise ValueError("Input has to be a batch of 3x3 Tensors.")
|
||||
|
||||
rot_trace = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2]
|
||||
|
||||
if ((rot_trace < -1.0 - eps) + (rot_trace > 3.0 + eps)).any():
|
||||
raise ValueError(
|
||||
"A matrix has trace outside valid range [-1-eps,3+eps]."
|
||||
)
|
||||
|
||||
# clamp to valid range
|
||||
rot_trace = torch.clamp(rot_trace, -1.0, 3.0)
|
||||
|
||||
# phi ... rotation angle
|
||||
phi = 0.5 * (rot_trace - 1.0)
|
||||
|
||||
if cos_angle:
|
||||
return phi
|
||||
else:
|
||||
return phi.acos()
|
||||
|
||||
|
||||
def so3_exponential_map(log_rot, eps: float = 0.0001):
|
||||
"""
|
||||
Convert a batch of logarithmic representations of rotation matrices `log_rot`
|
||||
to a batch of 3x3 rotation matrices using Rodrigues formula [1].
|
||||
|
||||
In the logarithmic representation, each rotation matrix is represented as
|
||||
a 3-dimensional vector (`log_rot`) who's l2-norm and direction correspond
|
||||
to the magnitude of the rotation angle and the axis of rotation respectively.
|
||||
|
||||
The conversion has a singularity around `log(R) = 0`
|
||||
which is handled by clamping controlled with the `eps` argument.
|
||||
|
||||
Args:
|
||||
log_rot: Batch of vectors of shape `(minibatch , 3)`.
|
||||
eps: A float constant handling the conversion singularity.
|
||||
|
||||
Returns:
|
||||
Batch of rotation matrices of shape `(minibatch , 3 , 3)`.
|
||||
|
||||
Raises:
|
||||
ValueError if `log_rot` is of incorrect shape.
|
||||
|
||||
[1] https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
|
||||
"""
|
||||
|
||||
_, dim = log_rot.shape
|
||||
if dim != 3:
|
||||
raise ValueError("Input tensor shape has to be Nx3.")
|
||||
|
||||
nrms = (log_rot * log_rot).sum(1)
|
||||
# phis ... rotation angles
|
||||
rot_angles = torch.clamp(nrms, eps).sqrt()
|
||||
rot_angles_inv = 1.0 / rot_angles
|
||||
fac1 = rot_angles_inv * rot_angles.sin()
|
||||
fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles.cos())
|
||||
skews = hat(log_rot)
|
||||
|
||||
R = (
|
||||
fac1[:, None, None] * skews
|
||||
+ fac2[:, None, None] * torch.bmm(skews, skews)
|
||||
+ torch.eye(3, dtype=log_rot.dtype, device=log_rot.device)[None]
|
||||
)
|
||||
|
||||
return R
|
||||
|
||||
|
||||
def so3_log_map(R, eps: float = 0.0001):
|
||||
"""
|
||||
Convert a batch of 3x3 rotation matrices `R`
|
||||
to a batch of 3-dimensional matrix logarithms of rotation matrices
|
||||
The conversion has a singularity around `(R=I)` which is handled
|
||||
by clamping controlled with the `eps` argument.
|
||||
|
||||
Args:
|
||||
R: batch of rotation matrices of shape `(minibatch, 3, 3)`.
|
||||
eps: A float constant handling the conversion singularity.
|
||||
|
||||
Returns:
|
||||
Batch of logarithms of input rotation matrices
|
||||
of shape `(minibatch, 3)`.
|
||||
|
||||
Raises:
|
||||
ValueError if `R` is of incorrect shape.
|
||||
ValueError if `R` has an unexpected trace.
|
||||
"""
|
||||
|
||||
N, dim1, dim2 = R.shape
|
||||
if dim1 != 3 or dim2 != 3:
|
||||
raise ValueError("Input has to be a batch of 3x3 Tensors.")
|
||||
|
||||
phi = so3_rotation_angle(R)
|
||||
|
||||
phi_valid = torch.clamp(phi.abs(), eps) * phi.sign()
|
||||
|
||||
log_rot_hat = (phi_valid / (2.0 * phi_valid.sin()))[:, None, None] * (
|
||||
R - R.permute(0, 2, 1)
|
||||
)
|
||||
log_rot = hat_inv(log_rot_hat)
|
||||
|
||||
return log_rot
|
||||
|
||||
|
||||
def hat_inv(h):
|
||||
"""
|
||||
Compute the inverse Hat operator [1] of a batch of 3x3 matrices.
|
||||
|
||||
Args:
|
||||
h: Batch of skew-symmetric matrices of shape `(minibatch, 3, 3)`.
|
||||
|
||||
Returns:
|
||||
Batch of 3d vectors of shape `(minibatch, 3, 3)`.
|
||||
|
||||
Raises:
|
||||
ValueError if `h` is of incorrect shape.
|
||||
ValueError if `h` not skew-symmetric.
|
||||
|
||||
[1] https://en.wikipedia.org/wiki/Hat_operator
|
||||
"""
|
||||
|
||||
N, dim1, dim2 = h.shape
|
||||
if dim1 != 3 or dim2 != 3:
|
||||
raise ValueError("Input has to be a batch of 3x3 Tensors.")
|
||||
|
||||
ss_diff = (h + h.permute(0, 2, 1)).abs().max()
|
||||
if float(ss_diff) > HAT_INV_SKEW_SYMMETRIC_TOL:
|
||||
raise ValueError("One of input matrices not skew-symmetric.")
|
||||
|
||||
x = h[:, 2, 1]
|
||||
y = h[:, 0, 2]
|
||||
z = h[:, 1, 0]
|
||||
|
||||
v = torch.stack((x, y, z), dim=1)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def hat(v):
|
||||
"""
|
||||
Compute the Hat operator [1] of a batch of 3D vectors.
|
||||
|
||||
Args:
|
||||
v: Batch of vectors of shape `(minibatch , 3)`.
|
||||
|
||||
Returns:
|
||||
Batch of skew-symmetric matrices of shape
|
||||
`(minibatch, 3 , 3)` where each matrix is of the form:
|
||||
`[ 0 -v_z v_y ]
|
||||
[ v_z 0 -v_x ]
|
||||
[ -v_y v_x 0 ]`
|
||||
|
||||
Raises:
|
||||
ValueError if `v` is of incorrect shape.
|
||||
|
||||
[1] https://en.wikipedia.org/wiki/Hat_operator
|
||||
"""
|
||||
|
||||
N, dim = v.shape
|
||||
if dim != 3:
|
||||
raise ValueError("Input vectors have to be 3-dimensional.")
|
||||
|
||||
h = v.new_zeros(N, 3, 3)
|
||||
|
||||
x, y, z = v.unbind(1)
|
||||
|
||||
h[:, 0, 1] = -z
|
||||
h[:, 0, 2] = y
|
||||
h[:, 1, 0] = z
|
||||
h[:, 1, 2] = -x
|
||||
h[:, 2, 0] = -y
|
||||
h[:, 2, 1] = x
|
||||
|
||||
return h
|
||||
677
pytorch3d/transforms/transform3d.py
Normal file
677
pytorch3d/transforms/transform3d.py
Normal file
@@ -0,0 +1,677 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import math
|
||||
import torch
|
||||
|
||||
|
||||
class Transform3d:
|
||||
"""
|
||||
A Transform3d object encapsulates a batch of N 3D transformations, and knows
|
||||
how to transform points and normal vectors. Suppose that t is a Transform3d;
|
||||
then we can do the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
N = len(t)
|
||||
points = torch.randn(N, P, 3)
|
||||
normals = torch.randn(N, P, 3)
|
||||
points_transformed = t.transform_points(points) # => (N, P, 3)
|
||||
normals_transformed = t.transform_points(normals) # => (N, P, 3)
|
||||
|
||||
|
||||
BROADCASTING
|
||||
Transform3d objects supports broadcasting. Suppose that t1 and tN are
|
||||
Transform3D objects with len(t1) == 1 and len(tN) == N respectively. Then we
|
||||
can broadcast transforms like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
t1.transform_points(torch.randn(P, 3)) # => (P, 3)
|
||||
t1.transform_points(torch.randn(1, P, 3)) # => (1, P, 3)
|
||||
t1.transform_points(torch.randn(M, P, 3)) # => (M, P, 3)
|
||||
tN.transform_points(torch.randn(P, 3)) # => (N, P, 3)
|
||||
tN.transform_points(torch.randn(1, P, 3)) # => (N, P, 3)
|
||||
|
||||
|
||||
COMBINING TRANSFORMS
|
||||
Transform3d objects can be combined in two ways: composing and stacking.
|
||||
Composing is function composition. Given Transform3d objects t1, t2, t3,
|
||||
the following all compute the same thing:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
y1 = t3.transform_points(t2.transform_points(t2.transform_points(x)))
|
||||
y2 = t1.compose(t2).compose(t3).transform_points()
|
||||
y3 = t1.compose(t2, t3).transform_points()
|
||||
|
||||
|
||||
Composing transforms should broadcast.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
if len(t1) == 1 and len(t2) == N, then len(t1.compose(t2)) == N.
|
||||
|
||||
We can also stack a sequence of Transform3d objects, which represents
|
||||
composition along the batch dimension; then the following should compute the
|
||||
same thing.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
N, M = len(tN), len(tM)
|
||||
xN = torch.randn(N, P, 3)
|
||||
xM = torch.randn(M, P, 3)
|
||||
y1 = torch.cat([tN.transform_points(xN), tM.transform_points(xM)], dim=0)
|
||||
y2 = tN.stack(tM).transform_points(torch.cat([xN, xM], dim=0))
|
||||
|
||||
BUILDING TRANSFORMS
|
||||
We provide convenience methods for easily building Transform3d objects
|
||||
as compositions of basic transforms.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Scale by 0.5, then translate by (1, 2, 3)
|
||||
t1 = Transform3d().scale(0.5).translate(1, 2, 3)
|
||||
|
||||
# Scale each axis by a different amount, then translate, then scale
|
||||
t2 = Transform3d().scale(1, 3, 3).translate(2, 3, 1).scale(2.0)
|
||||
|
||||
t3 = t1.compose(t2)
|
||||
tN = t1.stack(t3, t3)
|
||||
|
||||
|
||||
BACKPROP THROUGH TRANSFORMS
|
||||
When building transforms, we can also parameterize them by Torch tensors;
|
||||
in this case we can backprop through the construction and application of
|
||||
Transform objects, so they could be learned via gradient descent or
|
||||
predicted by a neural network.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
s1_params = torch.randn(N, requires_grad=True)
|
||||
t_params = torch.randn(N, 3, requires_grad=True)
|
||||
s2_params = torch.randn(N, 3, requires_grad=True)
|
||||
|
||||
t = Transform3d().scale(s1_params).translate(t_params).scale(s2_params)
|
||||
x = torch.randn(N, 3)
|
||||
y = t.transform_points(x)
|
||||
loss = compute_loss(y)
|
||||
loss.backward()
|
||||
|
||||
with torch.no_grad():
|
||||
s1_params -= lr * s1_params.grad
|
||||
t_params -= lr * t_params.grad
|
||||
s2_params -= lr * s2_params.grad
|
||||
"""
|
||||
|
||||
def __init__(self, dtype=torch.float32, device="cpu"):
|
||||
"""
|
||||
This class assumes a row major ordering for all matrices.
|
||||
"""
|
||||
self._matrix = torch.eye(4, dtype=dtype, device=device).view(1, 4, 4)
|
||||
self._transforms = [] # store transforms to compose
|
||||
self._lu = None
|
||||
self.device = device
|
||||
|
||||
def __len__(self):
|
||||
return self.get_matrix().shape[0]
|
||||
|
||||
def compose(self, *others):
|
||||
"""
|
||||
Return a new Transform3d with the tranforms to compose stored as
|
||||
an internal list.
|
||||
|
||||
Args:
|
||||
*others: Any number of Transform3d objects
|
||||
|
||||
Returns:
|
||||
A new Transform3d with the stored transforms
|
||||
"""
|
||||
out = Transform3d(device=self.device)
|
||||
out._matrix = self._matrix.clone()
|
||||
for other in others:
|
||||
if not isinstance(other, Transform3d):
|
||||
msg = "Only possible to compose Transform3d objects; got %s"
|
||||
raise ValueError(msg % type(other))
|
||||
out._transforms = self._transforms + list(others)
|
||||
return out
|
||||
|
||||
def get_matrix(self):
|
||||
"""
|
||||
Return a matrix which is the result of composing this transform
|
||||
with others stored in self.transforms. Where necessary transforms
|
||||
are broadcast against each other.
|
||||
For example, if self.transforms contains transforms t1, t2, and t3, and
|
||||
given a set of points x, the following should be true:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
y1 = t1.compose(t2, t3).transform(x)
|
||||
y2 = t3.transform(t2.transform(t1.transform(x)))
|
||||
y1.get_matrix() == y2.get_matrix()
|
||||
|
||||
Returns:
|
||||
A transformation matrix representing the composed inputs.
|
||||
"""
|
||||
composed_matrix = self._matrix.clone()
|
||||
if len(self._transforms) > 0:
|
||||
for other in self._transforms:
|
||||
other_matrix = other.get_matrix()
|
||||
composed_matrix = _broadcast_bmm(composed_matrix, other_matrix)
|
||||
return composed_matrix
|
||||
|
||||
def _get_matrix_inverse(self):
|
||||
"""
|
||||
Return the inverse of self._matrix.
|
||||
"""
|
||||
return torch.inverse(self._matrix)
|
||||
|
||||
def inverse(self, invert_composed: bool = False):
|
||||
"""
|
||||
Returns a new Transform3D object that represents an inverse of the
|
||||
current transformation.
|
||||
|
||||
Args:
|
||||
invert_composed:
|
||||
- True: First compose the list of stored transformations
|
||||
and then apply inverse to the result. This is
|
||||
potentially slower for classes of transformations
|
||||
with inverses that can be computed efficiently
|
||||
(e.g. rotations and translations).
|
||||
- False: Invert the individual stored transformations
|
||||
independently without composing them.
|
||||
|
||||
Returns:
|
||||
A new Transform3D object contaning the inverse of the original
|
||||
transformation.
|
||||
"""
|
||||
|
||||
tinv = Transform3d(device=self.device)
|
||||
|
||||
if invert_composed:
|
||||
# first compose then invert
|
||||
tinv._matrix = torch.inverse(self.get_matrix())
|
||||
else:
|
||||
# self._get_matrix_inverse() implements efficient inverse
|
||||
# of self._matrix
|
||||
i_matrix = self._get_matrix_inverse()
|
||||
|
||||
# 2 cases:
|
||||
if len(self._transforms) > 0:
|
||||
# a) Either we have a non-empty list of transforms:
|
||||
# Here we take self._matrix and append its inverse at the
|
||||
# end of the reverted _transforms list. After composing
|
||||
# the transformations with get_matrix(), this correctly
|
||||
# right-multiplies by the inverse of self._matrix
|
||||
# at the end of the composition.
|
||||
tinv._transforms = [
|
||||
t.inverse() for t in reversed(self._transforms)
|
||||
]
|
||||
last = Transform3d(device=self.device)
|
||||
last._matrix = i_matrix
|
||||
tinv._transforms.append(last)
|
||||
else:
|
||||
# b) Or there are no stored transformations
|
||||
# we just set inverted matrix
|
||||
tinv._matrix = i_matrix
|
||||
|
||||
return tinv
|
||||
|
||||
def stack(self, *others):
|
||||
transforms = [self] + list(others)
|
||||
matrix = torch.cat([t._matrix for t in transforms], dim=0)
|
||||
out = Transform3d()
|
||||
out._matrix = matrix
|
||||
return out
|
||||
|
||||
def transform_points(self, points, eps: float = None):
|
||||
"""
|
||||
Use this transform to transform a set of 3D points. Assumes row major
|
||||
ordering of the input points.
|
||||
|
||||
Args:
|
||||
points: Tensor of shape (P, 3) or (N, P, 3)
|
||||
eps: If eps!=None, the argument is used to clamp the
|
||||
last coordinate before peforming the final division.
|
||||
The clamping corresponds to:
|
||||
last_coord := (last_coord.sign() + (last_coord==0)) *
|
||||
torch.clamp(last_coord.abs(), eps),
|
||||
i.e. the last coordinates that are exactly 0 will
|
||||
be clamped to +eps.
|
||||
|
||||
Returns:
|
||||
points_out: points of shape (N, P, 3) or (P, 3) depending
|
||||
on the dimensions of the transform
|
||||
"""
|
||||
points_batch = points.clone()
|
||||
if points_batch.dim() == 2:
|
||||
points_batch = points_batch[None] # (P, 3) -> (1, P, 3)
|
||||
if points_batch.dim() != 3:
|
||||
msg = "Expected points to have dim = 2 or dim = 3: got shape %r"
|
||||
raise ValueError(msg % points.shape)
|
||||
|
||||
N, P, _3 = points_batch.shape
|
||||
ones = torch.ones(N, P, 1, dtype=points.dtype, device=points.device)
|
||||
points_batch = torch.cat([points_batch, ones], dim=2)
|
||||
|
||||
composed_matrix = self.get_matrix()
|
||||
points_out = _broadcast_bmm(points_batch, composed_matrix)
|
||||
denom = points_out[..., 3:] # denominator
|
||||
if eps is not None:
|
||||
denom_sign = denom.sign() + (denom == 0.0).type_as(denom)
|
||||
denom = denom_sign * torch.clamp(denom.abs(), eps)
|
||||
points_out = points_out[..., :3] / denom
|
||||
|
||||
# When transform is (1, 4, 4) and points is (P, 3) return
|
||||
# points_out of shape (P, 3)
|
||||
if points_out.shape[0] == 1 and points.dim() == 2:
|
||||
points_out = points_out.reshape(points.shape)
|
||||
|
||||
return points_out
|
||||
|
||||
def transform_normals(self, normals):
|
||||
"""
|
||||
Use this transform to transform a set of normal vectors.
|
||||
|
||||
Args:
|
||||
normals: Tensor of shape (P, 3) or (N, P, 3)
|
||||
|
||||
Returns:
|
||||
normals_out: Tensor of shape (P, 3) or (N, P, 3) depending
|
||||
on the dimensions of the transform
|
||||
"""
|
||||
if normals.dim() not in [2, 3]:
|
||||
msg = "Expected normals to have dim = 2 or dim = 3: got shape %r"
|
||||
raise ValueError(msg % normals.shape)
|
||||
composed_matrix = self.get_matrix()
|
||||
|
||||
# TODO: inverse is bad! Solve a linear system instead
|
||||
mat = composed_matrix[:, :3, :3]
|
||||
normals_out = _broadcast_bmm(normals, mat.transpose(1, 2).inverse())
|
||||
|
||||
# This doesn't pass unit tests. TODO investigate further
|
||||
# if self._lu is None:
|
||||
# self._lu = self._matrix[:, :3, :3].transpose(1, 2).lu()
|
||||
# normals_out = normals.lu_solve(*self._lu)
|
||||
|
||||
# When transform is (1, 4, 4) and normals is (P, 3) return
|
||||
# normals_out of shape (P, 3)
|
||||
if normals_out.shape[0] == 1 and normals.dim() == 2:
|
||||
normals_out = normals_out.reshape(normals.shape)
|
||||
|
||||
return normals_out
|
||||
|
||||
def translate(self, *args, **kwargs):
|
||||
return self.compose(Translate(device=self.device, *args, **kwargs))
|
||||
|
||||
def scale(self, *args, **kwargs):
|
||||
return self.compose(Scale(device=self.device, *args, **kwargs))
|
||||
|
||||
def rotate_axis_angle(self, *args, **kwargs):
|
||||
return self.compose(
|
||||
RotateAxisAngle(device=self.device, *args, **kwargs)
|
||||
)
|
||||
|
||||
def clone(self):
|
||||
"""
|
||||
Deep copy of Transforms object. All internal tensors are cloned
|
||||
individually.
|
||||
|
||||
Returns:
|
||||
new Transforms object.
|
||||
"""
|
||||
other = Transform3d(device=self.device)
|
||||
if self._lu is not None:
|
||||
other._lu = [l.clone() for l in self._lu]
|
||||
other._matrix = self._matrix.clone()
|
||||
other._transforms = [t.clone() for t in self._transforms]
|
||||
return other
|
||||
|
||||
def to(self, device, copy: bool = False, dtype=None):
|
||||
"""
|
||||
Match functionality of torch.Tensor.to()
|
||||
If copy = True or the self Tensor is on a different device, the
|
||||
returned tensor is a copy of self with the desired torch.device.
|
||||
If copy = False and the self Tensor already has the correct torch.device,
|
||||
then self is returned.
|
||||
|
||||
Args:
|
||||
device: Device id for the new tensor.
|
||||
copy: Boolean indicator whether or not to clone self. Default False.
|
||||
dtype: If not None, casts the internal tensor variables
|
||||
to a given torch.dtype.
|
||||
|
||||
Returns:
|
||||
Transform3d object.
|
||||
"""
|
||||
if not copy and self.device == device:
|
||||
return self
|
||||
other = self.clone()
|
||||
if self.device != device:
|
||||
other.device = device
|
||||
other._matrix = self._matrix.to(device=device, dtype=dtype)
|
||||
for t in other._transforms:
|
||||
t.to(device, copy=copy, dtype=dtype)
|
||||
return other
|
||||
|
||||
def cpu(self):
|
||||
return self.to(torch.device("cpu"))
|
||||
|
||||
def cuda(self):
|
||||
return self.to(torch.device("cuda"))
|
||||
|
||||
|
||||
class Translate(Transform3d):
|
||||
def __init__(
|
||||
self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"
|
||||
):
|
||||
"""
|
||||
Create a new Transform3d representing 3D translations.
|
||||
|
||||
Option I: Translate(xyz, dtype=torch.float32, device='cpu')
|
||||
xyz should be a tensor of shape (N, 3)
|
||||
|
||||
Option II: Translate(x, y, z, dtype=torch.float32, device='cpu')
|
||||
Here x, y, and z will be broadcast against each other and
|
||||
concatenated to form the translation. Each can be:
|
||||
- A python scalar
|
||||
- A torch scalar
|
||||
- A 1D torch tensor
|
||||
"""
|
||||
super().__init__(device=device)
|
||||
xyz = _handle_input(x, y, z, dtype, device, "Translate")
|
||||
N = xyz.shape[0]
|
||||
|
||||
mat = torch.eye(4, dtype=dtype, device=device)
|
||||
mat = mat.view(1, 4, 4).repeat(N, 1, 1)
|
||||
mat[:, 3, :3] = xyz
|
||||
self._matrix = mat
|
||||
|
||||
def _get_matrix_inverse(self):
|
||||
"""
|
||||
Return the inverse of self._matrix.
|
||||
"""
|
||||
inv_mask = self._matrix.new_ones([1, 4, 4])
|
||||
inv_mask[0, 3, :3] = -1.0
|
||||
i_matrix = self._matrix * inv_mask
|
||||
return i_matrix
|
||||
|
||||
|
||||
class Scale(Transform3d):
|
||||
def __init__(
|
||||
self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"
|
||||
):
|
||||
"""
|
||||
A Transform3d representing a scaling operation, with different scale
|
||||
factors along each coordinate axis.
|
||||
|
||||
Option I: Scale(s, dtype=torch.float32, device='cpu')
|
||||
s can be one of
|
||||
- Python scalar or torch scalar: Single uniform scale
|
||||
- 1D torch tensor of shape (N,): A batch of uniform scale
|
||||
- 2D torch tensor of shape (N, 3): Scale differently along each axis
|
||||
|
||||
Option II: Scale(x, y, z, dtype=torch.float32, device='cpu')
|
||||
Each of x, y, and z can be one of
|
||||
- python scalar
|
||||
- torch scalar
|
||||
- 1D torch tensor
|
||||
"""
|
||||
super().__init__(device=device)
|
||||
xyz = _handle_input(
|
||||
x, y, z, dtype, device, "scale", allow_singleton=True
|
||||
)
|
||||
N = xyz.shape[0]
|
||||
|
||||
# TODO: Can we do this all in one go somehow?
|
||||
mat = torch.eye(4, dtype=dtype, device=device)
|
||||
mat = mat.view(1, 4, 4).repeat(N, 1, 1)
|
||||
mat[:, 0, 0] = xyz[:, 0]
|
||||
mat[:, 1, 1] = xyz[:, 1]
|
||||
mat[:, 2, 2] = xyz[:, 2]
|
||||
self._matrix = mat
|
||||
|
||||
def _get_matrix_inverse(self):
|
||||
"""
|
||||
Return the inverse of self._matrix.
|
||||
"""
|
||||
xyz = torch.stack([self._matrix[:, i, i] for i in range(4)], dim=1)
|
||||
ixyz = 1.0 / xyz
|
||||
imat = torch.diag_embed(ixyz, dim1=1, dim2=2)
|
||||
return imat
|
||||
|
||||
|
||||
class Rotate(Transform3d):
|
||||
def __init__(
|
||||
self,
|
||||
R,
|
||||
dtype=torch.float32,
|
||||
device: str = "cpu",
|
||||
orthogonal_tol: float = 1e-5,
|
||||
):
|
||||
"""
|
||||
Create a new Transform3d representing 3D rotation using a rotation
|
||||
matrix as the input.
|
||||
|
||||
Args:
|
||||
R: a tensor of shape (3, 3) or (N, 3, 3)
|
||||
orthogonal_tol: tolerance for the test of the orthogonality of R
|
||||
|
||||
"""
|
||||
super().__init__(device=device)
|
||||
if R.dim() == 2:
|
||||
R = R[None]
|
||||
if R.shape[-2:] != (3, 3):
|
||||
msg = "R must have shape (3, 3) or (N, 3, 3); got %s"
|
||||
raise ValueError(msg % repr(R.shape))
|
||||
R = R.to(dtype=dtype).to(device=device)
|
||||
_check_valid_rotation_matrix(R, tol=orthogonal_tol)
|
||||
N = R.shape[0]
|
||||
mat = torch.eye(4, dtype=dtype, device=device)
|
||||
mat = mat.view(1, 4, 4).repeat(N, 1, 1)
|
||||
mat[:, :3, :3] = R
|
||||
self._matrix = mat
|
||||
|
||||
def _get_matrix_inverse(self):
|
||||
"""
|
||||
Return the inverse of self._matrix.
|
||||
"""
|
||||
return self._matrix.permute(0, 2, 1).contiguous()
|
||||
|
||||
|
||||
class RotateAxisAngle(Rotate):
|
||||
def __init__(
|
||||
self,
|
||||
angle,
|
||||
axis: str = "X",
|
||||
degrees: bool = True,
|
||||
dtype=torch.float64,
|
||||
device: str = "cpu",
|
||||
):
|
||||
"""
|
||||
Create a new Transform3d representing 3D rotation about an axis
|
||||
by an angle.
|
||||
|
||||
Args:
|
||||
angle:
|
||||
- A torch tensor of shape (N, 1)
|
||||
- A python scalar
|
||||
- A torch scalar
|
||||
axis:
|
||||
string: one of ["X", "Y", "Z"] indicating the axis about which
|
||||
to rotate.
|
||||
NOTE: All batch elements are rotated about the same axis.
|
||||
"""
|
||||
axis = axis.upper()
|
||||
if axis not in ["X", "Y", "Z"]:
|
||||
msg = "Expected axis to be one of ['X', 'Y', 'Z']; got %s"
|
||||
raise ValueError(msg % axis)
|
||||
angle = _handle_angle_input(angle, dtype, device, "RotateAxisAngle")
|
||||
angle = (angle / 180.0 * math.pi) if degrees else angle
|
||||
N = angle.shape[0]
|
||||
|
||||
cos = torch.cos(angle)
|
||||
sin = torch.sin(angle)
|
||||
one = torch.ones_like(angle)
|
||||
zero = torch.zeros_like(angle)
|
||||
|
||||
if axis == "X":
|
||||
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
|
||||
if axis == "Y":
|
||||
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
|
||||
if axis == "Z":
|
||||
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
|
||||
|
||||
R = torch.stack(R_flat, -1).reshape((N, 3, 3))
|
||||
super().__init__(device=device, R=R)
|
||||
|
||||
|
||||
def _handle_coord(c, dtype, device):
|
||||
"""
|
||||
Helper function for _handle_input.
|
||||
|
||||
Args:
|
||||
c: Python scalar, torch scalar, or 1D torch tensor
|
||||
|
||||
Returns:
|
||||
c_vec: 1D torch tensor
|
||||
"""
|
||||
if not torch.is_tensor(c):
|
||||
c = torch.tensor(c, dtype=dtype, device=device)
|
||||
if c.dim() == 0:
|
||||
c = c.view(1)
|
||||
return c
|
||||
|
||||
|
||||
def _handle_input(
|
||||
x, y, z, dtype, device, name: str, allow_singleton: bool = False
|
||||
):
|
||||
"""
|
||||
Helper function to handle parsing logic for building transforms. The output
|
||||
is always a tensor of shape (N, 3), but there are several types of allowed
|
||||
input.
|
||||
|
||||
Case I: Single Matrix
|
||||
In this case x is a tensor of shape (N, 3), and y and z are None. Here just
|
||||
return x.
|
||||
|
||||
Case II: Vectors and Scalars
|
||||
In this case each of x, y, and z can be one of the following
|
||||
- Python scalar
|
||||
- Torch scalar
|
||||
- Torch tensor of shape (N, 1) or (1, 1)
|
||||
In this case x, y and z are broadcast to tensors of shape (N, 1)
|
||||
and concatenated to a tensor of shape (N, 3)
|
||||
|
||||
Case III: Singleton (only if allow_singleton=True)
|
||||
In this case y and z are None, and x can be one of the following:
|
||||
- Python scalar
|
||||
- Torch scalar
|
||||
- Torch tensor of shape (N, 1) or (1, 1)
|
||||
Here x will be duplicated 3 times, and we return a tensor of shape (N, 3)
|
||||
|
||||
Returns:
|
||||
xyz: Tensor of shape (N, 3)
|
||||
"""
|
||||
# If x is actually a tensor of shape (N, 3) then just return it
|
||||
if torch.is_tensor(x) and x.dim() == 2:
|
||||
if x.shape[1] != 3:
|
||||
msg = "Expected tensor of shape (N, 3); got %r (in %s)"
|
||||
raise ValueError(msg % (x.shape, name))
|
||||
if y is not None or z is not None:
|
||||
msg = "Expected y and z to be None (in %s)" % name
|
||||
raise ValueError(msg)
|
||||
return x
|
||||
|
||||
if allow_singleton and y is None and z is None:
|
||||
y = x
|
||||
z = x
|
||||
|
||||
# Convert all to 1D tensors
|
||||
xyz = [_handle_coord(c, dtype, device) for c in [x, y, z]]
|
||||
|
||||
# Broadcast and concatenate
|
||||
sizes = [c.shape[0] for c in xyz]
|
||||
N = max(sizes)
|
||||
for c in xyz:
|
||||
if c.shape[0] != 1 and c.shape[0] != N:
|
||||
msg = "Got non-broadcastable sizes %r (in %s)" % (sizes, name)
|
||||
raise ValueError(msg)
|
||||
xyz = [c.expand(N) for c in xyz]
|
||||
xyz = torch.stack(xyz, dim=1)
|
||||
return xyz
|
||||
|
||||
|
||||
def _handle_angle_input(x, dtype, device: str, name: str):
|
||||
"""
|
||||
Helper function for building a rotation function using angles.
|
||||
The output is always of shape (N, 1).
|
||||
|
||||
The input can be one of:
|
||||
- Torch tensor (N, 1) or (N)
|
||||
- Python scalar
|
||||
- Torch scalar
|
||||
"""
|
||||
# If x is actually a tensor of shape (N, 1) then just return it
|
||||
if torch.is_tensor(x) and x.dim() == 2:
|
||||
if x.shape[1] != 1:
|
||||
msg = "Expected tensor of shape (N, 1); got %r (in %s)"
|
||||
raise ValueError(msg % (x.shape, name))
|
||||
return x
|
||||
else:
|
||||
return _handle_coord(x, dtype, device)
|
||||
|
||||
|
||||
def _broadcast_bmm(a, b):
|
||||
"""
|
||||
Batch multiply two matrices and broadcast if necessary.
|
||||
|
||||
Args:
|
||||
a: torch tensor of shape (P, K) or (M, P, K)
|
||||
b: torch tensor of shape (N, K, K)
|
||||
|
||||
Returns:
|
||||
a and b broadcast multipled. The output batch dimension is max(N, M).
|
||||
|
||||
To broadcast transforms across a batch dimension if M != N then
|
||||
expect that either M = 1 or N = 1. The tensor with batch dimension 1 is
|
||||
expanded to have shape N or M.
|
||||
"""
|
||||
if a.dim() == 2:
|
||||
a = a[None]
|
||||
if len(a) != len(b):
|
||||
if not ((len(a) == 1) or (len(b) == 1)):
|
||||
msg = "Expected batch dim for bmm to be equal or 1; got %r, %r"
|
||||
raise ValueError(msg % (a.shape, b.shape))
|
||||
if len(a) == 1:
|
||||
a = a.expand(len(b), -1, -1)
|
||||
if len(b) == 1:
|
||||
b = b.expand(len(a), -1, -1)
|
||||
return a.bmm(b)
|
||||
|
||||
|
||||
def _check_valid_rotation_matrix(R, tol: float = 1e-7):
|
||||
"""
|
||||
Determine if R is a valid rotation matrix by checking it satisfies the
|
||||
following conditions:
|
||||
|
||||
``RR^T = I and det(R) = 1``
|
||||
|
||||
Args:
|
||||
R: an (N, 3, 3) matrix
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Prints an warning if R is an invalid rotation matrix. Else return.
|
||||
"""
|
||||
N = R.shape[0]
|
||||
eye = torch.eye(3, dtype=R.dtype, device=R.device)
|
||||
eye = eye.view(1, 3, 3).expand(N, -1, -1)
|
||||
orthogonal = torch.allclose(R.bmm(R.transpose(1, 2)), eye, atol=tol)
|
||||
det_R = torch.det(R)
|
||||
no_distortion = torch.allclose(det_R, torch.ones_like(det_R))
|
||||
if not (orthogonal and no_distortion):
|
||||
msg = "R is not a valid rotation matrix"
|
||||
print(msg)
|
||||
return
|
||||
Reference in New Issue
Block a user