mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	transforms 3d convention fix
Summary: Fixed the rotation matrices generated by the RotateAxisAngle class and updated the tests. Added documentation for Transforms3d to clarify the conventions. Reviewed By: gkioxari Differential Revision: D19912903 fbshipit-source-id: c64926ce4e1381b145811557c32b73663d6d92d1
This commit is contained in:
		
							parent
							
								
									bdc2bb578c
								
							
						
					
					
						commit
						8301163d24
					
				@ -5,6 +5,32 @@ import functools
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
The transformation matrices returned from the functions in this file assume
 | 
			
		||||
the points on which the transformation will be applied are column vectors.
 | 
			
		||||
i.e. the R matrix is structured as
 | 
			
		||||
 | 
			
		||||
    R = [
 | 
			
		||||
            [Rxx, Rxy, Rxz],
 | 
			
		||||
            [Ryx, Ryy, Ryz],
 | 
			
		||||
            [Rzx, Rzy, Rzz],
 | 
			
		||||
        ]  # (3, 3)
 | 
			
		||||
 | 
			
		||||
This matrix can be applied to column vectors by post multiplication
 | 
			
		||||
by the points e.g.
 | 
			
		||||
 | 
			
		||||
    points = [[0], [1], [2]]  # (3 x 1) xyz coordinates of a point
 | 
			
		||||
    transformed_points = R * points
 | 
			
		||||
 | 
			
		||||
To apply the same matrix to points which are row vectors, the R matrix
 | 
			
		||||
can be transposed and pre multiplied by the points:
 | 
			
		||||
 | 
			
		||||
e.g.
 | 
			
		||||
    points = [[0, 1, 2]]  # (1 x 3) xyz coordinates of a point
 | 
			
		||||
    transformed_points = points * R.transpose(1, 0)
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def quaternion_to_matrix(quaternions):
 | 
			
		||||
    """
 | 
			
		||||
    Convert rotations given as quaternions to rotation matrices.
 | 
			
		||||
@ -80,7 +106,7 @@ def matrix_to_quaternion(matrix):
 | 
			
		||||
    return torch.stack((o0, o1, o2, o3), -1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _primary_matrix(axis: str, angle):
 | 
			
		||||
def _axis_angle_rotation(axis: str, angle):
 | 
			
		||||
    """
 | 
			
		||||
    Return the rotation matrices for one of the rotations about an axis
 | 
			
		||||
    of which Euler angles describe, for each value of the angle given.
 | 
			
		||||
@ -92,17 +118,20 @@ def _primary_matrix(axis: str, angle):
 | 
			
		||||
    Returns:
 | 
			
		||||
        Rotation matrices as tensor of shape (..., 3, 3).
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    cos = torch.cos(angle)
 | 
			
		||||
    sin = torch.sin(angle)
 | 
			
		||||
    one = torch.ones_like(angle)
 | 
			
		||||
    zero = torch.zeros_like(angle)
 | 
			
		||||
 | 
			
		||||
    if axis == "X":
 | 
			
		||||
        o = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
 | 
			
		||||
        R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
 | 
			
		||||
    if axis == "Y":
 | 
			
		||||
        o = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
 | 
			
		||||
        R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
 | 
			
		||||
    if axis == "Z":
 | 
			
		||||
        o = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
 | 
			
		||||
    return torch.stack(o, -1).reshape(angle.shape + (3, 3))
 | 
			
		||||
        R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
 | 
			
		||||
 | 
			
		||||
    return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def euler_angles_to_matrix(euler_angles, convention: str):
 | 
			
		||||
@ -126,7 +155,9 @@ def euler_angles_to_matrix(euler_angles, convention: str):
 | 
			
		||||
    for letter in convention:
 | 
			
		||||
        if letter not in ("X", "Y", "Z"):
 | 
			
		||||
            raise ValueError(f"Invalid letter {letter} in convention string.")
 | 
			
		||||
    matrices = map(_primary_matrix, convention, torch.unbind(euler_angles, -1))
 | 
			
		||||
    matrices = map(
 | 
			
		||||
        _axis_angle_rotation, convention, torch.unbind(euler_angles, -1)
 | 
			
		||||
    )
 | 
			
		||||
    return functools.reduce(torch.matmul, matrices)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -5,6 +5,8 @@ import math
 | 
			
		||||
