mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Allow indexing for classes inheriting Transform3d (#1801)
Summary: Currently, it is not possible to access a sub-transform using an indexer for all 3d transforms inheriting the `Transforms3d` class. For instance: ```python from pytorch3d import transforms N = 10 r = transforms.random_rotations(N) T = transforms.Transform3d().rotate(R=r) R = transforms.Rotate(r) x = T[0] # ok x = R[0] # TypeError: __init__() got an unexpected keyword argument 'matrix' ``` This is because all these classes (namely `Rotate`, `Translate`, `Scale`, `RotateAxisAngle`) inherit the `__getitem__()` method from `Transform3d` which has the [following code on line 201](https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/transform3d.py#L201): ```python return self.__class__(matrix=self.get_matrix()[index]) ``` The four classes inheriting `Transform3d` are not initialized through a matrix argument, hence they error. I propose to modify the `__getitem__()` method of the `Transform3d` class to fix this behavior. The least invasive way to do it I can think of consists of creating an empty instance of the current class, then setting the `_matrix` attribute manually. Thus, instead of ```python return self.__class__(matrix=self.get_matrix()[index]) ``` I propose to do: ```python instance = self.__class__.__new__(self.__class__) instance._matrix = self.get_matrix()[index] return instance ``` As far as I can tell, this modification occurs no modification whatsoever for the user, except for the ability to index all 3d transforms. Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1801 Reviewed By: MichaelRamamonjisoa Differential Revision: D58410389 Pulled By: bottler fbshipit-source-id: f371e4c63d2ae4c927a7ad48c2de8862761078de
This commit is contained in:
		
							parent
							
								
									b66d17a324
								
							
						
					
					
						commit
						b0462d8079
					
				@ -564,6 +564,22 @@ class Translate(Transform3d):
 | 
			
		||||
        i_matrix = self._matrix * inv_mask
 | 
			
		||||
        return i_matrix
 | 
			
		||||
 | 
			
		||||
    def __getitem__(
 | 
			
		||||
        self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor]
 | 
			
		||||
    ) -> "Transform3d":
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            index: Specifying the index of the transform to retrieve.
 | 
			
		||||
                Can be an int, slice, list of ints, boolean, long tensor.
 | 
			
		||||
                Supports negative indices.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Transform3d object with selected transforms. The tensors are not cloned.
 | 
			
		||||
        """
 | 
			
		||||
        if isinstance(index, int):
 | 
			
		||||
            index = [index]
 | 
			
		||||
        return self.__class__(self.get_matrix()[index, 3, :3])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Scale(Transform3d):
 | 
			
		||||
    def __init__(
 | 
			
		||||
@ -613,6 +629,26 @@ class Scale(Transform3d):
 | 
			
		||||
        imat = torch.diag_embed(ixyz, dim1=1, dim2=2)
 | 
			
		||||
        return imat
 | 
			
		||||
 | 
			
		||||
    def __getitem__(
 | 
			
		||||
        self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor]
 | 
			
		||||
    ) -> "Transform3d":
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            index: Specifying the index of the transform to retrieve.
 | 
			
		||||
                Can be an int, slice, list of ints, boolean, long tensor.
 | 
			
		||||
                Supports negative indices.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Transform3d object with selected transforms. The tensors are not cloned.
 | 
			
		||||
        """
 | 
			
		||||
        if isinstance(index, int):
 | 
			
		||||
            index = [index]
 | 
			
		||||
        mat = self.get_matrix()[index]
 | 
			
		||||
        x = mat[:, 0, 0]
 | 
			
		||||
        y = mat[:, 1, 1]
 | 
			
		||||
        z = mat[:, 2, 2]
 | 
			
		||||
        return self.__class__(x, y, z)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Rotate(Transform3d):
 | 
			
		||||
    def __init__(
 | 
			
		||||
@ -655,6 +691,22 @@ class Rotate(Transform3d):
 | 
			
		||||
        """
 | 
			
		||||
        return self._matrix.permute(0, 2, 1).contiguous()
 | 
			
		||||
 | 
			
		||||
    def __getitem__(
 | 
			
		||||
        self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor]
 | 
			
		||||
    ) -> "Transform3d":
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            index: Specifying the index of the transform to retrieve.
 | 
			
		||||
                Can be an int, slice, list of ints, boolean, long tensor.
 | 
			
		||||
                Supports negative indices.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Transform3d object with selected transforms. The tensors are not cloned.
 | 
			
		||||
        """
 | 
			
		||||
        if isinstance(index, int):
 | 
			
		||||
            index = [index]
 | 
			
		||||
        return self.__class__(self.get_matrix()[index, :3, :3])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RotateAxisAngle(Rotate):
 | 
			
		||||
    def __init__(
 | 
			
		||||
 | 
			
		||||
@ -685,6 +685,15 @@ class TestTranslate(unittest.TestCase):
 | 
			
		||||
        self.assertTrue(torch.allclose(im, im_comp))
 | 
			
		||||
        self.assertTrue(torch.allclose(im, im_2))
 | 
			
		||||
 | 
			
		||||
    def test_get_item(self, batch_size=5):
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        xyz = torch.randn(size=[batch_size, 3], device=device, dtype=torch.float32)
 | 
			
		||||
        t3d = Translate(xyz)
 | 
			
		||||
        index = 1
 | 
			
		||||
        t3d_selected = t3d[index]
 | 
			
		||||
        self.assertEqual(len(t3d_selected), 1)
 | 
			
		||||
        self.assertIsInstance(t3d_selected, Translate)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestScale(unittest.TestCase):
 | 
			
		||||
    def test_single_python_scalar(self):
 | 
			
		||||
@ -871,6 +880,15 @@ class TestScale(unittest.TestCase):
 | 
			
		||||
        self.assertTrue(torch.allclose(im, im_comp))
 | 
			
		||||
        self.assertTrue(torch.allclose(im, im_2))
 | 
			
		||||
 | 
			
		||||
    def test_get_item(self, batch_size=5):
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        s = torch.randn(size=[batch_size, 3], device=device, dtype=torch.float32)
 | 
			
		||||
        t3d = Scale(s)
 | 
			
		||||
        index = 1
 | 
			
		||||
        t3d_selected = t3d[index]
 | 
			
		||||
        self.assertEqual(len(t3d_selected), 1)
 | 
			
		||||
        self.assertIsInstance(t3d_selected, Scale)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestTransformBroadcast(unittest.TestCase):
 | 
			
		||||
    def test_broadcast_transform_points(self):
 | 
			
		||||
@ -986,6 +1004,15 @@ class TestRotate(unittest.TestCase):
 | 
			
		||||
        self.assertTrue(torch.allclose(im, im_comp, atol=1e-4))
 | 
			
		||||
        self.assertTrue(torch.allclose(im, im_2, atol=1e-4))
 | 
			
		||||
 | 
			
		||||
    def test_get_item(self, batch_size=5):
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        r = random_rotations(batch_size, dtype=torch.float32, device=device)
 | 
			
		||||
        t3d = Rotate(r)
 | 
			
		||||
        index = 1
 | 
			
		||||
        t3d_selected = t3d[index]
 | 
			
		||||
        self.assertEqual(len(t3d_selected), 1)
 | 
			
		||||
        self.assertIsInstance(t3d_selected, Rotate)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestRotateAxisAngle(unittest.TestCase):
 | 
			
		||||
    def test_rotate_x_python_scalar(self):
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user