__getitem__ for Transform3D

Summary: Implements the `__getitem__` method for `Transform3D`

Reviewed By: nikhilaravi

Differential Revision: D23813975

fbshipit-source-id: 5da752ed8ea029ad0af58bb7a7856f0995519b7a
This commit is contained in:
David Novotny
2021-01-05 03:37:38 -08:00
committed by Facebook GitHub Bot
parent ac3f8dc833
commit 1e4a2e8624
2 changed files with 135 additions and 2 deletions

View File

@@ -2,7 +2,7 @@
import math
import warnings
from typing import Optional
from typing import List, Optional, Union
import torch
@@ -172,6 +172,22 @@ class Transform3d:
def __len__(self):
return self.get_matrix().shape[0]
def __getitem__(
self, index: Union[int, List[int], slice, torch.Tensor]
) -> "Transform3d":
"""
Args:
index: Specifying the index of the transform to retrieve.
Can be an int, slice, list of ints, boolean, long tensor.
Supports negative indices.
Returns:
Transform3d object with selected transforms. The tensors are not cloned.
"""
if isinstance(index, int):
index = [index]
return self.__class__(matrix=self.get_matrix()[index])
def compose(self, *others):
"""
Return a new Transform3d with the tranforms to compose stored as
@@ -361,6 +377,9 @@ class Transform3d:
def scale(self, *args, **kwargs):
return self.compose(Scale(device=self.device, *args, **kwargs))
def rotate(self, *args, **kwargs):
return self.compose(Rotate(device=self.device, *args, **kwargs))
def rotate_axis_angle(self, *args, **kwargs):
return self.compose(RotateAxisAngle(device=self.device, *args, **kwargs))