transforms 3d convention fix

Summary: Fixed the rotation matrices generated by the RotateAxisAngle class and updated the tests. Added documentation for Transforms3d to clarify the conventions.

Reviewed By: gkioxari

Differential Revision: D19912903

fbshipit-source-id: c64926ce4e1381b145811557c32b73663d6d92d1
This commit is contained in:
Nikhila Ravi 2020-02-19 10:31:10 -08:00 committed by Facebook Github Bot
parent bdc2bb578c
commit 8301163d24
4 changed files with 203 additions and 104 deletions

View File

@ -5,6 +5,32 @@ import functools
import torch import torch
"""
The transformation matrices returned from the functions in this file assume
the points on which the transformation will be applied are column vectors.
i.e. the R matrix is structured as
R = [
[Rxx, Rxy, Rxz],
[Ryx, Ryy, Ryz],
[Rzx, Rzy, Rzz],
] # (3, 3)
This matrix can be applied to column vectors by post multiplication
by the points e.g.
points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
transformed_points = R * points
To apply the same matrix to points which are row vectors, the R matrix
can be transposed and pre multiplied by the points:
e.g.
points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
transformed_points = points * R.transpose(1, 0)
"""
def quaternion_to_matrix(quaternions): def quaternion_to_matrix(quaternions):
""" """
Convert rotations given as quaternions to rotation matrices. Convert rotations given as quaternions to rotation matrices.
@ -80,7 +106,7 @@ def matrix_to_quaternion(matrix):
return torch.stack((o0, o1, o2, o3), -1) return torch.stack((o0, o1, o2, o3), -1)
def _primary_matrix(axis: str, angle): def _axis_angle_rotation(axis: str, angle):
""" """
Return the rotation matrices for one of the rotations about an axis Return the rotation matrices for one of the rotations about an axis
of which Euler angles describe, for each value of the angle given. of which Euler angles describe, for each value of the angle given.
@ -92,17 +118,20 @@ def _primary_matrix(axis: str, angle):
Returns: Returns:
Rotation matrices as tensor of shape (..., 3, 3). Rotation matrices as tensor of shape (..., 3, 3).
""" """
cos = torch.cos(angle) cos = torch.cos(angle)
sin = torch.sin(angle) sin = torch.sin(angle)
one = torch.ones_like(angle) one = torch.ones_like(angle)
zero = torch.zeros_like(angle) zero = torch.zeros_like(angle)
if axis == "X": if axis == "X":
o = (one, zero, zero, zero, cos, -sin, zero, sin, cos) R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
if axis == "Y": if axis == "Y":
o = (cos, zero, sin, zero, one, zero, -sin, zero, cos) R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
if axis == "Z": if axis == "Z":
o = (cos, -sin, zero, sin, cos, zero, zero, zero, one) R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
return torch.stack(o, -1).reshape(angle.shape + (3, 3))
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
def euler_angles_to_matrix(euler_angles, convention: str): def euler_angles_to_matrix(euler_angles, convention: str):
@ -126,7 +155,9 @@ def euler_angles_to_matrix(euler_angles, convention: str):
for letter in convention: for letter in convention:
if letter not in ("X", "Y", "Z"): if letter not in ("X", "Y", "Z"):
raise ValueError(f"Invalid letter {letter} in convention string.") raise ValueError(f"Invalid letter {letter} in convention string.")
matrices = map(_primary_matrix, convention, torch.unbind(euler_angles, -1)) matrices = map(
_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)
)
return functools.reduce(torch.matmul, matrices) return functools.reduce(torch.matmul, matrices)

View File

