transforms 3d convention fix

Summary: Fixed the rotation matrices generated by the RotateAxisAngle class and updated the tests. Added documentation for Transforms3d to clarify the conventions.

Reviewed By: gkioxari

Differential Revision: D19912903

fbshipit-source-id: c64926ce4e1381b145811557c32b73663d6d92d1
This commit is contained in:
Nikhila Ravi 2020-02-19 10:31:10 -08:00 committed by Facebook Github Bot
parent bdc2bb578c
commit 8301163d24
4 changed files with 203 additions and 104 deletions

View File

@ -5,6 +5,32 @@ import functools
import torch
"""
The transformation matrices returned from the functions in this file assume
the points on which the transformation will be applied are column vectors.
i.e. the R matrix is structured as
R = [
[Rxx, Rxy, Rxz],
[Ryx, Ryy, Ryz],
[Rzx, Rzy, Rzz],
] # (3, 3)
This matrix can be applied to column vectors by post multiplication
by the points e.g.
points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
transformed_points = R * points
To apply the same matrix to points which are row vectors, the R matrix
can be transposed and pre multiplied by the points:
e.g.
points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
transformed_points = points * R.transpose(1, 0)
"""
def quaternion_to_matrix(quaternions):
"""
Convert rotations given as quaternions to rotation matrices.
@ -80,7 +106,7 @@ def matrix_to_quaternion(matrix):
return torch.stack((o0, o1, o2, o3), -1)
def _primary_matrix(axis: str, angle):
def _axis_angle_rotation(axis: str, angle):
"""
Return the rotation matrices for one of the rotations about an axis
of which Euler angles describe, for each value of the angle given.
@ -92,17 +118,20 @@ def _primary_matrix(axis: str, angle):
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
cos = torch.cos(angle)
sin = torch.sin(angle)
one = torch.ones_like(angle)
zero = torch.zeros_like(angle)
if axis == "X":
o = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
if axis == "Y":
o = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
if axis == "Z":
o = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
return torch.stack(o, -1).reshape(angle.shape + (3, 3))
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
def euler_angles_to_matrix(euler_angles, convention: str):
@ -126,7 +155,9 @@ def euler_angles_to_matrix(euler_angles, convention: str):
for letter in convention:
if letter not in ("X", "Y", "Z"):
raise ValueError(f"Invalid letter {letter} in convention string.")
matrices = map(_primary_matrix, convention, torch.unbind(euler_angles, -1))
matrices = map(
_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)
)
return functools.reduce(torch.matmul, matrices)

View File

