Boolean indexing of cameras

Summary: Reasonable to expect bool indexing.

Reviewed By: bottler, kjchalup

Differential Revision: D38741446

fbshipit-source-id: 22b607bf13110043c5624196c66ca1484fdbce6c
This commit is contained in:
Roman Shapovalov 2022-08-16 15:19:39 -07:00 committed by Facebook GitHub Bot
parent 60808972b8
commit b7c826b786
6 changed files with 58 additions and 18 deletions

View File

@ -385,31 +385,45 @@ class CamerasBase(TensorProperties):
return self.image_size if hasattr(self, "image_size") else None
def __getitem__(
self, index: Union[int, List[int], torch.LongTensor]
self, index: Union[int, List[int], torch.BoolTensor, torch.LongTensor]
) -> "CamerasBase":
"""
Override for the __getitem__ method in TensorProperties which needs to be
refactored.
Args:
index: an int/list/long tensor used to index all the fields in the cameras given by
self._FIELDS.
index: an integer index, list/tensor of integer indices, or tensor of boolean
indicators used to filter all the fields in the cameras given by self._FIELDS.
Returns:
if `index` is an index int/list/long tensor return an instance of the current
cameras class with only the values at the selected index.
an instance of the current cameras class with only the values at the selected index.
"""
kwargs = {}
# pyre-fixme[16]: Module `cuda` has no attribute `LongTensor`.
if not isinstance(index, (int, list, torch.LongTensor, torch.cuda.LongTensor)):
msg = "Invalid index type, expected int, List[int] or torch.LongTensor; got %r"
tensor_types = {
"bool": (torch.BoolTensor, torch.cuda.BoolTensor),
"long": (torch.LongTensor, torch.cuda.LongTensor),
}
if not isinstance(
index, (int, list, *tensor_types["bool"], *tensor_types["long"])
) or (
isinstance(index, list)
and not all(isinstance(i, int) and not isinstance(i, bool) for i in index)
):
msg = (
"Invalid index type, expected int, List[int] or Bool/LongTensor; got %r"
)
raise ValueError(msg % type(index))
if isinstance(index, int):
index = [index]
if max(index) >= len(self):
if isinstance(index, tensor_types["bool"]):
if index.ndim != 1 or index.shape[0] != len(self):
raise ValueError(
f"Boolean index of shape {index.shape} does not match cameras"
)
elif max(index) >= len(self):
raise ValueError(f"Index {max(index)} is out of bounds for select cameras")
for field in self._FIELDS:

View File

@ -472,7 +472,9 @@ class Meshes:
def __len__(self) -> int:
return self._N
def __getitem__(self, index) -> "Meshes":
def __getitem__(
self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor]
) -> "Meshes":
"""
Args:
index: Specifying the index of the mesh to retrieve.

View File

@ -360,7 +360,10 @@ class Pointclouds:
def __len__(self) -> int:
return self._N
def __getitem__(self, index) -> "Pointclouds":
def __getitem__(
self,
index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor],
) -> "Pointclouds":
"""
Args:
index: Specifying the index of the cloud to retrieve.

View File

@ -501,7 +501,10 @@ class Volumes:
return self._densities.shape[0]
def __getitem__(
self, index: Union[int, List[int], Tuple[int], slice, torch.Tensor]
self,
index: Union[
int, List[int], Tuple[int], slice, torch.BoolTensor, torch.LongTensor
],
) -> "Volumes":
"""
Args:

View File

@ -181,7 +181,7 @@ class Transform3d:
return self.get_matrix().shape[0]
def __getitem__(
self, index: Union[int, List[int], slice, torch.Tensor]
self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor]
) -> "Transform3d":
"""
Args:

View File

@ -884,7 +884,8 @@ class TestFoVPerspectiveProjection(TestCaseMixin, unittest.TestCase):
self.assertTrue(new_cam.device == device)
def test_getitem(self):
R_matrix = torch.randn((6, 3, 3))
N_CAMERAS = 6
R_matrix = torch.randn((N_CAMERAS, 3, 3))
cam = FoVPerspectiveCameras(znear=10.0, zfar=100.0, R=R_matrix)
# Check get item returns an instance of the same class
@ -908,22 +909,39 @@ class TestFoVPerspectiveProjection(TestCaseMixin, unittest.TestCase):
self.assertClose(c012.R, R_matrix[0:3, ...])
# Check torch.LongTensor index
index = torch.tensor([1, 3, 5], dtype=torch.int64)
SLICE = [1, 3, 5]
index = torch.tensor(SLICE, dtype=torch.int64)
c135 = cam[index]
self.assertEqual(len(c135), 3)
self.assertClose(c135.zfar, torch.tensor([100.0] * 3))
self.assertClose(c135.znear, torch.tensor([10.0] * 3))
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
self.assertClose(c135.R, R_matrix[SLICE, ...])
# Check torch.BoolTensor index
bool_slice = [i in SLICE for i in range(N_CAMERAS)]
index = torch.tensor(bool_slice, dtype=torch.bool)
c135 = cam[index]
self.assertEqual(len(c135), 3)
self.assertClose(c135.zfar, torch.tensor([100.0] * 3))
self.assertClose(c135.znear, torch.tensor([10.0] * 3))
self.assertClose(c135.R, R_matrix[SLICE, ...])
# Check errors with get item
with self.assertRaisesRegex(ValueError, "out of bounds"):
cam[6]
cam[N_CAMERAS]
with self.assertRaisesRegex(ValueError, "does not match cameras"):
index = torch.tensor([1, 0, 1], dtype=torch.bool)
cam[index]
with self.assertRaisesRegex(ValueError, "Invalid index type"):
cam[slice(0, 1)]
with self.assertRaisesRegex(ValueError, "Invalid index type"):
index = torch.tensor([1, 3, 5], dtype=torch.float32)
cam[[True, False]]
with self.assertRaisesRegex(ValueError, "Invalid index type"):
index = torch.tensor(SLICE, dtype=torch.float32)
cam[index]
def test_get_full_transform(self):