Enable __getitem__ for Cameras to return an instance of Cameras

Summary:
Added a custom `__getitem__` method to `CamerasBase` which returns an instance of the appropriate camera instead of the `TensorAccessor` class.

Long term we should deprecate the `TensorAccessor` and the `__getitem__` method on `TensorProperties`

FB: In the next diff I will update the uses of `select_cameras` in implicitron.

Reviewed By: bottler

Differential Revision: D33185885

fbshipit-source-id: c31995d0eb126981e91ba61a6151d5404b263f67
This commit is contained in:
Nikhila Ravi
2021-12-21 05:45:32 -08:00
committed by Facebook GitHub Bot
parent cc3259ba93
commit 28ccdb7328
3 changed files with 224 additions and 13 deletions

View File

@@ -75,6 +75,10 @@ class CamerasBase(TensorProperties):
boolean argument of the function.
"""
# Used in __getitem__ to index the relevant fields
# When creating a new camera, this should be set in the __init__
_FIELDS: Tuple = ()
def get_projection_transform(self):
"""
Calculate the projective transformation matrix.
@@ -362,6 +366,55 @@ class CamerasBase(TensorProperties):
"""
return self.image_size if hasattr(self, "image_size") else None
def __getitem__(
self, index: Union[int, List[int], torch.LongTensor]
) -> "CamerasBase":
"""
Override for the __getitem__ method in TensorProperties which needs to be
refactored.
Args:
index: an int/list/long tensor used to index all the fields in the cameras given by
self._FIELDS.
Returns:
if `index` is an index int/list/long tensor return an instance of the current
cameras class with only the values at the selected index.
"""
kwargs = {}
if not isinstance(index, (int, list, torch.LongTensor)):
msg = "Invalid index type, expected int, List[int] or torch.LongTensor; got %r"
raise ValueError(msg % type(index))
if isinstance(index, int):
index = [index]
if max(index) >= len(self):
raise ValueError(f"Index {max(index)} is out of bounds for select cameras")
for field in self._FIELDS:
val = getattr(self, field, None)
if val is None:
continue
# e.g. "in_ndc" is set as attribute "_in_ndc" on the class
# but provided as "in_ndc" on initialization
if field.startswith("_"):
field = field[1:]
if isinstance(val, (str, bool)):
kwargs[field] = val
elif isinstance(val, torch.Tensor):
# In the init, all inputs will be converted to
# tensors before setting as attributes
kwargs[field] = val[index]
else:
raise ValueError(f"Field {field} type is not supported for indexing")
kwargs["device"] = self.device
return self.__class__(**kwargs)
############################################################
# Field of View Camera Classes #
@@ -434,6 +487,18 @@ class FoVPerspectiveCameras(CamerasBase):
for rasterization.
"""
# For __getitem__
_FIELDS = (
"K",
"znear",
"zfar",
"aspect_ratio",
"fov",
"R",
"T",
"degrees",
)
def __init__(
self,
znear=1.0,
@@ -590,7 +655,7 @@ class FoVPerspectiveCameras(CamerasBase):
xy_depth: torch.Tensor,
world_coordinates: bool = True,
scaled_depth_input: bool = False,
**kwargs
**kwargs,
) -> torch.Tensor:
""">!
FoV cameras further allow for passing depth in world units
@@ -681,6 +746,20 @@ class FoVOrthographicCameras(CamerasBase):
The definition of the parameters follow the OpenGL orthographic camera.
"""
# For __getitem__
_FIELDS = (
"K",
"znear",
"zfar",
"R",
"T",
"max_y",
"min_y",
"max_x",
"min_x",
"scale_xyz",
)
def __init__(
self,
znear=1.0,
@@ -819,7 +898,7 @@ class FoVOrthographicCameras(CamerasBase):
xy_depth: torch.Tensor,
world_coordinates: bool = True,
scaled_depth_input: bool = False,
**kwargs
**kwargs,
) -> torch.Tensor:
""">!
FoV cameras further allow for passing depth in world units
@@ -907,6 +986,17 @@ class PerspectiveCameras(CamerasBase):
If parameters are specified in screen space, `in_ndc` must be set to False.
"""
# For __getitem__
_FIELDS = (
"K",
"R",
"T",
"focal_length",
"principal_point",
"_in_ndc", # arg is in_ndc but attribute set as _in_ndc
"image_size",
)
def __init__(
self,
focal_length=1.0,
@@ -1007,7 +1097,7 @@ class PerspectiveCameras(CamerasBase):
xy_depth: torch.Tensor,
world_coordinates: bool = True,
from_ndc: bool = False,
**kwargs
**kwargs,
) -> torch.Tensor:
"""
Args:
@@ -1126,6 +1216,17 @@ class OrthographicCameras(CamerasBase):
If parameters are specified in screen space, `in_ndc` must be set to False.
"""
# For __getitem__
_FIELDS = (
"K",
"R",
"T",
"focal_length",
"principal_point",
"_in_ndc",
"image_size",
)
def __init__(
self,
focal_length=1.0,
@@ -1225,7 +1326,7 @@ class OrthographicCameras(CamerasBase):
xy_depth: torch.Tensor,
world_coordinates: bool = True,
from_ndc: bool = False,
**kwargs
**kwargs,
) -> torch.Tensor:
"""
Args:

View File

@@ -155,7 +155,7 @@ class TensorProperties(nn.Module):
Returns:
if `index` is an index int/slice return a TensorAccessor class
with getattribute/setattribute methods which return/update the value
at the index in the original camera.
at the index in the original class.
"""
if isinstance(index, (int, slice)):
return TensorAccessor(class_object=self, index=index)