diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index 0d2e330c..03ae253a 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -2,7 +2,7 @@ import math import warnings -from typing import Optional +from typing import List, Optional, Union import torch @@ -172,6 +172,22 @@ class Transform3d: def __len__(self): return self.get_matrix().shape[0] + def __getitem__( + self, index: Union[int, List[int], slice, torch.Tensor] + ) -> "Transform3d": + """ + Args: + index: Specifying the index of the transform to retrieve. + Can be an int, slice, list of ints, boolean, long tensor. + Supports negative indices. + + Returns: + Transform3d object with selected transforms. The tensors are not cloned. + """ + if isinstance(index, int): + index = [index] + return self.__class__(matrix=self.get_matrix()[index]) + def compose(self, *others): """ Return a new Transform3d with the tranforms to compose stored as @@ -361,6 +377,9 @@ class Transform3d: def scale(self, *args, **kwargs): return self.compose(Scale(device=self.device, *args, **kwargs)) + def rotate(self, *args, **kwargs): + return self.compose(Rotate(device=self.device, *args, **kwargs)) + def rotate_axis_angle(self, *args, **kwargs): return self.compose(RotateAxisAngle(device=self.device, *args, **kwargs)) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index b2acdce4..62404d7a 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -5,6 +5,7 @@ import math import unittest import torch +from common_testing import TestCaseMixin from pytorch3d.transforms.so3 import so3_exponential_map from pytorch3d.transforms.transform3d import ( Rotate, @@ -15,7 +16,7 @@ from pytorch3d.transforms.transform3d import ( ) -class TestTransform(unittest.TestCase): +class TestTransform(TestCaseMixin, unittest.TestCase): def test_to(self): tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]])) R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]) @@ -89,6 +90,22 @@ class TestTransform(unittest.TestCase): self.assertTrue(torch.allclose(points_out, points_out_expected)) self.assertTrue(torch.allclose(normals_out, normals_out_expected)) + def test_rotate(self): + R = so3_exponential_map(torch.randn((1, 3))) + t = Transform3d().rotate(R) + points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view( + 1, 3, 3 + ) + normals = torch.tensor( + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]] + ).view(1, 3, 3) + points_out = t.transform_points(points) + normals_out = t.transform_normals(normals) + points_out_expected = torch.bmm(points, R) + normals_out_expected = torch.bmm(normals, R) + self.assertTrue(torch.allclose(points_out, points_out_expected)) + self.assertTrue(torch.allclose(normals_out, normals_out_expected)) + def test_scale(self): t = Transform3d().scale(2.0).scale(0.5, 0.25, 1.0) points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view( @@ -237,6 +254,103 @@ class TestTransform(unittest.TestCase): for m in (m1, m2, m3, m4): self.assertTrue(torch.allclose(m, m5, atol=1e-3)) + def _check_indexed_transforms(self, t3d, t3d_selected, indices): + t3d_matrix = t3d.get_matrix() + t3d_selected_matrix = t3d_selected.get_matrix() + for order_index, selected_index in indices: + self.assertClose( + t3d_matrix[selected_index], t3d_selected_matrix[order_index] + ) + + def test_get_item(self, batch_size=5): + device = torch.device("cuda:0") + + matrices = torch.randn( + size=[batch_size, 4, 4], device=device, dtype=torch.float32 + ) + + # init the Transforms3D class + t3d = Transform3d(matrix=matrices) + + # int index + index = 1 + t3d_selected = t3d[index] + self.assertEqual(len(t3d_selected), 1) + self._check_indexed_transforms(t3d, t3d_selected, [(0, 1)]) + + # negative int index + index = -1 + t3d_selected = t3d[index] + self.assertEqual(len(t3d_selected), 1) + self._check_indexed_transforms(t3d, t3d_selected, [(0, -1)]) + + # list index + index = [1, 2] + t3d_selected = t3d[index] + self.assertEqual(len(t3d_selected), len(index)) + self._check_indexed_transforms(t3d, t3d_selected, enumerate(index)) + + # empty list index + index = [] + t3d_selected = t3d[index] + self.assertEqual(len(t3d_selected), 0) + self.assertEqual(t3d_selected.get_matrix().nelement(), 0) + + # slice index + index = slice(0, 2, 1) + t3d_selected = t3d[index] + self.assertEqual(len(t3d_selected), 2) + self._check_indexed_transforms(t3d, t3d_selected, [(0, 0), (1, 1)]) + + # empty slice index + index = slice(0, 0, 1) + t3d_selected = t3d[index] + self.assertEqual(len(t3d_selected), 0) + self.assertEqual(t3d_selected.get_matrix().nelement(), 0) + + # bool tensor + index = (torch.rand(batch_size) > 0.5).to(device) + index[:2] = True # make sure smth is selected + t3d_selected = t3d[index] + self.assertEqual(len(t3d_selected), index.sum()) + self._check_indexed_transforms( + t3d, + t3d_selected, + zip( + torch.arange(index.sum()), + torch.nonzero(index, as_tuple=False).squeeze(), + ), + ) + + # all false bool tensor + index = torch.zeros(batch_size).bool() + t3d_selected = t3d[index] + self.assertEqual(len(t3d_selected), 0) + self.assertEqual(t3d_selected.get_matrix().nelement(), 0) + + # int tensor + index = torch.tensor([1, 2], dtype=torch.int64, device=device) + t3d_selected = t3d[index] + self.assertEqual(len(t3d_selected), index.numel()) + self._check_indexed_transforms(t3d, t3d_selected, enumerate(index.tolist())) + + # negative int tensor + index = -(torch.tensor([1, 2], dtype=torch.int64, device=device)) + t3d_selected = t3d[index] + self.assertEqual(len(t3d_selected), index.numel()) + self._check_indexed_transforms(t3d, t3d_selected, enumerate(index.tolist())) + + # invalid index + for invalid_index in ( + torch.tensor([1, 0, 1], dtype=torch.float32, device=device), # float tensor + 1.2, # float index + torch.tensor( + [[1, 0, 1], [1, 0, 1]], dtype=torch.int32, device=device + ), # multidimensional tensor + ): + with self.assertRaises(IndexError): + t3d_selected = t3d[invalid_index] + class TestTranslate(unittest.TestCase): def test_python_scalar(self):