Initial commit

fbshipit-source-id: ad58e416e3ceeca85fae0583308968d04e78fe0d
This commit is contained in:
facebook-github-bot
2020-01-23 11:53:41 -08:00
commit dbf06b504b
211 changed files with 47362 additions and 0 deletions

View File

@@ -0,0 +1,25 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .rotation_conversions import (
euler_angles_to_matrix,
matrix_to_euler_angles,
matrix_to_quaternion,
quaternion_apply,
quaternion_invert,
quaternion_multiply,
quaternion_raw_multiply,
quaternion_to_matrix,
random_quaternions,
random_rotation,
random_rotations,
standardize_quaternion,
)
from .so3 import (
so3_exponential_map,
so3_log_map,
so3_relative_angle,
so3_rotation_angle,
)
from .transform3d import Rotate, RotateAxisAngle, Scale, Transform3d, Translate
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@@ -0,0 +1,374 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import functools
import torch
def quaternion_to_matrix(quaternions):
"""
Convert rotations given as quaternions to rotation matrices.
Args:
quaternions: quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
r, i, j, k = torch.unbind(quaternions, -1)
two_s = 2.0 / (quaternions * quaternions).sum(-1)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))
def _copysign(a, b):
"""
Return a tensor where each element has the absolute value taken from the,
corresponding element of a, with sign taken from the corresponding
element of b. This is like the standard copysign floating-point operation,
but is not careful about negative 0 and NaN.
Args:
a: source tensor.
b: tensor whose signs will be used, of the same shape as a.
Returns:
Tensor of the same shape as a with the signs of b.
"""
signs_differ = (a < 0) != (b < 0)
return torch.where(signs_differ, -a, a)
def matrix_to_quaternion(matrix):
"""
Convert rotations given as rotation matrices to quaternions.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
quaternions with real part first, as tensor of shape (..., 4).
"""
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
zero = matrix.new_zeros((1,))
m00 = matrix[..., 0, 0]
m11 = matrix[..., 1, 1]
m22 = matrix[..., 2, 2]
o0 = 0.5 * torch.sqrt(torch.max(zero, 1 + m00 + m11 + m22))
x = 0.5 * torch.sqrt(torch.max(zero, 1 + m00 - m11 - m22))
y = 0.5 * torch.sqrt(torch.max(zero, 1 - m00 + m11 - m22))
z = 0.5 * torch.sqrt(torch.max(zero, 1 - m00 - m11 + m22))
o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
return torch.stack((o0, o1, o2, o3), -1)
def _primary_matrix(axis: str, angle):
"""
Return the rotation matrices for one of the rotations about an axis
of which Euler angles describe, for each value of the angle given.
Args:
axis: Axis label "X" or "Y or "Z".
angle: any shape tensor of Euler angles in radians
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
cos = torch.cos(angle)
sin = torch.sin(angle)
one = torch.ones_like(angle)
zero = torch.zeros_like(angle)
if axis == "X":
o = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
if axis == "Y":
o = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
if axis == "Z":
o = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
return torch.stack(o, -1).reshape(angle.shape + (3, 3))
def euler_angles_to_matrix(euler_angles, convention: str):
"""
Convert rotations given as Euler angles in radians to rotation matrices.
Args:
euler_angles: Euler angles in radians as tensor of shape (..., 3).
convention: Convention string of three uppercase letters from
{"X", "Y", and "Z"}.
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
raise ValueError("Invalid input euler angles.")
if len(convention) != 3:
raise ValueError("Convention must have 3 letters.")
if convention[1] in (convention[0], convention[2]):
raise ValueError(f"Invalid convention {convention}.")
for letter in convention:
if letter not in ("X", "Y", "Z"):
raise ValueError(f"Invalid letter {letter} in convention string.")
matrices = map(_primary_matrix, convention, torch.unbind(euler_angles, -1))
return functools.reduce(torch.matmul, matrices)
def _angle_from_tan(
axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
):
"""
Extract the first or third Euler angle from the two members of
the matrix which are positive constant times its sine and cosine.
Args:
axis: Axis label "X" or "Y or "Z" for the angle we are finding.
other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
convention.
data: Rotation matrices as tensor of shape (..., 3, 3).
horizontal: Whether we are looking for the angle for the third axis,
which means the relevant entries are in the same row of the
rotation matrix. If not, they are in the same column.
tait_bryan: Whether the first and third axes in the convention differ.
Returns:
Euler Angles in radians for each matrix in data as a tensor
of shape (...).
"""
i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
if horizontal:
i2, i1 = i1, i2
even = (axis + other_axis) in ["XY", "YZ", "ZX"]
if horizontal == even:
return torch.atan2(data[..., i1], data[..., i2])
if tait_bryan:
return torch.atan2(-data[..., i2], data[..., i1])
return torch.atan2(data[..., i2], -data[..., i1])
def _index_from_letter(letter: str):
if letter == "X":
return 0
if letter == "Y":
return 1
if letter == "Z":
return 2
def matrix_to_euler_angles(matrix, convention: str):
"""
Convert rotations given as rotation matrices to Euler angles in radians.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
convention: Convention string of three uppercase letters.
Returns:
Euler angles in radians as tensor of shape (..., 3).
"""
if len(convention) != 3:
raise ValueError("Convention must have 3 letters.")
if convention[1] in (convention[0], convention[2]):
raise ValueError(f"Invalid convention {convention}.")
for letter in convention:
if letter not in ("X", "Y", "Z"):
raise ValueError(f"Invalid letter {letter} in convention string.")
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
i0 = _index_from_letter(convention[0])
i2 = _index_from_letter(convention[2])
tait_bryan = i0 != i2
if tait_bryan:
central_angle = torch.asin(
matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
)
else:
central_angle = torch.acos(matrix[..., i0, i0])
o = (
_angle_from_tan(
convention[0], convention[1], matrix[..., i2], False, tait_bryan
),
central_angle,
_angle_from_tan(
convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
),
)
return torch.stack(o, -1)
def random_quaternions(
n: int, dtype: torch.dtype = None, device=None, requires_grad=False
):
"""
Generate random quaternions representing rotations,
i.e. versors with nonnegative real part.
Args:
n: Number to return.
dtype: Type to return.
device: Desired device of returned tensor. Default:
uses the current device for the default tensor type.
requires_grad: Whether the resulting tensor should have the gradient
flag set.
Returns:
Quaternions as tensor of shape (N, 4).
"""
o = torch.randn(
(n, 4), dtype=dtype, device=device, requires_grad=requires_grad
)
s = (o * o).sum(1)
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
return o
def random_rotations(
n: int, dtype: torch.dtype = None, device=None, requires_grad=False
):
"""
Generate random rotations as 3x3 rotation matrices.
Args:
n: Number to return.
dtype: Type to return.
device: Device of returned tensor. Default: if None,
uses the current device for the default tensor type.
requires_grad: Whether the resulting tensor should have the gradient
flag set.
Returns:
Rotation matrices as tensor of shape (n, 3, 3).
"""
quaternions = random_quaternions(
n, dtype=dtype, device=device, requires_grad=requires_grad
)
return quaternion_to_matrix(quaternions)
def random_rotation(
dtype: torch.dtype = None, device=None, requires_grad=False
):
"""
Generate a single random 3x3 rotation matrix.
Args:
dtype: Type to return
device: Device of returned tensor. Default: if None,
uses the current device for the default tensor type
requires_grad: Whether the resulting tensor should have the gradient
flag set
Returns:
Rotation matrix as tensor of shape (3, 3).
"""
return random_rotations(1, dtype, device, requires_grad)[0]
def standardize_quaternion(quaternions):
"""
Convert a unit quaternion to a standard form: one in which the real
part is non negative.
Args:
quaternions: Quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Standardized quaternions as tensor of shape (..., 4).
"""
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
def quaternion_raw_multiply(a, b):
"""
Multiply two quaternions.
Usual torch rules for broadcasting apply.
Args:
a: Quaternions as tensor of shape (..., 4), real part first.
b: Quaternions as tensor of shape (..., 4), real part first.
Returns:
The product of a and b, a tensor of quaternions shape (..., 4).
"""
aw, ax, ay, az = torch.unbind(a, -1)
bw, bx, by, bz = torch.unbind(b, -1)
ow = aw * bw - ax * bx - ay * by - az * bz
ox = aw * bx + ax * bw + ay * bz - az * by
oy = aw * by - ax * bz + ay * bw + az * bx
oz = aw * bz + ax * by - ay * bx + az * bw
return torch.stack((ow, ox, oy, oz), -1)
def quaternion_multiply(a, b):
"""
Multiply two quaternions representing rotations, returning the quaternion
representing their composition, i.e. the versor with nonnegative real part.
Usual torch rules for broadcasting apply.
Args:
a: Quaternions as tensor of shape (..., 4), real part first.
b: Quaternions as tensor of shape (..., 4), real part first.
Returns:
The product of a and b, a tensor of quaternions of shape (..., 4).
"""
ab = quaternion_raw_multiply(a, b)
return standardize_quaternion(ab)
def quaternion_invert(quaternion):
"""
Given a quaternion representing rotation, get the quaternion representing
its inverse.
Args:
quaternion: Quaternions as tensor of shape (..., 4), with real part
first, which must be versors (unit quaternions).
Returns:
The inverse, a tensor of quaternions of shape (..., 4).
"""
return quaternion * quaternion.new_tensor([1, -1, -1, -1])
def quaternion_apply(quaternion, point):
"""
Apply the rotation given by a quaternion to a 3D point.
Usual torch rules for broadcasting apply.
Args:
quaternion: Tensor of quaternions, real part first, of shape (..., 4).
point: Tensor of 3D points of shape (..., 3).
Returns:
Tensor of rotated points of shape (..., 3).
"""
if point.size(-1) != 3:
raise ValueError(f"Points are not in 3D, f{point.shape}.")
real_parts = point.new_zeros(point.shape[:-1] + (1,))
point_as_quaternion = torch.cat((real_parts, point), -1)
out = quaternion_raw_multiply(
quaternion_raw_multiply(quaternion, point_as_quaternion),
quaternion_invert(quaternion),
)
return out[..., 1:]

236
pytorch3d/transforms/so3.py Normal file
View File

@@ -0,0 +1,236 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5
def so3_relative_angle(R1, R2, cos_angle: bool = False):
"""
Calculates the relative angle (in radians) between pairs of
rotation matrices `R1` and `R2` with `angle = acos(0.5 * Trace(R1 R2^T)-1)`
.. note::
This corresponds to a geodesic distance on the 3D manifold of rotation
matrices.
Args:
R1: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
R2: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
cos_angle: If==True return cosine of the relative angle rather than
the angle itself. This can avoid the unstable
calculation of `acos`.
Returns:
Corresponding rotation angles of shape `(minibatch,)`.
If `cos_angle==True`, returns the cosine of the angles.
Raises:
ValueError if `R1` or `R2` is of incorrect shape.
ValueError if `R1` or `R2` has an unexpected trace.
"""
R12 = torch.bmm(R1, R2.permute(0, 2, 1))
return so3_rotation_angle(R12, cos_angle=cos_angle)
def so3_rotation_angle(R, eps: float = 1e-4, cos_angle: bool = False):
"""
Calculates angles (in radians) of a batch of rotation matrices `R` with
`angle = acos(0.5 * (Trace(R)-1))`. The trace of the
input matrices is checked to be in the valid range `[-1-eps,3+eps]`.
The `eps` argument is a small constant that allows for small errors
caused by limited machine precision.
Args:
R: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
eps: Tolerance for the valid trace check.
cos_angle: If==True return cosine of the rotation angles rather than
the angle itself. This can avoid the unstable
calculation of `acos`.
Returns:
Corresponding rotation angles of shape `(minibatch,)`.
If `cos_angle==True`, returns the cosine of the angles.
Raises:
ValueError if `R` is of incorrect shape.
ValueError if `R` has an unexpected trace.
"""
N, dim1, dim2 = R.shape
if dim1 != 3 or dim2 != 3:
raise ValueError("Input has to be a batch of 3x3 Tensors.")
rot_trace = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2]
if ((rot_trace < -1.0 - eps) + (rot_trace > 3.0 + eps)).any():
raise ValueError(
"A matrix has trace outside valid range [-1-eps,3+eps]."
)
# clamp to valid range
rot_trace = torch.clamp(rot_trace, -1.0, 3.0)
# phi ... rotation angle
phi = 0.5 * (rot_trace - 1.0)
if cos_angle:
return phi
else:
return phi.acos()
def so3_exponential_map(log_rot, eps: float = 0.0001):
"""
Convert a batch of logarithmic representations of rotation matrices `log_rot`
to a batch of 3x3 rotation matrices using Rodrigues formula [1].
In the logarithmic representation, each rotation matrix is represented as
a 3-dimensional vector (`log_rot`) who's l2-norm and direction correspond
to the magnitude of the rotation angle and the axis of rotation respectively.
The conversion has a singularity around `log(R) = 0`
which is handled by clamping controlled with the `eps` argument.
Args:
log_rot: Batch of vectors of shape `(minibatch , 3)`.
eps: A float constant handling the conversion singularity.
Returns:
Batch of rotation matrices of shape `(minibatch , 3 , 3)`.
Raises:
ValueError if `log_rot` is of incorrect shape.
[1] https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
"""
_, dim = log_rot.shape
if dim != 3:
raise ValueError("Input tensor shape has to be Nx3.")
nrms = (log_rot * log_rot).sum(1)
# phis ... rotation angles
rot_angles = torch.clamp(nrms, eps).sqrt()
rot_angles_inv = 1.0 / rot_angles
fac1 = rot_angles_inv * rot_angles.sin()
fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles.cos())
skews = hat(log_rot)
R = (
fac1[:, None, None] * skews
+ fac2[:, None, None] * torch.bmm(skews, skews)
+ torch.eye(3, dtype=log_rot.dtype, device=log_rot.device)[None]
)
return R
def so3_log_map(R, eps: float = 0.0001):
"""
Convert a batch of 3x3 rotation matrices `R`
to a batch of 3-dimensional matrix logarithms of rotation matrices
The conversion has a singularity around `(R=I)` which is handled
by clamping controlled with the `eps` argument.
Args:
R: batch of rotation matrices of shape `(minibatch, 3, 3)`.
eps: A float constant handling the conversion singularity.
Returns:
Batch of logarithms of input rotation matrices
of shape `(minibatch, 3)`.
Raises:
ValueError if `R` is of incorrect shape.
ValueError if `R` has an unexpected trace.
"""
N, dim1, dim2 = R.shape
if dim1 != 3 or dim2 != 3:
raise ValueError("Input has to be a batch of 3x3 Tensors.")
phi = so3_rotation_angle(R)
phi_valid = torch.clamp(phi.abs(), eps) * phi.sign()
log_rot_hat = (phi_valid / (2.0 * phi_valid.sin()))[:, None, None] * (
R - R.permute(0, 2, 1)
)
log_rot = hat_inv(log_rot_hat)
return log_rot
def hat_inv(h):
"""
Compute the inverse Hat operator [1] of a batch of 3x3 matrices.
Args:
h: Batch of skew-symmetric matrices of shape `(minibatch, 3, 3)`.
Returns:
Batch of 3d vectors of shape `(minibatch, 3, 3)`.
Raises:
ValueError if `h` is of incorrect shape.
ValueError if `h` not skew-symmetric.
[1] https://en.wikipedia.org/wiki/Hat_operator
"""
N, dim1, dim2 = h.shape
if dim1 != 3 or dim2 != 3:
raise ValueError("Input has to be a batch of 3x3 Tensors.")
ss_diff = (h + h.permute(0, 2, 1)).abs().max()
if float(ss_diff) > HAT_INV_SKEW_SYMMETRIC_TOL:
raise ValueError("One of input matrices not skew-symmetric.")
x = h[:, 2, 1]
y = h[:, 0, 2]
z = h[:, 1, 0]
v = torch.stack((x, y, z), dim=1)
return v
def hat(v):
"""
Compute the Hat operator [1] of a batch of 3D vectors.
Args:
v: Batch of vectors of shape `(minibatch , 3)`.
Returns:
Batch of skew-symmetric matrices of shape
`(minibatch, 3 , 3)` where each matrix is of the form:
`[ 0 -v_z v_y ]
[ v_z 0 -v_x ]
[ -v_y v_x 0 ]`
Raises:
ValueError if `v` is of incorrect shape.
[1] https://en.wikipedia.org/wiki/Hat_operator
"""
N, dim = v.shape
if dim != 3:
raise ValueError("Input vectors have to be 3-dimensional.")
h = v.new_zeros(N, 3, 3)
x, y, z = v.unbind(1)
h[:, 0, 1] = -z
h[:, 0, 2] = y
h[:, 1, 0] = z
h[:, 1, 2] = -x
h[:, 2, 0] = -y
h[:, 2, 1] = x
return h

View File

@@ -0,0 +1,677 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import math
import torch
class Transform3d:
"""
A Transform3d object encapsulates a batch of N 3D transformations, and knows
how to transform points and normal vectors. Suppose that t is a Transform3d;
then we can do the following:
.. code-block:: python
N = len(t)
points = torch.randn(N, P, 3)
normals = torch.randn(N, P, 3)
points_transformed = t.transform_points(points) # => (N, P, 3)
normals_transformed = t.transform_points(normals) # => (N, P, 3)
BROADCASTING
Transform3d objects supports broadcasting. Suppose that t1 and tN are
Transform3D objects with len(t1) == 1 and len(tN) == N respectively. Then we
can broadcast transforms like this:
.. code-block:: python
t1.transform_points(torch.randn(P, 3)) # => (P, 3)
t1.transform_points(torch.randn(1, P, 3)) # => (1, P, 3)
t1.transform_points(torch.randn(M, P, 3)) # => (M, P, 3)
tN.transform_points(torch.randn(P, 3)) # => (N, P, 3)
tN.transform_points(torch.randn(1, P, 3)) # => (N, P, 3)
COMBINING TRANSFORMS
Transform3d objects can be combined in two ways: composing and stacking.
Composing is function composition. Given Transform3d objects t1, t2, t3,
the following all compute the same thing:
.. code-block:: python
y1 = t3.transform_points(t2.transform_points(t2.transform_points(x)))
y2 = t1.compose(t2).compose(t3).transform_points()
y3 = t1.compose(t2, t3).transform_points()
Composing transforms should broadcast.
.. code-block:: python
if len(t1) == 1 and len(t2) == N, then len(t1.compose(t2)) == N.
We can also stack a sequence of Transform3d objects, which represents
composition along the batch dimension; then the following should compute the
same thing.
.. code-block:: python
N, M = len(tN), len(tM)
xN = torch.randn(N, P, 3)
xM = torch.randn(M, P, 3)
y1 = torch.cat([tN.transform_points(xN), tM.transform_points(xM)], dim=0)
y2 = tN.stack(tM).transform_points(torch.cat([xN, xM], dim=0))
BUILDING TRANSFORMS
We provide convenience methods for easily building Transform3d objects
as compositions of basic transforms.
.. code-block:: python
# Scale by 0.5, then translate by (1, 2, 3)
t1 = Transform3d().scale(0.5).translate(1, 2, 3)
# Scale each axis by a different amount, then translate, then scale
t2 = Transform3d().scale(1, 3, 3).translate(2, 3, 1).scale(2.0)
t3 = t1.compose(t2)
tN = t1.stack(t3, t3)
BACKPROP THROUGH TRANSFORMS
When building transforms, we can also parameterize them by Torch tensors;
in this case we can backprop through the construction and application of
Transform objects, so they could be learned via gradient descent or
predicted by a neural network.
.. code-block:: python
s1_params = torch.randn(N, requires_grad=True)
t_params = torch.randn(N, 3, requires_grad=True)
s2_params = torch.randn(N, 3, requires_grad=True)
t = Transform3d().scale(s1_params).translate(t_params).scale(s2_params)
x = torch.randn(N, 3)
y = t.transform_points(x)
loss = compute_loss(y)
loss.backward()
with torch.no_grad():
s1_params -= lr * s1_params.grad
t_params -= lr * t_params.grad
s2_params -= lr * s2_params.grad
"""
def __init__(self, dtype=torch.float32, device="cpu"):
"""
This class assumes a row major ordering for all matrices.
"""
self._matrix = torch.eye(4, dtype=dtype, device=device).view(1, 4, 4)
self._transforms = [] # store transforms to compose
self._lu = None
self.device = device
def __len__(self):
return self.get_matrix().shape[0]
def compose(self, *others):
"""
Return a new Transform3d with the tranforms to compose stored as
an internal list.
Args:
*others: Any number of Transform3d objects
Returns:
A new Transform3d with the stored transforms
"""
out = Transform3d(device=self.device)
out._matrix = self._matrix.clone()
for other in others:
if not isinstance(other, Transform3d):
msg = "Only possible to compose Transform3d objects; got %s"
raise ValueError(msg % type(other))
out._transforms = self._transforms + list(others)
return out
def get_matrix(self):
"""
Return a matrix which is the result of composing this transform
with others stored in self.transforms. Where necessary transforms
are broadcast against each other.
For example, if self.transforms contains transforms t1, t2, and t3, and
given a set of points x, the following should be true:
.. code-block:: python
y1 = t1.compose(t2, t3).transform(x)
y2 = t3.transform(t2.transform(t1.transform(x)))
y1.get_matrix() == y2.get_matrix()
Returns:
A transformation matrix representing the composed inputs.
"""
composed_matrix = self._matrix.clone()
if len(self._transforms) > 0:
for other in self._transforms:
other_matrix = other.get_matrix()
composed_matrix = _broadcast_bmm(composed_matrix, other_matrix)
return composed_matrix
def _get_matrix_inverse(self):
"""
Return the inverse of self._matrix.
"""
return torch.inverse(self._matrix)
def inverse(self, invert_composed: bool = False):
"""
Returns a new Transform3D object that represents an inverse of the
current transformation.
Args:
invert_composed:
- True: First compose the list of stored transformations
and then apply inverse to the result. This is
potentially slower for classes of transformations
with inverses that can be computed efficiently
(e.g. rotations and translations).
- False: Invert the individual stored transformations
independently without composing them.
Returns:
A new Transform3D object contaning the inverse of the original
transformation.
"""
tinv = Transform3d(device=self.device)
if invert_composed:
# first compose then invert
tinv._matrix = torch.inverse(self.get_matrix())
else:
# self._get_matrix_inverse() implements efficient inverse
# of self._matrix
i_matrix = self._get_matrix_inverse()
# 2 cases:
if len(self._transforms) > 0:
# a) Either we have a non-empty list of transforms:
# Here we take self._matrix and append its inverse at the
# end of the reverted _transforms list. After composing
# the transformations with get_matrix(), this correctly
# right-multiplies by the inverse of self._matrix
# at the end of the composition.
tinv._transforms = [
t.inverse() for t in reversed(self._transforms)
]
last = Transform3d(device=self.device)
last._matrix = i_matrix
tinv._transforms.append(last)
else:
# b) Or there are no stored transformations
# we just set inverted matrix
tinv._matrix = i_matrix
return tinv
def stack(self, *others):
transforms = [self] + list(others)
matrix = torch.cat([t._matrix for t in transforms], dim=0)
out = Transform3d()
out._matrix = matrix
return out
def transform_points(self, points, eps: float = None):
"""
Use this transform to transform a set of 3D points. Assumes row major
ordering of the input points.
Args:
points: Tensor of shape (P, 3) or (N, P, 3)
eps: If eps!=None, the argument is used to clamp the
last coordinate before peforming the final division.
The clamping corresponds to:
last_coord := (last_coord.sign() + (last_coord==0)) *
torch.clamp(last_coord.abs(), eps),
i.e. the last coordinates that are exactly 0 will
be clamped to +eps.
Returns:
points_out: points of shape (N, P, 3) or (P, 3) depending
on the dimensions of the transform
"""
points_batch = points.clone()
if points_batch.dim() == 2:
points_batch = points_batch[None] # (P, 3) -> (1, P, 3)
if points_batch.dim() != 3:
msg = "Expected points to have dim = 2 or dim = 3: got shape %r"
raise ValueError(msg % points.shape)
N, P, _3 = points_batch.shape
ones = torch.ones(N, P, 1, dtype=points.dtype, device=points.device)
points_batch = torch.cat([points_batch, ones], dim=2)
composed_matrix = self.get_matrix()
points_out = _broadcast_bmm(points_batch, composed_matrix)
denom = points_out[..., 3:] # denominator
if eps is not None:
denom_sign = denom.sign() + (denom == 0.0).type_as(denom)
denom = denom_sign * torch.clamp(denom.abs(), eps)
points_out = points_out[..., :3] / denom
# When transform is (1, 4, 4) and points is (P, 3) return
# points_out of shape (P, 3)
if points_out.shape[0] == 1 and points.dim() == 2:
points_out = points_out.reshape(points.shape)
return points_out
def transform_normals(self, normals):
"""
Use this transform to transform a set of normal vectors.
Args:
normals: Tensor of shape (P, 3) or (N, P, 3)
Returns:
normals_out: Tensor of shape (P, 3) or (N, P, 3) depending
on the dimensions of the transform
"""
if normals.dim() not in [2, 3]:
msg = "Expected normals to have dim = 2 or dim = 3: got shape %r"
raise ValueError(msg % normals.shape)
composed_matrix = self.get_matrix()
# TODO: inverse is bad! Solve a linear system instead
mat = composed_matrix[:, :3, :3]
normals_out = _broadcast_bmm(normals, mat.transpose(1, 2).inverse())
# This doesn't pass unit tests. TODO investigate further
# if self._lu is None:
# self._lu = self._matrix[:, :3, :3].transpose(1, 2).lu()
# normals_out = normals.lu_solve(*self._lu)
# When transform is (1, 4, 4) and normals is (P, 3) return
# normals_out of shape (P, 3)
if normals_out.shape[0] == 1 and normals.dim() == 2:
normals_out = normals_out.reshape(normals.shape)
return normals_out
def translate(self, *args, **kwargs):
return self.compose(Translate(device=self.device, *args, **kwargs))
def scale(self, *args, **kwargs):
return self.compose(Scale(device=self.device, *args, **kwargs))
def rotate_axis_angle(self, *args, **kwargs):
return self.compose(
RotateAxisAngle(device=self.device, *args, **kwargs)
)
def clone(self):
"""
Deep copy of Transforms object. All internal tensors are cloned
individually.
Returns:
new Transforms object.
"""
other = Transform3d(device=self.device)
if self._lu is not None:
other._lu = [l.clone() for l in self._lu]
other._matrix = self._matrix.clone()
other._transforms = [t.clone() for t in self._transforms]
return other
def to(self, device, copy: bool = False, dtype=None):
"""
Match functionality of torch.Tensor.to()
If copy = True or the self Tensor is on a different device, the
returned tensor is a copy of self with the desired torch.device.
If copy = False and the self Tensor already has the correct torch.device,
then self is returned.
Args:
device: Device id for the new tensor.
copy: Boolean indicator whether or not to clone self. Default False.
dtype: If not None, casts the internal tensor variables
to a given torch.dtype.
Returns:
Transform3d object.
"""
if not copy and self.device == device:
return self
other = self.clone()
if self.device != device:
other.device = device
other._matrix = self._matrix.to(device=device, dtype=dtype)
for t in other._transforms:
t.to(device, copy=copy, dtype=dtype)
return other
def cpu(self):
return self.to(torch.device("cpu"))
def cuda(self):
return self.to(torch.device("cuda"))
class Translate(Transform3d):
def __init__(
self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"
):
"""
Create a new Transform3d representing 3D translations.
Option I: Translate(xyz, dtype=torch.float32, device='cpu')
xyz should be a tensor of shape (N, 3)
Option II: Translate(x, y, z, dtype=torch.float32, device='cpu')
Here x, y, and z will be broadcast against each other and
concatenated to form the translation. Each can be:
- A python scalar
- A torch scalar
- A 1D torch tensor
"""
super().__init__(device=device)
xyz = _handle_input(x, y, z, dtype, device, "Translate")
N = xyz.shape[0]
mat = torch.eye(4, dtype=dtype, device=device)
mat = mat.view(1, 4, 4).repeat(N, 1, 1)
mat[:, 3, :3] = xyz
self._matrix = mat
def _get_matrix_inverse(self):
"""
Return the inverse of self._matrix.
"""
inv_mask = self._matrix.new_ones([1, 4, 4])
inv_mask[0, 3, :3] = -1.0
i_matrix = self._matrix * inv_mask
return i_matrix
class Scale(Transform3d):
def __init__(
self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"
):
"""
A Transform3d representing a scaling operation, with different scale
factors along each coordinate axis.
Option I: Scale(s, dtype=torch.float32, device='cpu')
s can be one of
- Python scalar or torch scalar: Single uniform scale
- 1D torch tensor of shape (N,): A batch of uniform scale
- 2D torch tensor of shape (N, 3): Scale differently along each axis
Option II: Scale(x, y, z, dtype=torch.float32, device='cpu')
Each of x, y, and z can be one of
- python scalar
- torch scalar
- 1D torch tensor
"""
super().__init__(device=device)
xyz = _handle_input(
x, y, z, dtype, device, "scale", allow_singleton=True
)
N = xyz.shape[0]
# TODO: Can we do this all in one go somehow?
mat = torch.eye(4, dtype=dtype, device=device)
mat = mat.view(1, 4, 4).repeat(N, 1, 1)
mat[:, 0, 0] = xyz[:, 0]
mat[:, 1, 1] = xyz[:, 1]
mat[:, 2, 2] = xyz[:, 2]
self._matrix = mat
def _get_matrix_inverse(self):
"""
Return the inverse of self._matrix.
"""
xyz = torch.stack([self._matrix[:, i, i] for i in range(4)], dim=1)
ixyz = 1.0 / xyz
imat = torch.diag_embed(ixyz, dim1=1, dim2=2)
return imat
class Rotate(Transform3d):
def __init__(
self,
R,
dtype=torch.float32,
device: str = "cpu",
orthogonal_tol: float = 1e-5,
):
"""
Create a new Transform3d representing 3D rotation using a rotation
matrix as the input.
Args:
R: a tensor of shape (3, 3) or (N, 3, 3)
orthogonal_tol: tolerance for the test of the orthogonality of R
"""
super().__init__(device=device)
if R.dim() == 2:
R = R[None]
if R.shape[-2:] != (3, 3):
msg = "R must have shape (3, 3) or (N, 3, 3); got %s"
raise ValueError(msg % repr(R.shape))
R = R.to(dtype=dtype).to(device=device)
_check_valid_rotation_matrix(R, tol=orthogonal_tol)
N = R.shape[0]
mat = torch.eye(4, dtype=dtype, device=device)
mat = mat.view(1, 4, 4).repeat(N, 1, 1)
mat[:, :3, :3] = R
self._matrix = mat
def _get_matrix_inverse(self):
"""
Return the inverse of self._matrix.
"""
return self._matrix.permute(0, 2, 1).contiguous()
class RotateAxisAngle(Rotate):
def __init__(
self,
angle,
axis: str = "X",
degrees: bool = True,
dtype=torch.float64,
device: str = "cpu",
):
"""
Create a new Transform3d representing 3D rotation about an axis
by an angle.
Args:
angle:
- A torch tensor of shape (N, 1)
- A python scalar
- A torch scalar
axis:
string: one of ["X", "Y", "Z"] indicating the axis about which
to rotate.
NOTE: All batch elements are rotated about the same axis.
"""
axis = axis.upper()
if axis not in ["X", "Y", "Z"]:
msg = "Expected axis to be one of ['X', 'Y', 'Z']; got %s"
raise ValueError(msg % axis)
angle = _handle_angle_input(angle, dtype, device, "RotateAxisAngle")
angle = (angle / 180.0 * math.pi) if degrees else angle
N = angle.shape[0]
cos = torch.cos(angle)
sin = torch.sin(angle)
one = torch.ones_like(angle)
zero = torch.zeros_like(angle)
if axis == "X":
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
if axis == "Y":
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
if axis == "Z":
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
R = torch.stack(R_flat, -1).reshape((N, 3, 3))
super().__init__(device=device, R=R)
def _handle_coord(c, dtype, device):
"""
Helper function for _handle_input.
Args:
c: Python scalar, torch scalar, or 1D torch tensor
Returns:
c_vec: 1D torch tensor
"""
if not torch.is_tensor(c):
c = torch.tensor(c, dtype=dtype, device=device)
if c.dim() == 0:
c = c.view(1)
return c
def _handle_input(
x, y, z, dtype, device, name: str, allow_singleton: bool = False
):
"""
Helper function to handle parsing logic for building transforms. The output
is always a tensor of shape (N, 3), but there are several types of allowed
input.
Case I: Single Matrix
In this case x is a tensor of shape (N, 3), and y and z are None. Here just
return x.
Case II: Vectors and Scalars
In this case each of x, y, and z can be one of the following
- Python scalar
- Torch scalar
- Torch tensor of shape (N, 1) or (1, 1)
In this case x, y and z are broadcast to tensors of shape (N, 1)
and concatenated to a tensor of shape (N, 3)
Case III: Singleton (only if allow_singleton=True)
In this case y and z are None, and x can be one of the following:
- Python scalar
- Torch scalar
- Torch tensor of shape (N, 1) or (1, 1)
Here x will be duplicated 3 times, and we return a tensor of shape (N, 3)
Returns:
xyz: Tensor of shape (N, 3)
"""
# If x is actually a tensor of shape (N, 3) then just return it
if torch.is_tensor(x) and x.dim() == 2:
if x.shape[1] != 3:
msg = "Expected tensor of shape (N, 3); got %r (in %s)"
raise ValueError(msg % (x.shape, name))
if y is not None or z is not None:
msg = "Expected y and z to be None (in %s)" % name
raise ValueError(msg)
return x
if allow_singleton and y is None and z is None:
y = x
z = x
# Convert all to 1D tensors
xyz = [_handle_coord(c, dtype, device) for c in [x, y, z]]
# Broadcast and concatenate
sizes = [c.shape[0] for c in xyz]
N = max(sizes)
for c in xyz:
if c.shape[0] != 1 and c.shape[0] != N:
msg = "Got non-broadcastable sizes %r (in %s)" % (sizes, name)
raise ValueError(msg)
xyz = [c.expand(N) for c in xyz]
xyz = torch.stack(xyz, dim=1)
return xyz
def _handle_angle_input(x, dtype, device: str, name: str):
"""
Helper function for building a rotation function using angles.
The output is always of shape (N, 1).
The input can be one of:
- Torch tensor (N, 1) or (N)
- Python scalar
- Torch scalar
"""
# If x is actually a tensor of shape (N, 1) then just return it
if torch.is_tensor(x) and x.dim() == 2:
if x.shape[1] != 1:
msg = "Expected tensor of shape (N, 1); got %r (in %s)"
raise ValueError(msg % (x.shape, name))
return x
else:
return _handle_coord(x, dtype, device)
def _broadcast_bmm(a, b):
"""
Batch multiply two matrices and broadcast if necessary.
Args:
a: torch tensor of shape (P, K) or (M, P, K)
b: torch tensor of shape (N, K, K)
Returns:
a and b broadcast multipled. The output batch dimension is max(N, M).
To broadcast transforms across a batch dimension if M != N then
expect that either M = 1 or N = 1. The tensor with batch dimension 1 is
expanded to have shape N or M.
"""
if a.dim() == 2:
a = a[None]
if len(a) != len(b):
if not ((len(a) == 1) or (len(b) == 1)):
msg = "Expected batch dim for bmm to be equal or 1; got %r, %r"
raise ValueError(msg % (a.shape, b.shape))
if len(a) == 1:
a = a.expand(len(b), -1, -1)
if len(b) == 1:
b = b.expand(len(a), -1, -1)
return a.bmm(b)
def _check_valid_rotation_matrix(R, tol: float = 1e-7):
"""
Determine if R is a valid rotation matrix by checking it satisfies the
following conditions:
``RR^T = I and det(R) = 1``
Args:
R: an (N, 3, 3) matrix
Returns:
None
Prints an warning if R is an invalid rotation matrix. Else return.
"""
N = R.shape[0]
eye = torch.eye(3, dtype=R.dtype, device=R.device)
eye = eye.view(1, 3, 3).expand(N, -1, -1)
orthogonal = torch.allclose(R.bmm(R.transpose(1, 2)), eye, atol=tol)
det_R = torch.det(R)
no_distortion = torch.allclose(det_R, torch.ones_like(det_R))
if not (orthogonal and no_distortion):
msg = "R is not a valid rotation matrix"
print(msg)
return