@ -5,6 +5,8 @@ import math
import warnings import warnings
import torch import torch
from .rotation_conversions import _axis_angle_rotation
class Transform3d: class Transform3d:
""" """
@ -103,12 +105,35 @@ class Transform3d:
s1_params -= lr * s1_params.grad s1_params -= lr * s1_params.grad
t_params -= lr * t_params.grad t_params -= lr * t_params.grad
s2_params -= lr * s2_params.grad s2_params -= lr * s2_params.grad
CONVENTIONS
We adopt a right-hand coordinate system, meaning that rotation about an axis
with a positive angle results in a counter clockwise rotation.
This class assumes that transformations are applied on inputs which
are row vectors. The internal representation of the Nx4x4 transformation
matrix is of the form:
.. code-block:: python
M = [
[Rxx, Ryx, Rzx, 0],
[Rxy, Ryy, Rzy, 0],
[Rxz, Ryz, Rzz, 0],
[Tx, Ty, Tz, 1],
]
To apply the transformation to points which are row vectors, the M matrix
can be pre multiplied by the points:
.. code-block:: python
points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
transformed_points = points * M
""" """
def __init__(self, dtype=torch.float32, device="cpu"): def __init__(self, dtype=torch.float32, device="cpu"):
"""
This class assumes a row major ordering for all matrices.
"""
self._matrix = torch.eye(4, dtype=dtype, device=device).view(1, 4, 4) self._matrix = torch.eye(4, dtype=dtype, device=device).view(1, 4, 4)
self._transforms = [] # store transforms to compose self._transforms = [] # store transforms to compose
self._lu = None self._lu = None
@ -493,9 +518,12 @@ class RotateAxisAngle(Rotate):
Create a new Transform3d representing 3D rotation about an axis Create a new Transform3d representing 3D rotation about an axis
by an angle. by an angle.
Assuming a right-hand coordinate system, positive rotation angles result
in a counter clockwise rotation.
Args: Args:
angle: angle:
- A torch tensor of shape (N, 1) - A torch tensor of shape (N,)
- A python scalar - A python scalar
- A torch scalar - A torch scalar
axis: axis:
@ -509,21 +537,11 @@ class RotateAxisAngle(Rotate):
raise ValueError(msg % axis) raise ValueError(msg % axis)
angle = _handle_angle_input(angle, dtype, device, "RotateAxisAngle") angle = _handle_angle_input(angle, dtype, device, "RotateAxisAngle")
angle = (angle / 180.0 * math.pi) if degrees else angle angle = (angle / 180.0 * math.pi) if degrees else angle
N = angle.shape[0] # We assume the points on which this transformation will be applied
# are row vectors. The rotation matrix returned from _axis_angle_rotation
cos = torch.cos(angle) # is for transforming column vectors. Therefore we transpose this matrix.
sin = torch.sin(angle) # R will always be of shape (N, 3, 3)
one = torch.ones_like(angle) R = _axis_angle_rotation(axis, angle).transpose(1, 2)
zero = torch.zeros_like(angle)
if axis == "X":
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
if axis == "Y":
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
if axis == "Z":
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
R = torch.stack(R_flat, -1).reshape((N, 3, 3))
super().__init__(device=device, R=R) super().__init__(device=device, R=R)
@ -606,19 +624,16 @@ def _handle_input(
def _handle_angle_input(x, dtype, device: str, name: str): def _handle_angle_input(x, dtype, device: str, name: str):
""" """
Helper function for building a rotation function using angles. Helper function for building a rotation function using angles.
The output is always of shape (N, 1). The output is always of shape (N,).
The input can be one of: The input can be one of:
- Torch tensor (N, 1) or (N) - Torch tensor of shape (N,)
- Python scalar - Python scalar
- Torch scalar - Torch scalar
""" """
# If x is actually a tensor of shape (N, 1) then just return it if torch.is_tensor(x) and x.dim() > 1:
if torch.is_tensor(x) and x.dim() == 2: msg = "Expected tensor of shape (N,); got %r (in %s)"
if x.shape[1] != 1:
msg = "Expected tensor of shape (N, 1); got %r (in %s)"
raise ValueError(msg % (x.shape, name)) raise ValueError(msg % (x.shape, name))
return x
else: else:
return _handle_coord(x, dtype, device) return _handle_coord(x, dtype, device)

View File

@ -8,6 +8,7 @@ import unittest
import torch import torch
from pytorch3d.transforms.rotation_conversions import ( from pytorch3d.transforms.rotation_conversions import (
_axis_angle_rotation,
euler_angles_to_matrix, euler_angles_to_matrix,
matrix_to_euler_angles, matrix_to_euler_angles,
matrix_to_quaternion, matrix_to_quaternion,
@ -118,7 +119,6 @@ class TestRotationConversion(unittest.TestCase):
def test_to_euler(self): def test_to_euler(self):
"""mtx -> euler -> mtx""" """mtx -> euler -> mtx"""
data = random_rotations(13, dtype=torch.float64) data = random_rotations(13, dtype=torch.float64)
for convention in self._all_euler_angle_conventions(): for convention in self._all_euler_angle_conventions():
euler_angles = matrix_to_euler_angles(data, convention) euler_angles = matrix_to_euler_angles(data, convention)
mdata = euler_angles_to_matrix(euler_angles, convention) mdata = euler_angles_to_matrix(euler_angles, convention)

View File

@ -120,7 +120,7 @@ class TestTransform(unittest.TestCase):
self.assertTrue(torch.allclose(normals_out, normals_out_expected)) self.assertTrue(torch.allclose(normals_out, normals_out_expected))
def test_rotate_axis_angle(self): def test_rotate_axis_angle(self):
t = Transform3d().rotate_axis_angle(-90.0, axis="Z") t = Transform3d().rotate_axis_angle(90.0, axis="Z")
points = torch.tensor( points = torch.tensor(
[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 1.0]] [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 1.0]]
).view(1, 3, 3) ).view(1, 3, 3)
@ -738,14 +738,22 @@ class TestRotateAxisAngle(unittest.TestCase):
[ [
[ [
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, -1.0, 0.0], # noqa: E241, E201 [0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 [0.0, -1.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
] ]
], ],
dtype=torch.float32, dtype=torch.float32,
) )
# fmt: on # fmt: on
points = torch.tensor([0.0, 1.0, 0.0])[None, None, :] # (1, 1, 3)
transformed_points = t.transform_points(points)
expected_points = torch.tensor([0.0, 0.0, 1.0])
self.assertTrue(
torch.allclose(
transformed_points.squeeze(), expected_points, atol=1e-7
)
)
self.assertTrue(torch.allclose(t._matrix, matrix)) self.assertTrue(torch.allclose(t._matrix, matrix))
def test_rotate_x_torch_scalar(self): def test_rotate_x_torch_scalar(self):
@ -756,14 +764,22 @@ class TestRotateAxisAngle(unittest.TestCase):
[ [
[ [
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, -1.0, 0.0], # noqa: E241, E201 [0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 [0.0, -1.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
] ]
], ],
dtype=torch.float32, dtype=torch.float32,
) )
# fmt: on # fmt: on
points = torch.tensor([0.0, 1.0, 0.0])[None, None, :] # (1, 1, 3)
transformed_points = t.transform_points(points)
expected_points = torch.tensor([0.0, 0.0, 1.0])
self.assertTrue(
torch.allclose(
transformed_points.squeeze(), expected_points, atol=1e-7
)
)
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7)) self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
def test_rotate_x_torch_tensor(self): def test_rotate_x_torch_tensor(self):
@ -782,14 +798,14 @@ class TestRotateAxisAngle(unittest.TestCase):
], ],
[ [
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, r2_2, -r2_i, 0.0], # noqa: E241, E201 [0.0, r2_2, r2_i, 0.0], # noqa: E241, E201
[0.0, r2_i, r2_2, 0.0], # noqa: E241, E201 [0.0, -r2_i, r2_2, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
], ],
[ [
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, -1.0, 0.0], # noqa: E241, E201 [0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 [0.0, -1.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
] ]
], ],
@ -797,7 +813,7 @@ class TestRotateAxisAngle(unittest.TestCase):
) )
# fmt: on # fmt: on
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7)) self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
angle = angle[..., None] # (N, 1) angle = angle
t = RotateAxisAngle(angle=angle, axis="X") t = RotateAxisAngle(angle=angle, axis="X")
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7)) self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
@ -807,33 +823,54 @@ class TestRotateAxisAngle(unittest.TestCase):
matrix = torch.tensor( matrix = torch.tensor(
[ [
[ [
[ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 [0.0, 0.0, -1.0, 0.0], # noqa: E241, E201
[ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 [0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
] ]
], ],
dtype=torch.float32, dtype=torch.float32,
) )
# fmt: on # fmt: on
points = torch.tensor([1.0, 0.0, 0.0])[None, None, :] # (1, 1, 3)
transformed_points = t.transform_points(points)
expected_points = torch.tensor([0.0, 0.0, -1.0])
self.assertTrue(
torch.allclose(
transformed_points.squeeze(), expected_points, atol=1e-7
)
)
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7)) self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
def test_rotate_y_torch_scalar(self): def test_rotate_y_torch_scalar(self):
"""
Test rotation about Y axis. With a right hand coordinate system this
should result in a vector pointing along the x-axis being rotated to
point along the negative z axis.
"""
angle = torch.tensor(90.0) angle = torch.tensor(90.0)
t = RotateAxisAngle(angle=angle, axis="Y") t = RotateAxisAngle(angle=angle, axis="Y")
# fmt: off # fmt: off
matrix = torch.tensor( matrix = torch.tensor(
[ [
[ [
[ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 [0.0, 0.0, -1.0, 0.0], # noqa: E241, E201
[ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 [0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
] ]
], ],
dtype=torch.float32, dtype=torch.float32,
) )
# fmt: on # fmt: on
points = torch.tensor([1.0, 0.0, 0.0])[None, None, :] # (1, 1, 3)
transformed_points = t.transform_points(points)
expected_points = torch.tensor([0.0, 0.0, -1.0])
self.assertTrue(
torch.allclose(
transformed_points.squeeze(), expected_points, atol=1e-7
)
)
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7)) self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
def test_rotate_y_torch_tensor(self): def test_rotate_y_torch_tensor(self):
@ -851,16 +888,16 @@ class TestRotateAxisAngle(unittest.TestCase):
[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0],
], ],
[ [
[ r2_2, 0.0, r2_i, 0.0], # noqa: E241, E201 [r2_2, 0.0, -r2_i, 0.0], # noqa: E241, E201
[ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 [ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[-r2_i, 0.0, r2_2, 0.0], # noqa: E241, E201 [r2_i, 0.0, r2_2, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 [ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
], ],
[ [
[ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 [0.0, 0.0, -1.0, 0.0], # noqa: E241, E201
[ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 [0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
] ]
], ],
dtype=torch.float32, dtype=torch.float32,
@ -874,15 +911,23 @@ class TestRotateAxisAngle(unittest.TestCase):
matrix = torch.tensor( matrix = torch.tensor(
[ [
[ [
[0.0, -1.0, 0.0, 0.0], # noqa: E241, E201 [ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 [-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 [ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 [ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
] ]
], ],
dtype=torch.float32, dtype=torch.float32,
) )
# fmt: on # fmt: on
points = torch.tensor([1.0, 0.0, 0.0])[None, None, :] # (1, 1, 3)
transformed_points = t.transform_points(points)
expected_points = torch.tensor([0.0, 1.0, 0.0])
self.assertTrue(
torch.allclose(
transformed_points.squeeze(), expected_points, atol=1e-7
)
)
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7)) self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
def test_rotate_z_torch_scalar(self): def test_rotate_z_torch_scalar(self):
@ -892,15 +937,23 @@ class TestRotateAxisAngle(unittest.TestCase):
matrix = torch.tensor( matrix = torch.tensor(
[ [
[ [
[0.0, -1.0, 0.0, 0.0], # noqa: E241, E201 [ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 [-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 [ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 [ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
] ]
], ],
dtype=torch.float32, dtype=torch.float32,
) )
# fmt: on # fmt: on
points = torch.tensor([1.0, 0.0, 0.0])[None, None, :] # (1, 1, 3)
transformed_points = t.transform_points(points)
expected_points = torch.tensor([0.0, 1.0, 0.0])
self.assertTrue(
torch.allclose(
transformed_points.squeeze(), expected_points, atol=1e-7
)
)
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7)) self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
def test_rotate_z_torch_tensor(self): def test_rotate_z_torch_tensor(self):
@ -918,16 +971,16 @@ class TestRotateAxisAngle(unittest.TestCase):
[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0],
], ],
[ [
[r2_2, -r2_i, 0.0, 0.0], # noqa: E241, E201 [ r2_2, r2_i, 0.0, 0.0], # noqa: E241, E201
[r2_i, r2_2, 0.0, 0.0], # noqa: E241, E201 [-r2_i, r2_2, 0.0, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 [ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 [ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
], ],
[ [
[0.0, -1.0, 0.0, 0.0], # noqa: E241, E201 [ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 [-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 [ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 [ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
] ]
], ],
dtype=torch.float32, dtype=torch.float32,
@ -946,8 +999,8 @@ class TestRotateAxisAngle(unittest.TestCase):
[ [
[ [
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, -1.0, 0.0], # noqa: E241, E201 [0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 [0.0, -1.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
] ]
], ],
@ -956,10 +1009,10 @@ class TestRotateAxisAngle(unittest.TestCase):
matrix2 = torch.tensor( matrix2 = torch.tensor(
[ [
[ [
[ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 [0.0, 0.0, -1.0, 0.0], # noqa: E241, E201
[ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201 [0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 [1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 [0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
] ]
], ],
dtype=torch.float32, dtype=torch.float32,
@ -967,10 +1020,10 @@ class TestRotateAxisAngle(unittest.TestCase):
matrix3 = torch.tensor( matrix3 = torch.tensor(
[ [
[ [
[0.0, -1.0, 0.0, 0.0], # noqa: E241, E201 [ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 [-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 [ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 [ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
] ]
], ],
dtype=torch.float32, dtype=torch.float32,
@ -987,10 +1040,10 @@ class TestRotateAxisAngle(unittest.TestCase):
matrix = torch.tensor( matrix = torch.tensor(
[ [
[ [
[0.0, -1.0, 0.0, 0.0], # noqa: E241, E201 [ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 [-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 [ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 [ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
] ]
], ],
dtype=torch.float32, dtype=torch.float32,
@ -1004,10 +1057,10 @@ class TestRotateAxisAngle(unittest.TestCase):
matrix = torch.tensor( matrix = torch.tensor(
[ [
[ [
[0.0, -1.0, 0.0, 0.0], # noqa: E241, E201 [ 0.0, 1.0, 0.0, 0.0], # noqa: E241, E201
[1.0, 0.0, 0.0, 0.0], # noqa: E241, E201 [-1.0, 0.0, 0.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 1.0, 0.0], # noqa: E241, E201 [ 0.0, 0.0, 1.0, 0.0], # noqa: E241, E201
[0.0, 0.0, 0.0, 1.0], # noqa: E241, E201 [ 0.0, 0.0, 0.0, 1.0], # noqa: E241, E201
] ]
], ],
dtype=torch.float32, dtype=torch.float32,