From c8f3d6bc0bc366d44c8fca790c5e433503c7785f Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Tue, 15 Feb 2022 06:46:38 -0800 Subject: [PATCH] Fix Transform3d.stack of compositions Summary: Add a test for Transform3d.stack, and make it work with composed transformations. Fixes https://github.com/facebookresearch/pytorch3d/issues/1072 . Reviewed By: patricklabatut Differential Revision: D34211920 fbshipit-source-id: bfbd0895494ca2ad3d08a61bc82ba23637e168cc --- pytorch3d/renderer/cameras.py | 2 +- pytorch3d/transforms/transform3d.py | 62 +++++++++++++++++------------ tests/test_transforms.py | 26 ++++++++++++ 3 files changed, 64 insertions(+), 26 deletions(-) diff --git a/pytorch3d/renderer/cameras.py b/pytorch3d/renderer/cameras.py index a840d54b..2dcc9110 100644 --- a/pytorch3d/renderer/cameras.py +++ b/pytorch3d/renderer/cameras.py @@ -1649,7 +1649,7 @@ def look_at_view_transform( elev=0.0, azim=0.0, degrees: bool = True, - eye: Optional[Sequence] = None, + eye: Optional[Union[Sequence, torch.Tensor]] = None, at=((0, 0, 0),), # (1, 3) up=((0, 1, 0),), # (1, 3) device: Device = "cpu", diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index db30ec48..a3dea212 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -196,10 +196,10 @@ class Transform3d: index = [index] return self.__class__(matrix=self.get_matrix()[index]) - def compose(self, *others): + def compose(self, *others: "Transform3d") -> "Transform3d": """ - Return a new Transform3d with the transforms to compose stored as - an internal list. + Return a new Transform3d representing the composition of self with the + given other transforms, which will be stored as an internal list. Args: *others: Any number of Transform3d objects @@ -216,7 +216,7 @@ class Transform3d: out._transforms = self._transforms + list(others) return out - def get_matrix(self): + def get_matrix(self) -> torch.Tensor: """ Return a matrix which is the result of composing this transform with others stored in self.transforms. Where necessary transforms @@ -240,13 +240,13 @@ class Transform3d: composed_matrix = _broadcast_bmm(composed_matrix, other_matrix) return composed_matrix - def _get_matrix_inverse(self): + def _get_matrix_inverse(self) -> torch.Tensor: """ Return the inverse of self._matrix. """ return torch.inverse(self._matrix) - def inverse(self, invert_composed: bool = False): + def inverse(self, invert_composed: bool = False) -> "Transform3d": """ Returns a new Transform3d object that represents an inverse of the current transformation. @@ -295,14 +295,24 @@ class Transform3d: return tinv - def stack(self, *others): + def stack(self, *others: "Transform3d") -> "Transform3d": + """ + Return a new batched Transform3d representing the batch elements from + self and all the given other transforms all batched together. + + Args: + *others: Any number of Transform3d objects + + Returns: + A new Transform3d. + """ transforms = [self] + list(others) - matrix = torch.cat([t._matrix for t in transforms], dim=0) + matrix = torch.cat([t.get_matrix() for t in transforms], dim=0) out = Transform3d(dtype=self.dtype, device=self.device) out._matrix = matrix return out - def transform_points(self, points, eps: Optional[float] = None): + def transform_points(self, points, eps: Optional[float] = None) -> torch.Tensor: """ Use this transform to transform a set of 3D points. Assumes row major ordering of the input points. @@ -347,7 +357,7 @@ class Transform3d: return points_out - def transform_normals(self, normals): + def transform_normals(self, normals) -> torch.Tensor: """ Use this transform to transform a set of normal vectors. @@ -379,19 +389,19 @@ class Transform3d: return normals_out - def translate(self, *args, **kwargs): + def translate(self, *args, **kwargs) -> "Transform3d": return self.compose(Translate(device=self.device, *args, **kwargs)) - def scale(self, *args, **kwargs): + def scale(self, *args, **kwargs) -> "Transform3d": return self.compose(Scale(device=self.device, *args, **kwargs)) - def rotate(self, *args, **kwargs): + def rotate(self, *args, **kwargs) -> "Transform3d": return self.compose(Rotate(device=self.device, *args, **kwargs)) - def rotate_axis_angle(self, *args, **kwargs): + def rotate_axis_angle(self, *args, **kwargs) -> "Transform3d": return self.compose(RotateAxisAngle(device=self.device, *args, **kwargs)) - def clone(self): + def clone(self) -> "Transform3d": """ Deep copy of Transforms object. All internal tensors are cloned individually. @@ -411,7 +421,7 @@ class Transform3d: device: Device, copy: bool = False, dtype: Optional[torch.dtype] = None, - ): + ) -> "Transform3d": """ Match functionality of torch.Tensor.to() If copy = True or the self Tensor is on a different device, the @@ -448,10 +458,10 @@ class Transform3d: ] return other - def cpu(self): + def cpu(self) -> "Transform3d": return self.to("cpu") - def cuda(self): + def cuda(self) -> "Transform3d": return self.to("cuda") @@ -486,7 +496,7 @@ class Translate(Transform3d): mat[:, 3, :3] = xyz self._matrix = mat - def _get_matrix_inverse(self): + def _get_matrix_inverse(self) -> torch.Tensor: """ Return the inverse of self._matrix. """ @@ -533,7 +543,7 @@ class Scale(Transform3d): mat[:, 2, 2] = xyz[:, 2] self._matrix = mat - def _get_matrix_inverse(self): + def _get_matrix_inverse(self) -> torch.Tensor: """ Return the inverse of self._matrix. """ @@ -575,7 +585,7 @@ class Rotate(Transform3d): mat[:, :3, :3] = R self._matrix = mat - def _get_matrix_inverse(self): + def _get_matrix_inverse(self) -> torch.Tensor: """ Return the inverse of self._matrix. """ @@ -622,7 +632,7 @@ class RotateAxisAngle(Rotate): super().__init__(device=angle.device, R=R) -def _handle_coord(c, dtype: torch.dtype, device: torch.device): +def _handle_coord(c, dtype: torch.dtype, device: torch.device) -> torch.Tensor: """ Helper function for _handle_input. @@ -649,7 +659,7 @@ def _handle_input( device: Optional[Device], name: str, allow_singleton: bool = False, -): +) -> torch.Tensor: """ Helper function to handle parsing logic for building transforms. The output is always a tensor of shape (N, 3), but there are several types of allowed @@ -707,7 +717,9 @@ def _handle_input( return xyz -def _handle_angle_input(x, dtype: torch.dtype, device: Optional[Device], name: str): +def _handle_angle_input( + x, dtype: torch.dtype, device: Optional[Device], name: str +) -> torch.Tensor: """ Helper function for building a rotation function using angles. The output is always of shape (N,). @@ -725,7 +737,7 @@ def _handle_angle_input(x, dtype: torch.dtype, device: Optional[Device], name: s return _handle_coord(x, dtype, device_) -def _broadcast_bmm(a, b): +def _broadcast_bmm(a, b) -> torch.Tensor: """ Batch multiply two matrices and broadcast if necessary. diff --git a/tests/test_transforms.py b/tests/test_transforms.py index a819f5b3..f4690a41 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -10,6 +10,7 @@ import unittest import torch from common_testing import TestCaseMixin +from pytorch3d.transforms import random_rotations from pytorch3d.transforms.so3 import so3_exp_map from pytorch3d.transforms.transform3d import ( Rotate, @@ -21,6 +22,9 @@ from pytorch3d.transforms.transform3d import ( class TestTransform(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(42) + def test_to(self): tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]])) R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]) @@ -406,6 +410,28 @@ class TestTransform(TestCaseMixin, unittest.TestCase): with self.assertRaises(IndexError): t3d_selected = t3d[invalid_index] + def test_stack(self): + rotations = random_rotations(3) + transform3 = Transform3d().rotate(rotations).translate(torch.full((3, 3), 0.3)) + transform1 = Scale(37) + transform4 = transform1.stack(transform3) + self.assertEqual(len(transform1), 1) + self.assertEqual(len(transform3), 3) + self.assertEqual(len(transform4), 4) + self.assertClose( + transform4.get_matrix(), + torch.cat([transform1.get_matrix(), transform3.get_matrix()]), + ) + points = torch.rand(4, 5, 3) + new_points_expect = torch.cat( + [ + transform1.transform_points(points[:1]), + transform3.transform_points(points[1:]), + ] + ) + new_points = transform4.transform_points(points) + self.assertClose(new_points, new_points_expect) + class TestTranslate(unittest.TestCase): def test_python_scalar(self):