From 90dc7a08568072375fe9f7ecc3201618fba86287 Mon Sep 17 00:00:00 2001 From: David Novotny Date: Sun, 5 Apr 2020 14:42:22 -0700 Subject: [PATCH] Initialization of Transform3D with a custom matrix. Summary: Allows to initialize a Transform3D object with a batch of user-defined transformation matrices: ``` t = Transform3D(matrix=torch.randn(2, 4, 4)) ``` Reviewed By: nikhilaravi Differential Revision: D20693475 fbshipit-source-id: dccc49b2ca4c19a034844c63463953ba8f52c1bc --- pytorch3d/transforms/transform3d.py | 33 +++++++++++++++++++++++++++-- tests/test_transforms.py | 13 ++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index e3f58050..a5fdeeff 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -134,8 +134,37 @@ class Transform3d: """ - def __init__(self, dtype=torch.float32, device="cpu"): - self._matrix = torch.eye(4, dtype=dtype, device=device).view(1, 4, 4) + def __init__( + self, + dtype: torch.dtype = torch.float32, + device="cpu", + matrix: Optional[torch.Tensor] = None, + ): + """ + Args: + dtype: The data type of the transformation matrix. + to be used if `matrix = None`. + device: The device for storing the implemented transformation. + If `matrix != None`, uses the device of input `matrix`. + matrix: A tensor of shape (4, 4) or of shape (minibatch, 4, 4) + representing the 4x4 3D transformation matrix. + If `None`, initializes with identity using + the specified `device` and `dtype`. + """ + + if matrix is None: + self._matrix = torch.eye(4, dtype=dtype, device=device).view(1, 4, 4) + else: + if matrix.ndim not in (2, 3): + raise ValueError('"matrix" has to be a 2- or a 3-dimensional tensor.') + if matrix.shape[-2] != 4 or matrix.shape[-1] != 4: + raise ValueError( + '"matrix" has to be a tensor of shape (minibatch, 4, 4)' + ) + # set the device from matrix + device = matrix.device + self._matrix = matrix.view(-1, 4, 4) + self._transforms = [] # store transforms to compose self._lu = None self.device = device diff --git a/tests/test_transforms.py b/tests/test_transforms.py index d466937b..b2acdce4 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -57,6 +57,19 @@ class TestTransform(unittest.TestCase): matrix2 = t_pair[1].get_matrix() self.assertTrue(torch.allclose(matrix1, matrix2)) + def test_init_with_custom_matrix(self): + for matrix in (torch.randn(10, 4, 4), torch.randn(4, 4)): + t = Transform3d(matrix=matrix) + self.assertTrue(t.device == matrix.device) + self.assertTrue(t._matrix.dtype == matrix.dtype) + self.assertTrue(torch.allclose(t._matrix, matrix.view(t._matrix.shape))) + + def test_init_with_custom_matrix_errors(self): + bad_shapes = [[10, 5, 4], [3, 4], [10, 4, 4, 1], [10, 4, 4, 2], [4, 4, 4, 3]] + for bad_shape in bad_shapes: + matrix = torch.randn(*bad_shape).float() + self.assertRaises(ValueError, Transform3d, matrix=matrix) + def test_translate(self): t = Transform3d().translate(1, 2, 3) points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(