mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Boolean indexing of cameras
Summary: Reasonable to expect bool indexing. Reviewed By: bottler, kjchalup Differential Revision: D38741446 fbshipit-source-id: 22b607bf13110043c5624196c66ca1484fdbce6c
This commit is contained in:
parent
60808972b8
commit
b7c826b786
@ -385,31 +385,45 @@ class CamerasBase(TensorProperties):
|
||||
return self.image_size if hasattr(self, "image_size") else None
|
||||
|
||||
def __getitem__(
|
||||
self, index: Union[int, List[int], torch.LongTensor]
|
||||
self, index: Union[int, List[int], torch.BoolTensor, torch.LongTensor]
|
||||
) -> "CamerasBase":
|
||||
"""
|
||||
Override for the __getitem__ method in TensorProperties which needs to be
|
||||
refactored.
|
||||
|
||||
Args:
|
||||
index: an int/list/long tensor used to index all the fields in the cameras given by
|
||||
self._FIELDS.
|
||||
index: an integer index, list/tensor of integer indices, or tensor of boolean
|
||||
indicators used to filter all the fields in the cameras given by self._FIELDS.
|
||||
Returns:
|
||||
if `index` is an index int/list/long tensor return an instance of the current
|
||||
cameras class with only the values at the selected index.
|
||||
an instance of the current cameras class with only the values at the selected index.
|
||||
"""
|
||||
|
||||
kwargs = {}
|
||||
|
||||
# pyre-fixme[16]: Module `cuda` has no attribute `LongTensor`.
|
||||
if not isinstance(index, (int, list, torch.LongTensor, torch.cuda.LongTensor)):
|
||||
msg = "Invalid index type, expected int, List[int] or torch.LongTensor; got %r"
|
||||
tensor_types = {
|
||||
"bool": (torch.BoolTensor, torch.cuda.BoolTensor),
|
||||
"long": (torch.LongTensor, torch.cuda.LongTensor),
|
||||
}
|
||||
if not isinstance(
|
||||
index, (int, list, *tensor_types["bool"], *tensor_types["long"])
|
||||
) or (
|
||||
isinstance(index, list)
|
||||
and not all(isinstance(i, int) and not isinstance(i, bool) for i in index)
|
||||
):
|
||||
msg = (
|
||||
"Invalid index type, expected int, List[int] or Bool/LongTensor; got %r"
|
||||
)
|
||||
raise ValueError(msg % type(index))
|
||||
|
||||
if isinstance(index, int):
|
||||
index = [index]
|
||||
|
||||
if max(index) >= len(self):
|
||||
if isinstance(index, tensor_types["bool"]):
|
||||
if index.ndim != 1 or index.shape[0] != len(self):
|
||||
raise ValueError(
|
||||
f"Boolean index of shape {index.shape} does not match cameras"
|
||||
)
|
||||
elif max(index) >= len(self):
|
||||
raise ValueError(f"Index {max(index)} is out of bounds for select cameras")
|
||||
|
||||
for field in self._FIELDS:
|
||||
|
@ -472,7 +472,9 @@ class Meshes:
|
||||
def __len__(self) -> int:
|
||||
return self._N
|
||||
|
||||
def __getitem__(self, index) -> "Meshes":
|
||||
def __getitem__(
|
||||
self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor]
|
||||
) -> "Meshes":
|
||||
"""
|
||||
Args:
|
||||
index: Specifying the index of the mesh to retrieve.
|
||||
|
@ -360,7 +360,10 @@ class Pointclouds:
|
||||
def __len__(self) -> int:
|
||||
return self._N
|
||||
|
||||
def __getitem__(self, index) -> "Pointclouds":
|
||||
def __getitem__(
|
||||
self,
|
||||
index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor],
|
||||
) -> "Pointclouds":
|
||||
"""
|
||||
Args:
|
||||
index: Specifying the index of the cloud to retrieve.
|
||||
|
@ -501,7 +501,10 @@ class Volumes:
|
||||
return self._densities.shape[0]
|
||||
|
||||
def __getitem__(
|
||||
self, index: Union[int, List[int], Tuple[int], slice, torch.Tensor]
|
||||
self,
|
||||
index: Union[
|
||||
int, List[int], Tuple[int], slice, torch.BoolTensor, torch.LongTensor
|
||||
],
|
||||
) -> "Volumes":
|
||||
"""
|
||||
Args:
|
||||
|
@ -181,7 +181,7 @@ class Transform3d:
|
||||
return self.get_matrix().shape[0]
|
||||
|
||||
def __getitem__(
|
||||
self, index: Union[int, List[int], slice, torch.Tensor]
|
||||
self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor]
|
||||
) -> "Transform3d":
|
||||
"""
|
||||
Args:
|
||||
|
@ -884,7 +884,8 @@ class TestFoVPerspectiveProjection(TestCaseMixin, unittest.TestCase):
|
||||
self.assertTrue(new_cam.device == device)
|
||||
|
||||
def test_getitem(self):
|
||||
R_matrix = torch.randn((6, 3, 3))
|
||||
N_CAMERAS = 6
|
||||
R_matrix = torch.randn((N_CAMERAS, 3, 3))
|
||||
cam = FoVPerspectiveCameras(znear=10.0, zfar=100.0, R=R_matrix)
|
||||
|
||||
# Check get item returns an instance of the same class
|
||||
@ -908,22 +909,39 @@ class TestFoVPerspectiveProjection(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(c012.R, R_matrix[0:3, ...])
|
||||
|
||||
# Check torch.LongTensor index
|
||||
index = torch.tensor([1, 3, 5], dtype=torch.int64)
|
||||
SLICE = [1, 3, 5]
|
||||
index = torch.tensor(SLICE, dtype=torch.int64)
|
||||
c135 = cam[index]
|
||||
self.assertEqual(len(c135), 3)
|
||||
self.assertClose(c135.zfar, torch.tensor([100.0] * 3))
|
||||
self.assertClose(c135.znear, torch.tensor([10.0] * 3))
|
||||
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
|
||||
self.assertClose(c135.R, R_matrix[SLICE, ...])
|
||||
|
||||
# Check torch.BoolTensor index
|
||||
bool_slice = [i in SLICE for i in range(N_CAMERAS)]
|
||||
index = torch.tensor(bool_slice, dtype=torch.bool)
|
||||
c135 = cam[index]
|
||||
self.assertEqual(len(c135), 3)
|
||||
self.assertClose(c135.zfar, torch.tensor([100.0] * 3))
|
||||
self.assertClose(c135.znear, torch.tensor([10.0] * 3))
|
||||
self.assertClose(c135.R, R_matrix[SLICE, ...])
|
||||
|
||||
# Check errors with get item
|
||||
with self.assertRaisesRegex(ValueError, "out of bounds"):
|
||||
cam[6]
|
||||
cam[N_CAMERAS]
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "does not match cameras"):
|
||||
index = torch.tensor([1, 0, 1], dtype=torch.bool)
|
||||
cam[index]
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Invalid index type"):
|
||||
cam[slice(0, 1)]
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Invalid index type"):
|
||||
index = torch.tensor([1, 3, 5], dtype=torch.float32)
|
||||
cam[[True, False]]
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Invalid index type"):
|
||||
index = torch.tensor(SLICE, dtype=torch.float32)
|
||||
cam[index]
|
||||
|
||||
def test_get_full_transform(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user