import warnings
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from .rotation_conversions import _axis_angle_rotation
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Transform3d:
 | 
			
		||||
    """
 | 
			
		||||
@ -103,12 +105,35 @@ class Transform3d:
 | 
			
		||||
            s1_params -= lr * s1_params.grad
 | 
			
		||||
            t_params -= lr * t_params.grad
 | 
			
		||||
            s2_params -= lr * s2_params.grad
 | 
			
		||||
 | 
			
		||||
    CONVENTIONS
 | 
			
		||||
    We adopt a right-hand coordinate system, meaning that rotation about an axis
 | 
			
		||||
    with a positive angle results in a counter clockwise rotation.
 | 
			
		||||
 | 
			
		||||
    This class assumes that transformations are applied on inputs which
 | 
			
		||||
    are row vectors. The internal representation of the Nx4x4 transformation
 | 
			
		||||
    matrix is of the form:
 | 
			
		||||
 | 
			
		||||
    .. code-block:: python
 | 
			
		||||
 | 
			
		||||
        M = [
 | 
			
		||||
                [Rxx, Ryx, Rzx, 0],
 | 
			
		||||
                [Rxy, Ryy, Rzy, 0],
 | 
			
		||||
                [Rxz, Ryz, Rzz, 0],
 | 
			
		||||
                [Tx,  Ty,  Tz,  1],
 | 
			
		||||
            ]
 | 
			
		||||
 | 
			
		||||
    To apply the transformation to points which are row vectors, the M matrix
 | 
			
		||||
    can be pre multiplied by the points:
 | 
			
		||||
 | 
			
		||||
    .. code-block:: python
 | 
			
		||||
 | 
			
		||||
        points = [[0, 1, 2]]  # (1 x 3) xyz coordinates of a point
 | 
			
		||||
        transformed_points = points * M
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, dtype=torch.float32, device="cpu"):
 | 
			
		||||
        """
 | 
			
		||||
        This class assumes a row major ordering for all matrices.
 | 
			
		||||
        """
 | 
			
		||||
        self._matrix = torch.eye(4, dtype=dtype, device=device).view(1, 4, 4)
 | 
			
		||||
        self._transforms = []  # store transforms to compose
 | 
			
		||||
        self._lu = None
 | 
			
		||||
@ -493,9 +518,12 @@ class RotateAxisAngle(Rotate):
 | 
			
		||||
        Create a new Transform3d representing 3D rotation about an axis
 | 
			
		||||
        by an angle.
 | 
			
		||||
 | 
			
		||||
        Assuming a right-hand coordinate system, positive rotation angles result
 | 
			
		||||
        in a counter clockwise rotation.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            angle:
 | 
			
		||||
                - A torch tensor of shape (N, 1)
 | 
			
		||||
                - A torch tensor of shape (N,)
 | 
			
		||||
                - A python scalar
 | 
			
		||||
                - A torch scalar
 | 
			
		||||
            axis:
 | 
			
		||||
@ -509,21 +537,11 @@ class RotateAxisAngle(Rotate):
 | 
			
		||||
            raise ValueError(msg % axis)
 | 
			
		||||
        angle = _handle_angle_input(angle, dtype, device, "RotateAxisAngle")
 | 
			
		||||
        angle = (angle / 180.0 * math.pi) if degrees else angle
 | 
			
		||||
        N = angle.shape[0]
 | 
			
		||||
 | 
			
		||||
        cos = torch.cos(angle)
 | 
			
		||||
        sin = torch.sin(angle)
 | 
			
		||||
        one = torch.ones_like(angle)
 | 
			
		||||
        zero = torch.zeros_like(angle)
 | 
			
		||||
 | 
			
		||||
        if axis == "X":
 | 
			
		||||
            R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
 | 
			
		||||
        if axis == "Y":
 | 
			
		||||
            R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
 | 
			
		||||
        if axis == "Z":
 | 
			
		||||
            R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
 | 
			
		||||
 | 
			
		||||
        R = torch.stack(R_flat, -1).reshape((N, 3, 3))
 | 
			
		||||
        # We assume the points on which this transformation will be applied
 | 
			
		||||
        # are row vectors. The rotation matrix returned from _axis_angle_rotation
 | 
			
		||||
        # is for transforming column vectors. Therefore we transpose this matrix.
 | 
			
		||||
        # R will always be of shape (N, 3, 3)
 | 
			
		||||
        R = _axis_angle_rotation(axis, angle).transpose(1, 2)
 | 
			
		||||
        super().__init__(device=device, R=R)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -606,19 +624,16 @@ def _handle_input(
 | 
			
		||||
def _handle_angle_input(x, dtype, device: str, name: str):
 | 
			
		||||
    """
 | 
			
		||||
    Helper function for building a rotation function using angles.
 | 
			
		||||
    The output is always of shape (N, 1).
 | 
			
		||||
    The output is always of shape (N,).
 | 
			
		||||
 | 
			
		||||
    The input can be one of:
 | 
			
		||||
        - Torch tensor (N, 1) or (N)
 | 
			
		||||
        - Torch tensor of shape (N,)
 | 
			
		||||
        - Python scalar
 | 
			
		||||
        - Torch scalar
 | 
			
		||||
    """
 | 
			
		||||
    # If x is actually a tensor of shape (N, 1) then just return it
 | 
			
		||||
    if torch.is_tensor(x) and x.dim() == 2:
 | 
			
		||||
        if x.shape[1] != 1:
 | 
			
		||||
            msg = "Expected tensor of shape (N, 1); got %r (in %s)"
 | 
			
		||||
            raise ValueError(msg % (x.shape, name))
 | 
			
		||||
        return x
 | 
			
		||||
    if torch.is_tensor(x) and x.dim() > 1:
 | 
			
		||||
        msg = "Expected tensor of shape (N,); got %r (in %s)"
 | 
			
		||||
        raise ValueError(msg % (x.shape, name))
 | 
			
		||||
    else:
 | 
			
		||||
        return _handle_coord(x, dtype, device)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -8,6 +8,7 @@ import unittest
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from pytorch3d.transforms.rotation_conversions import (
 | 
			
		||||
    _axis_angle_rotation,
 | 
			
		||||
    euler_angles_to_matrix,
 | 
			
		||||
    matrix_to_euler_angles,
 | 
			
		||||
    matrix_to_quaternion,
 | 
			
		||||
@ -118,7 +119,6 @@ class TestRotationConversion(unittest.TestCase):
 | 
			
		||||
    def test_to_euler(self):
 | 
			
		||||
        """mtx -> euler -> mtx"""
 | 
			
		||||
        data = random_rotations(13, dtype=torch.float64)
 | 
			
		||||
 | 
			
		||||
        for convention in self._all_euler_angle_conventions():
 | 
			
		||||
            euler_angles = matrix_to_euler_angles(data, convention)
 | 
			
		||||
            mdata = euler_angles_to_matrix(euler_angles, convention)
 | 
			
		||||
 | 
			
		||||
@ -120,7 +120,7 @@ class TestTransform(unittest.TestCase):
 | 
			
		||||
        self.assertTrue(torch.allclose(normals_out, normals_out_expected))
 | 
			
		||||
 | 
			
		||||
    def test_rotate_axis_angle(self):
 | 
			
		||||
        t = Transform3d().rotate_axis_angle(-90.0, axis="Z")
 | 
			
		||||
        t = Transform3d().rotate_axis_angle(90.0, axis="Z")
 | 
			
		||||
        points = torch.tensor(
 | 
			
		||||
            [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 1.0]]
 | 
			
		||||
        ).view(1, 3, 3)
 | 
			
		||||
@ -737,15 +737,23 @@ class TestRotateAxisAngle(unittest.TestCase):
 | 
			
		||||
        matrix = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [
 | 
			
		||||
                    [1.0, 0.0,  0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0, 0.0, -1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0, 1.0,  0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0, 0.0,  0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                    [1.0,  0.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0,  0.0, 1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0, -1.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0,  0.0, 0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                ]
 | 
			
		||||
            ],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
        )
 | 
			
		||||
        # fmt: on
 | 
			
		||||
        points = torch.tensor([0.0, 1.0, 0.0])[None, None, :]  # (1, 1, 3)
 | 
			
		||||
        transformed_points = t.transform_points(points)
 | 
			
		||||
        expected_points = torch.tensor([0.0, 0.0, 1.0])
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            torch.allclose(
 | 
			
		||||
                transformed_points.squeeze(), expected_points, atol=1e-7
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        self.assertTrue(torch.allclose(t._matrix, matrix))
 | 
			
		||||
 | 
			
		||||
    def test_rotate_x_torch_scalar(self):
 | 
			
		||||
@ -755,15 +763,23 @@ class TestRotateAxisAngle(unittest.TestCase):
 | 
			
		||||
        matrix = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [
 | 
			
		||||
                    [1.0, 0.0,  0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0, 0.0, -1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0, 1.0,  0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0, 0.0,  0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                    [1.0,  0.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0,  0.0, 1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0, -1.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0,  0.0, 0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                ]
 | 
			
		||||
            ],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
        )
 | 
			
		||||
        # fmt: on
 | 
			
		||||
        points = torch.tensor([0.0, 1.0, 0.0])[None, None, :]  # (1, 1, 3)
 | 
			
		||||
        transformed_points = t.transform_points(points)
 | 
			
		||||
        expected_points = torch.tensor([0.0, 0.0, 1.0])
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            torch.allclose(
 | 
			
		||||
                transformed_points.squeeze(), expected_points, atol=1e-7
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
 | 
			
		||||
 | 
			
		||||
    def test_rotate_x_torch_tensor(self):
 | 
			
		||||
@ -781,23 +797,23 @@ class TestRotateAxisAngle(unittest.TestCase):
 | 
			
		||||
                    [0.0, 0.0, 0.0, 1.0],
 | 
			
		||||
                ],
 | 
			
		||||
                [
 | 
			
		||||
                    [1.0,  0.0,   0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0, r2_2, -r2_i, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0, r2_i,  r2_2, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0,  0.0,   0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                    [1.0,   0.0,  0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0,  r2_2, r2_i, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0, -r2_i, r2_2, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0,   0.0,  0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                ],
 | 
			
		||||
                [
 | 
			
		||||
                    [1.0, 0.0,  0.0,  0.0],   # noqa: E241, E201
 | 
			
		||||
                    [0.0, 0.0, -1.0,  0.0],   # noqa: E241, E201
 | 
			
		||||
                    [0.0, 1.0,  0.0,  0.0],   # noqa: E241, E201
 | 
			
		||||
                    [0.0, 0.0,  0.0,  1.0],   # noqa: E241, E201
 | 
			
		||||
                    [1.0,  0.0, 0.0,  0.0],   # noqa: E241, E201
 | 
			
		||||
                    [0.0,  0.0, 1.0,  0.0],   # noqa: E241, E201
 | 
			
		||||
                    [0.0, -1.0, 0.0,  0.0],   # noqa: E241, E201
 | 
			
		||||
                    [0.0,  0.0, 0.0,  1.0],   # noqa: E241, E201
 | 
			
		||||
                ]
 | 
			
		||||
            ],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
        )
 | 
			
		||||
        # fmt: on
 | 
			
		||||
        self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
 | 
			
		||||
        angle = angle[..., None]  # (N, 1)
 | 
			
		||||
        angle = angle
 | 
			
		||||
        t = RotateAxisAngle(angle=angle, axis="X")
 | 
			
		||||
        self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
 | 
			
		||||
 | 
			
		||||
@ -807,33 +823,54 @@ class TestRotateAxisAngle(unittest.TestCase):
 | 
			
		||||
        matrix = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [
 | 
			
		||||
                    [ 0.0, 0.0, 1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0, 1.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [-1.0, 0.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0, 0.0, 0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0, 0.0, -1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0, 1.0,  0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [1.0, 0.0,  0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0, 0.0,  0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                ]
 | 
			
		||||
            ],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
        )
 | 
			
		||||
        # fmt: on
 | 
			
		||||
        points = torch.tensor([1.0, 0.0, 0.0])[None, None, :]  # (1, 1, 3)
 | 
			
		||||
        transformed_points = t.transform_points(points)
 | 
			
		||||
        expected_points = torch.tensor([0.0, 0.0, -1.0])
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            torch.allclose(
 | 
			
		||||
                transformed_points.squeeze(), expected_points, atol=1e-7
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
 | 
			
		||||
 | 
			
		||||
    def test_rotate_y_torch_scalar(self):
 | 
			
		||||
        """
 | 
			
		||||
        Test rotation about Y axis. With a right hand coordinate system this
 | 
			
		||||
        should result in a vector pointing along the x-axis being rotated to
 | 
			
		||||
        point along the negative z axis.
 | 
			
		||||
        """
 | 
			
		||||
        angle = torch.tensor(90.0)
 | 
			
		||||
        t = RotateAxisAngle(angle=angle, axis="Y")
 | 
			
		||||
        # fmt: off
 | 
			
		||||
        matrix = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [
 | 
			
		||||
                    [ 0.0, 0.0, 1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0, 1.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [-1.0, 0.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0, 0.0, 0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0, 0.0, -1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0, 1.0,  0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [1.0, 0.0,  0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0, 0.0,  0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                ]
 | 
			
		||||
            ],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
        )
 | 
			
		||||
        # fmt: on
 | 
			
		||||
        points = torch.tensor([1.0, 0.0, 0.0])[None, None, :]  # (1, 1, 3)
 | 
			
		||||
        transformed_points = t.transform_points(points)
 | 
			
		||||
        expected_points = torch.tensor([0.0, 0.0, -1.0])
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            torch.allclose(
 | 
			
		||||
                transformed_points.squeeze(), expected_points, atol=1e-7
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
 | 
			
		||||
 | 
			
		||||
    def test_rotate_y_torch_tensor(self):
 | 
			
		||||
@ -851,16 +888,16 @@ class TestRotateAxisAngle(unittest.TestCase):
 | 
			
		||||
                    [0.0, 0.0, 0.0, 1.0],
 | 
			
		||||
                ],
 | 
			
		||||
                [
 | 
			
		||||
                    [ r2_2,  0.0,  r2_i, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [  0.0,  1.0,   0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [-r2_i,  0.0,  r2_2, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [  0.0,  0.0,   0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                    [r2_2,  0.0, -r2_i, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0,  1.0,   0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [r2_i,  0.0,  r2_2, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0,  0.0,   0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                ],
 | 
			
		||||
                [
 | 
			
		||||
                    [ 0.0, 0.0, 1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0, 1.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [-1.0, 0.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0, 0.0, 0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0,  0.0, -1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0,  1.0,  0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [1.0,  0.0,  0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0,  0.0,  0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                ]
 | 
			
		||||
            ],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
@ -874,15 +911,23 @@ class TestRotateAxisAngle(unittest.TestCase):
 | 
			
		||||
        matrix = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [
 | 
			
		||||
                    [0.0, -1.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [1.0,  0.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0,  0.0, 1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0,  0.0, 0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0, 1.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [-1.0, 0.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0, 0.0, 1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0, 0.0, 0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                ]
 | 
			
		||||
            ],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
        )
 | 
			
		||||
        # fmt: on
 | 
			
		||||
        points = torch.tensor([1.0, 0.0, 0.0])[None, None, :]  # (1, 1, 3)
 | 
			
		||||
        transformed_points = t.transform_points(points)
 | 
			
		||||
        expected_points = torch.tensor([0.0, 1.0, 0.0])
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            torch.allclose(
 | 
			
		||||
                transformed_points.squeeze(), expected_points, atol=1e-7
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
 | 
			
		||||
 | 
			
		||||
    def test_rotate_z_torch_scalar(self):
 | 
			
		||||
@ -892,15 +937,23 @@ class TestRotateAxisAngle(unittest.TestCase):
 | 
			
		||||
        matrix = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [
 | 
			
		||||
                    [0.0, -1.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [1.0,  0.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0,  0.0, 1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0,  0.0, 0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0, 1.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [-1.0, 0.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0, 0.0, 1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0, 0.0, 0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                ]
 | 
			
		||||
            ],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
        )
 | 
			
		||||
        # fmt: on
 | 
			
		||||
        points = torch.tensor([1.0, 0.0, 0.0])[None, None, :]  # (1, 1, 3)
 | 
			
		||||
        transformed_points = t.transform_points(points)
 | 
			
		||||
        expected_points = torch.tensor([0.0, 1.0, 0.0])
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            torch.allclose(
 | 
			
		||||
                transformed_points.squeeze(), expected_points, atol=1e-7
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
 | 
			
		||||
 | 
			
		||||
    def test_rotate_z_torch_tensor(self):
 | 
			
		||||
@ -918,16 +971,16 @@ class TestRotateAxisAngle(unittest.TestCase):
 | 
			
		||||
                    [0.0, 0.0, 0.0, 1.0],
 | 
			
		||||
                ],
 | 
			
		||||
                [
 | 
			
		||||
                    [r2_2,  -r2_i,  0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [r2_i,   r2_2,  0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0,    0.0,  1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0,    0.0,  0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                    [ r2_2,   r2_i,  0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [-r2_i,   r2_2,  0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [  0.0,    0.0,  1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [  0.0,    0.0,  0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                ],
 | 
			
		||||
                [
 | 
			
		||||
                    [0.0, -1.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [1.0,  0.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0,  0.0, 1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0,  0.0, 0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0,  1.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [-1.0,  0.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0,  0.0, 1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0,  0.0, 0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                ]
 | 
			
		||||
            ],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
@ -945,10 +998,10 @@ class TestRotateAxisAngle(unittest.TestCase):
 | 
			
		||||
        matrix1 = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [
 | 
			
		||||
                    [1.0, 0.0,  0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0, 0.0, -1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0, 1.0,  0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0, 0.0,  0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                    [1.0,  0.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0,  0.0, 1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0, -1.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0,  0.0, 0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                ]
 | 
			
		||||
            ],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
@ -956,10 +1009,10 @@ class TestRotateAxisAngle(unittest.TestCase):
 | 
			
		||||
        matrix2 = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [
 | 
			
		||||
                    [ 0.0, 0.0, 1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0, 1.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [-1.0, 0.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0, 0.0, 0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0, 0.0, -1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0, 1.0,  0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [1.0, 0.0,  0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0, 0.0,  0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                ]
 | 
			
		||||
            ],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
@ -967,10 +1020,10 @@ class TestRotateAxisAngle(unittest.TestCase):
 | 
			
		||||
        matrix3 = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [
 | 
			
		||||
                    [0.0, -1.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [1.0,  0.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0,  0.0, 1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0,  0.0, 0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0, 1.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [-1.0, 0.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0, 0.0, 1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0, 0.0, 0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                ]
 | 
			
		||||
            ],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
@ -987,10 +1040,10 @@ class TestRotateAxisAngle(unittest.TestCase):
 | 
			
		||||
        matrix = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [
 | 
			
		||||
                    [0.0, -1.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [1.0,  0.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0,  0.0, 1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0,  0.0, 0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0, 1.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [-1.0, 0.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0, 0.0, 1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0, 0.0, 0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                ]
 | 
			
		||||
            ],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
@ -1004,10 +1057,10 @@ class TestRotateAxisAngle(unittest.TestCase):
 | 
			
		||||
        matrix = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [
 | 
			
		||||
                    [0.0, -1.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [1.0,  0.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0,  0.0, 1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [0.0,  0.0, 0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0, 1.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [-1.0, 0.0, 0.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0, 0.0, 1.0, 0.0],  # noqa: E241, E201
 | 
			
		||||
                    [ 0.0, 0.0, 0.0, 1.0],  # noqa: E241, E201
 | 
			
		||||
                ]
 | 
			
		||||
            ],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user