mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-04-30 02:28:56 +08:00
Boolean indexing of cameras
Summary: Reasonable to expect bool indexing. Reviewed By: bottler, kjchalup Differential Revision: D38741446 fbshipit-source-id: 22b607bf13110043c5624196c66ca1484fdbce6c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
60808972b8
commit
b7c826b786
@@ -385,31 +385,45 @@ class CamerasBase(TensorProperties):
|
||||
return self.image_size if hasattr(self, "image_size") else None
|
||||
|
||||
def __getitem__(
|
||||
self, index: Union[int, List[int], torch.LongTensor]
|
||||
self, index: Union[int, List[int], torch.BoolTensor, torch.LongTensor]
|
||||
) -> "CamerasBase":
|
||||
"""
|
||||
Override for the __getitem__ method in TensorProperties which needs to be
|
||||
refactored.
|
||||
|
||||
Args:
|
||||
index: an int/list/long tensor used to index all the fields in the cameras given by
|
||||
self._FIELDS.
|
||||
index: an integer index, list/tensor of integer indices, or tensor of boolean
|
||||
indicators used to filter all the fields in the cameras given by self._FIELDS.
|
||||
Returns:
|
||||
if `index` is an index int/list/long tensor return an instance of the current
|
||||
cameras class with only the values at the selected index.
|
||||
an instance of the current cameras class with only the values at the selected index.
|
||||
"""
|
||||
|
||||
kwargs = {}
|
||||
|
||||
# pyre-fixme[16]: Module `cuda` has no attribute `LongTensor`.
|
||||
if not isinstance(index, (int, list, torch.LongTensor, torch.cuda.LongTensor)):
|
||||
msg = "Invalid index type, expected int, List[int] or torch.LongTensor; got %r"
|
||||
tensor_types = {
|
||||
"bool": (torch.BoolTensor, torch.cuda.BoolTensor),
|
||||
"long": (torch.LongTensor, torch.cuda.LongTensor),
|
||||
}
|
||||
if not isinstance(
|
||||
index, (int, list, *tensor_types["bool"], *tensor_types["long"])
|
||||
) or (
|
||||
isinstance(index, list)
|
||||
and not all(isinstance(i, int) and not isinstance(i, bool) for i in index)
|
||||
):
|
||||
msg = (
|
||||
"Invalid index type, expected int, List[int] or Bool/LongTensor; got %r"
|
||||
)
|
||||
raise ValueError(msg % type(index))
|
||||
|
||||
if isinstance(index, int):
|
||||
index = [index]
|
||||
|
||||
if max(index) >= len(self):
|
||||
if isinstance(index, tensor_types["bool"]):
|
||||
if index.ndim != 1 or index.shape[0] != len(self):
|
||||
raise ValueError(
|
||||
f"Boolean index of shape {index.shape} does not match cameras"
|
||||
)
|
||||
elif max(index) >= len(self):
|
||||
raise ValueError(f"Index {max(index)} is out of bounds for select cameras")
|
||||
|
||||
for field in self._FIELDS:
|
||||
|
||||
@@ -472,7 +472,9 @@ class Meshes:
|
||||
def __len__(self) -> int:
|
||||
return self._N
|
||||
|
||||
def __getitem__(self, index) -> "Meshes":
|
||||
def __getitem__(
|
||||
self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor]
|
||||
) -> "Meshes":
|
||||
"""
|
||||
Args:
|
||||
index: Specifying the index of the mesh to retrieve.
|
||||
|
||||
@@ -360,7 +360,10 @@ class Pointclouds:
|
||||
def __len__(self) -> int:
|
||||
return self._N
|
||||
|
||||
def __getitem__(self, index) -> "Pointclouds":
|
||||
def __getitem__(
|
||||
self,
|
||||
index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor],
|
||||
) -> "Pointclouds":
|
||||
"""
|
||||
Args:
|
||||
index: Specifying the index of the cloud to retrieve.
|
||||
|
||||
@@ -501,7 +501,10 @@ class Volumes:
|
||||
return self._densities.shape[0]
|
||||
|
||||
def __getitem__(
|
||||
self, index: Union[int, List[int], Tuple[int], slice, torch.Tensor]
|
||||
self,
|
||||
index: Union[
|
||||
int, List[int], Tuple[int], slice, torch.BoolTensor, torch.LongTensor
|
||||
],
|
||||
) -> "Volumes":
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -181,7 +181,7 @@ class Transform3d:
|
||||
return self.get_matrix().shape[0]
|
||||
|
||||
def __getitem__(
|
||||
self, index: Union[int, List[int], slice, torch.Tensor]
|
||||
self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor]
|
||||
) -> "Transform3d":
|
||||
"""
|
||||
Args:
|
||||
|
||||
Reference in New Issue
Block a user