@ -5,6 +5,8 @@ import math
import warnings
import torch
from .rotation_conversions import _axis_angle_rotation
class Transform3d:
"""
@ -103,12 +105,35 @@ class Transform3d:
s1_params -= lr * s1_params.grad
t_params -= lr * t_params.grad
s2_params -= lr * s2_params.grad
CONVENTIONS
We adopt a right-hand coordinate system, meaning that rotation about an axis
with a positive angle results in a counter clockwise rotation.
This class assumes that transformations are applied on inputs which
are row vectors. The internal representation of the Nx4x4 transformation
matrix is of the form:
.. code-block:: python
M = [
[Rxx, Ryx, Rzx, 0],
[Rxy, Ryy, Rzy, 0],
[Rxz, Ryz, Rzz, 0],
[Tx, Ty, Tz, 1],
]
To apply the transformation to points which are row vectors, the M matrix
can be pre multiplied by the points:
.. code-block:: python
points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
transformed_points = points * M
"""
def __init__(self, dtype=torch.float32, device="cpu"):
"""
This class assumes a row major ordering for all matrices.
"""
self._matrix = torch.eye(4, dtype=dtype, device=device).view(1, 4, 4)
self._transforms = [] # store transforms to compose
self._lu = None
@ -493,9 +518,12 @@ class RotateAxisAngle(Rotate):
Create a new Transform3d representing 3D rotation about an axis
by an angle.
Assuming a right-hand coordinate system, positive rotation angles result
in a counter clockwise rotation.
Args:
angle:
- A torch tensor of shape (N, 1)
- A torch tensor of shape (N,)
- A python scalar
- A torch scalar
axis:
@ -509,21 +537,11 @@ class RotateAxisAngle(Rotate):
raise ValueError(msg % axis)
angle = _handle_angle_input(angle, dtype, device, "RotateAxisAngle")
angle = (angle / 180.0 * math.pi) if degrees else angle
N = angle.shape[0]
cos = torch.cos(angle)
sin = torch.sin(angle)
one = torch.ones_like(angle)
zero = torch.zeros_like(angle)
if axis == "X":
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
if axis == "Y":
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
if axis == "Z":
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
R = torch.stack(R_flat, -1).reshape((N, 3, 3))
# We assume the points on which this transformation will be applied
# are row vectors. The rotation matrix returned from _axis_angle_rotation
# is for transforming column vectors. Therefore we transpose this matrix.
# R will always be of shape (N, 3, 3)
R = _axis_angle_rotation(axis, angle).transpose(1, 2)
super().__init__(device=device, R=R)
@ -606,19 +624,16 @@ def _handle_input(
def _handle_angle_input(x, dtype, device: str, name: str):
"""
Helper function for building a rotation function using angles.
The output is always of shape (N, 1).
The output is always of shape (N,).
The input can be one of:
- Torch tensor (N, 1) or (N)
- Torch tensor of shape (N,)
- Python scalar
- Torch scalar
"""
# If x is actually a tensor of shape (N, 1) then just return it
if torch.is_tensor(x) and x.dim() == 2:
if x.shape[1] != 1:
msg = "Expected tensor of shape (N, 1); got %r (in %s)"
raise ValueError(msg % (x.shape, name))
return x
if torch.is_tensor(x) and x.dim() > 1:
msg = "Expected tensor of shape (N,); got %r (in %s)"
raise ValueError(msg % (x.shape, name))
else:
return _handle_coord(x, dtype, device)

View File

@ -8,6 +8,7 @@ import unittest
import torch
from pytorch3d.transforms.rotation_conversions import (
_axis_angle_rotation,
euler_angles_to_matrix,
matrix_to_euler_angles,
matrix_to_quaternion,
@ -118,7 +119,6 @@ class TestRotationConversion(unittest.TestCase):
def test_to_euler(self):
"""mtx -> euler -> mtx"""
data = random_rotations(13, dtype=torch.float64)
for convention in self._all_euler_angle_conventions():
euler_angles = matrix_to_euler_angles(data, convention)
mdata = euler_angles_to_matrix(euler_angles, convention)

View File

@ -120,7 +120,7 @@ class TestTransform(unittest.TestCase):
self.assertTrue(torch.allclose(normals_out, normals_out_expected))
def test_rotate_axis_angle(self):
t = Transform3d().rotate_axis_angle(-90.0, axis="Z")
t = Transform3d().rotate_axis_angle(90.0, axis="Z")
points = torch.tensor(
[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 1.0]]
).view(1, 3, 3)
@ -737,15 +737,23 @@ class TestRotateAxisAngle(unittest.TestCase):
matrix = torch.tensor(
[
[
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, -1.0, 0.0], # noqa: E241, E201
[0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[0.0, -1.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
]
],
dtype=torch.float32,
)
# fmt: on
points = torch.tensor([0.0, 1.0, 0.0])[None, None, :] # (1, 1, 3)
transformed_points = t.transform_points(points)
expected_points = torch.tensor([0.0, 0.0, 1.0])
self.assertTrue(
torch.allclose(
transformed_points.squeeze(), expected_points, atol=1e-7
)
)
self.assertTrue(torch.allclose(t._matrix, matrix))
def test_rotate_x_torch_scalar(self):
@ -755,15 +763,23 @@ class TestRotateAxisAngle(unittest.TestCase):
matrix = torch.tensor(
[
[
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, -1.0, 0.0], # noqa: E241, E201
[0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[0.0, -1.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
]
],
dtype=torch.float32,
)
# fmt: on
points = torch.tensor([0.0, 1.0, 0.0])[None, None, :] # (1, 1, 3)
transformed_points = t.transform_points(points)
expected_points = torch.tensor([0.0, 0.0, 1.0])
self.assertTrue(
torch.allclose(
transformed_points.squeeze(), expected_points, atol=1e-7
)
)
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
def test_rotate_x_torch_tensor(self):
@ -781,23 +797,23 @@ class TestRotateAxisAngle(unittest.TestCase):
[0.0, 0.0, 0.0, 1.0],
],
[
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, r2_2, -r2_i, 0.0], # noqa: E241, E201
[0.0, r2_i, r2_2, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, r2_2, r2_i, 0.0], # noqa: E241, E201
[0.0, -r2_i, r2_2, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
],
[
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, -1.0, 0.0], # noqa: E241, E201
[0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[0.0, -1.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
]
],
dtype=torch.float32,
)
# fmt: on
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
angle = angle[..., None] # (N, 1)
angle = angle
t = RotateAxisAngle(angle=angle, axis="X")
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
@ -807,33 +823,54 @@ class TestRotateAxisAngle(unittest.TestCase):
matrix = torch.tensor(
[
[
[ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
[0.0, 0.0, -1.0, 0.0], # noqa: E241, E201
[0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
]
],
dtype=torch.float32,
)
# fmt: on
points = torch.tensor([1.0, 0.0, 0.0])[None, None, :] # (1, 1, 3)
transformed_points = t.transform_points(points)
expected_points = torch.tensor([0.0, 0.0, -1.0])
self.assertTrue(
torch.allclose(
transformed_points.squeeze(), expected_points, atol=1e-7
)
)
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
def test_rotate_y_torch_scalar(self):
"""
Test rotation about Y axis. With a right hand coordinate system this
should result in a vector pointing along the x-axis being rotated to
point along the negative z axis.
"""
angle = torch.tensor(90.0)
t = RotateAxisAngle(angle=angle, axis="Y")
# fmt: off
matrix = torch.tensor(
[
[
[ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
[0.0, 0.0, -1.0, 0.0], # noqa: E241, E201
[0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
]
],
dtype=torch.float32,
)
# fmt: on
points = torch.tensor([1.0, 0.0, 0.0])[None, None, :] # (1, 1, 3)
transformed_points = t.transform_points(points)
expected_points = torch.tensor([0.0, 0.0, -1.0])
self.assertTrue(
torch.allclose(
transformed_points.squeeze(), expected_points, atol=1e-7
)
)
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
def test_rotate_y_torch_tensor(self):
@ -851,16 +888,16 @@ class TestRotateAxisAngle(unittest.TestCase):
[0.0, 0.0, 0.0, 1.0],
],
[
[ r2_2, 0.0, r2_i, 0.0], # noqa: E241, E201
[ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[-r2_i, 0.0, r2_2, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
[r2_2, 0.0, -r2_i, 0.0], # noqa: E241, E201
[ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[r2_i, 0.0, r2_2, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
],
[
[ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
[0.0, 0.0, -1.0, 0.0], # noqa: E241, E201
[0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
]
],
dtype=torch.float32,
@ -874,15 +911,23 @@ class TestRotateAxisAngle(unittest.TestCase):
matrix = torch.tensor(
[
[
[0.0, -1.0, 0.0, 0.0], # noqa: E241, E201
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
[ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
]
],
dtype=torch.float32,
)
# fmt: on
points = torch.tensor([1.0, 0.0, 0.0])[None, None, :] # (1, 1, 3)
transformed_points = t.transform_points(points)
expected_points = torch.tensor([0.0, 1.0, 0.0])
self.assertTrue(
torch.allclose(
transformed_points.squeeze(), expected_points, atol=1e-7
)
)
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
def test_rotate_z_torch_scalar(self):
@ -892,15 +937,23 @@ class TestRotateAxisAngle(unittest.TestCase):
matrix = torch.tensor(
[
[
[0.0, -1.0, 0.0, 0.0], # noqa: E241, E201
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
[ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
]
],
dtype=torch.float32,
)
# fmt: on
points = torch.tensor([1.0, 0.0, 0.0])[None, None, :] # (1, 1, 3)
transformed_points = t.transform_points(points)
expected_points = torch.tensor([0.0, 1.0, 0.0])
self.assertTrue(
torch.allclose(
transformed_points.squeeze(), expected_points, atol=1e-7
)
)
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
def test_rotate_z_torch_tensor(self):
@ -918,16 +971,16 @@ class TestRotateAxisAngle(unittest.TestCase):
[0.0, 0.0, 0.0, 1.0],
],
[
[r2_2, -r2_i, 0.0, 0.0], # noqa: E241, E201
[r2_i, r2_2, 0.0, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
[ r2_2, r2_i, 0.0, 0.0], # noqa: E241, E201
[-r2_i, r2_2, 0.0, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
],
[
[0.0, -1.0, 0.0, 0.0], # noqa: E241, E201
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
[ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
]
],
dtype=torch.float32,
@ -945,10 +998,10 @@ class TestRotateAxisAngle(unittest.TestCase):
matrix1 = torch.tensor(
[
[
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, -1.0, 0.0], # noqa: E241, E201
[0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[0.0, -1.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
]
],
dtype=torch.float32,
@ -956,10 +1009,10 @@ class TestRotateAxisAngle(unittest.TestCase):
matrix2 = torch.tensor(
[
[
[ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
[0.0, 0.0, -1.0, 0.0], # noqa: E241, E201
[0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
]
],
dtype=torch.float32,
@ -967,10 +1020,10 @@ class TestRotateAxisAngle(unittest.TestCase):
matrix3 = torch.tensor(
[
[
[0.0, -1.0, 0.0, 0.0], # noqa: E241, E201
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
[ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
]
],
dtype=torch.float32,
@ -987,10 +1040,10 @@ class TestRotateAxisAngle(unittest.TestCase):
matrix = torch.tensor(
[
[
[0.0, -1.0, 0.0, 0.0], # noqa: E241, E201
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
[ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
]
],
dtype=torch.float32,
@ -1004,10 +1057,10 @@ class TestRotateAxisAngle(unittest.TestCase):
matrix = torch.tensor(
[
[
[0.0, -1.0, 0.0, 0.0], # noqa: E241, E201
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
[ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
]
],
dtype=torch.float32,