mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Enable __getitem__ for Cameras to return an instance of Cameras
Summary: Added a custom `__getitem__` method to `CamerasBase` which returns an instance of the appropriate camera instead of the `TensorAccessor` class. Long term we should deprecate the `TensorAccessor` and the `__getitem__` method on `TensorProperties` FB: In the next diff I will update the uses of `select_cameras` in implicitron. Reviewed By: bottler Differential Revision: D33185885 fbshipit-source-id: c31995d0eb126981e91ba61a6151d5404b263f67
This commit is contained in:
parent
cc3259ba93
commit
28ccdb7328
@ -75,6 +75,10 @@ class CamerasBase(TensorProperties):
|
||||
boolean argument of the function.
|
||||
"""
|
||||
|
||||
# Used in __getitem__ to index the relevant fields
|
||||
# When creating a new camera, this should be set in the __init__
|
||||
_FIELDS: Tuple = ()
|
||||
|
||||
def get_projection_transform(self):
|
||||
"""
|
||||
Calculate the projective transformation matrix.
|
||||
@ -362,6 +366,55 @@ class CamerasBase(TensorProperties):
|
||||
"""
|
||||
return self.image_size if hasattr(self, "image_size") else None
|
||||
|
||||
def __getitem__(
|
||||
self, index: Union[int, List[int], torch.LongTensor]
|
||||
) -> "CamerasBase":
|
||||
"""
|
||||
Override for the __getitem__ method in TensorProperties which needs to be
|
||||
refactored.
|
||||
|
||||
Args:
|
||||
index: an int/list/long tensor used to index all the fields in the cameras given by
|
||||
self._FIELDS.
|
||||
Returns:
|
||||
if `index` is an index int/list/long tensor return an instance of the current
|
||||
cameras class with only the values at the selected index.
|
||||
"""
|
||||
|
||||
kwargs = {}
|
||||
|
||||
if not isinstance(index, (int, list, torch.LongTensor)):
|
||||
msg = "Invalid index type, expected int, List[int] or torch.LongTensor; got %r"
|
||||
raise ValueError(msg % type(index))
|
||||
|
||||
if isinstance(index, int):
|
||||
index = [index]
|
||||
|
||||
if max(index) >= len(self):
|
||||
raise ValueError(f"Index {max(index)} is out of bounds for select cameras")
|
||||
|
||||
for field in self._FIELDS:
|
||||
val = getattr(self, field, None)
|
||||
if val is None:
|
||||
continue
|
||||
|
||||
# e.g. "in_ndc" is set as attribute "_in_ndc" on the class
|
||||
# but provided as "in_ndc" on initialization
|
||||
if field.startswith("_"):
|
||||
field = field[1:]
|
||||
|
||||
if isinstance(val, (str, bool)):
|
||||
kwargs[field] = val
|
||||
elif isinstance(val, torch.Tensor):
|
||||
# In the init, all inputs will be converted to
|
||||
# tensors before setting as attributes
|
||||
kwargs[field] = val[index]
|
||||
else:
|
||||
raise ValueError(f"Field {field} type is not supported for indexing")
|
||||
|
||||
kwargs["device"] = self.device
|
||||
return self.__class__(**kwargs)
|
||||
|
||||
|
||||
############################################################
|
||||
# Field of View Camera Classes #
|
||||
@ -434,6 +487,18 @@ class FoVPerspectiveCameras(CamerasBase):
|
||||
for rasterization.
|
||||
"""
|
||||
|
||||
# For __getitem__
|
||||
_FIELDS = (
|
||||
"K",
|
||||
"znear",
|
||||
"zfar",
|
||||
"aspect_ratio",
|
||||
"fov",
|
||||
"R",
|
||||
"T",
|
||||
"degrees",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
znear=1.0,
|
||||
@ -590,7 +655,7 @@ class FoVPerspectiveCameras(CamerasBase):
|
||||
xy_depth: torch.Tensor,
|
||||
world_coordinates: bool = True,
|
||||
scaled_depth_input: bool = False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
""">!
|
||||
FoV cameras further allow for passing depth in world units
|
||||
@ -681,6 +746,20 @@ class FoVOrthographicCameras(CamerasBase):
|
||||
The definition of the parameters follow the OpenGL orthographic camera.
|
||||
"""
|
||||
|
||||
# For __getitem__
|
||||
_FIELDS = (
|
||||
"K",
|
||||
"znear",
|
||||
"zfar",
|
||||
"R",
|
||||
"T",
|
||||
"max_y",
|
||||
"min_y",
|
||||
"max_x",
|
||||
"min_x",
|
||||
"scale_xyz",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
znear=1.0,
|
||||
@ -819,7 +898,7 @@ class FoVOrthographicCameras(CamerasBase):
|
||||
xy_depth: torch.Tensor,
|
||||
world_coordinates: bool = True,
|
||||
scaled_depth_input: bool = False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
""">!
|
||||
FoV cameras further allow for passing depth in world units
|
||||
@ -907,6 +986,17 @@ class PerspectiveCameras(CamerasBase):
|
||||
If parameters are specified in screen space, `in_ndc` must be set to False.
|
||||
"""
|
||||
|
||||
# For __getitem__
|
||||
_FIELDS = (
|
||||
"K",
|
||||
"R",
|
||||
"T",
|
||||
"focal_length",
|
||||
"principal_point",
|
||||
"_in_ndc", # arg is in_ndc but attribute set as _in_ndc
|
||||
"image_size",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
focal_length=1.0,
|
||||
@ -1007,7 +1097,7 @@ class PerspectiveCameras(CamerasBase):
|
||||
xy_depth: torch.Tensor,
|
||||
world_coordinates: bool = True,
|
||||
from_ndc: bool = False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -1126,6 +1216,17 @@ class OrthographicCameras(CamerasBase):
|
||||
If parameters are specified in screen space, `in_ndc` must be set to False.
|
||||
"""
|
||||
|
||||
# For __getitem__
|
||||
_FIELDS = (
|
||||
"K",
|
||||
"R",
|
||||
"T",
|
||||
"focal_length",
|
||||
"principal_point",
|
||||
"_in_ndc",
|
||||
"image_size",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
focal_length=1.0,
|
||||
@ -1225,7 +1326,7 @@ class OrthographicCameras(CamerasBase):
|
||||
xy_depth: torch.Tensor,
|
||||
world_coordinates: bool = True,
|
||||
from_ndc: bool = False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
|
@ -155,7 +155,7 @@ class TensorProperties(nn.Module):
|
||||
Returns:
|
||||
if `index` is an index int/slice return a TensorAccessor class
|
||||
with getattribute/setattribute methods which return/update the value
|
||||
at the index in the original camera.
|
||||
at the index in the original class.
|
||||
"""
|
||||
if isinstance(index, (int, slice)):
|
||||
return TensorAccessor(class_object=self, index=index)
|
||||
|
@ -783,18 +783,53 @@ class TestFoVPerspectiveProjection(TestCaseMixin, unittest.TestCase):
|
||||
self.assertTrue(cam.znear.shape == (2,))
|
||||
self.assertTrue(cam.zfar.shape == (2,))
|
||||
|
||||
# update znear element 1
|
||||
cam[1].znear = 20.0
|
||||
self.assertTrue(cam.znear[1] == 20.0)
|
||||
|
||||
# Get item and get value
|
||||
c0 = cam[0]
|
||||
self.assertTrue(c0.zfar == 100.0)
|
||||
|
||||
# Test to
|
||||
new_cam = cam.to(device=device)
|
||||
self.assertTrue(new_cam.device == device)
|
||||
|
||||
def test_getitem(self):
|
||||
R_matrix = torch.randn((6, 3, 3))
|
||||
cam = FoVPerspectiveCameras(znear=10.0, zfar=100.0, R=R_matrix)
|
||||
|
||||
# Check get item returns an instance of the same class
|
||||
# with all the same keys
|
||||
c0 = cam[0]
|
||||
self.assertTrue(isinstance(c0, FoVPerspectiveCameras))
|
||||
self.assertEqual(cam.__dict__.keys(), c0.__dict__.keys())
|
||||
|
||||
# Check all fields correct in get item with int index
|
||||
self.assertEqual(len(c0), 1)
|
||||
self.assertClose(c0.zfar, torch.tensor([100.0]))
|
||||
self.assertClose(c0.znear, torch.tensor([10.0]))
|
||||
self.assertClose(c0.R, R_matrix[0:1, ...])
|
||||
self.assertEqual(c0.device, torch.device("cpu"))
|
||||
|
||||
# Check list(int) index
|
||||
c012 = cam[[0, 1, 2]]
|
||||
self.assertEqual(len(c012), 3)
|
||||
self.assertClose(c012.zfar, torch.tensor([100.0] * 3))
|
||||
self.assertClose(c012.znear, torch.tensor([10.0] * 3))
|
||||
self.assertClose(c012.R, R_matrix[0:3, ...])
|
||||
|
||||
# Check torch.LongTensor index
|
||||
index = torch.tensor([1, 3, 5], dtype=torch.int64)
|
||||
c135 = cam[index]
|
||||
self.assertEqual(len(c135), 3)
|
||||
self.assertClose(c135.zfar, torch.tensor([100.0] * 3))
|
||||
self.assertClose(c135.znear, torch.tensor([10.0] * 3))
|
||||
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
|
||||
|
||||
# Check errors with get item
|
||||
with self.assertRaisesRegex(ValueError, "out of bounds"):
|
||||
cam[6]
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Invalid index type"):
|
||||
cam[slice(0, 1)]
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Invalid index type"):
|
||||
index = torch.tensor([1, 3, 5], dtype=torch.float32)
|
||||
cam[index]
|
||||
|
||||
def test_get_full_transform(self):
|
||||
cam = FoVPerspectiveCameras()
|
||||
T = torch.tensor([0.0, 0.0, 1.0]).view(1, -1)
|
||||
@ -919,6 +954,30 @@ class TestFoVOrthographicProjection(TestCaseMixin, unittest.TestCase):
|
||||
self.assertFalse(cam.is_perspective())
|
||||
self.assertEqual(cam.get_znear(), 1.0)
|
||||
|
||||
def test_getitem(self):
|
||||
R_matrix = torch.randn((6, 3, 3))
|
||||
scale = torch.tensor([[1.0, 1.0, 1.0]], requires_grad=True)
|
||||
cam = FoVOrthographicCameras(
|
||||
znear=10.0, zfar=100.0, R=R_matrix, scale_xyz=scale
|
||||
)
|
||||
|
||||
# Check get item returns an instance of the same class
|
||||
# with all the same keys
|
||||
c0 = cam[0]
|
||||
self.assertTrue(isinstance(c0, FoVOrthographicCameras))
|
||||
self.assertEqual(cam.__dict__.keys(), c0.__dict__.keys())
|
||||
|
||||
# Check torch.LongTensor index
|
||||
index = torch.tensor([1, 3, 5], dtype=torch.int64)
|
||||
c135 = cam[index]
|
||||
self.assertEqual(len(c135), 3)
|
||||
self.assertClose(c135.zfar, torch.tensor([100.0] * 3))
|
||||
self.assertClose(c135.znear, torch.tensor([10.0] * 3))
|
||||
self.assertClose(c135.min_x, torch.tensor([-1.0] * 3))
|
||||
self.assertClose(c135.max_x, torch.tensor([1.0] * 3))
|
||||
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
|
||||
self.assertClose(c135.scale_xyz, scale.expand(3, -1))
|
||||
|
||||
|
||||
############################################################
|
||||
# Orthographic Camera #
|
||||
@ -976,6 +1035,30 @@ class TestOrthographicProjection(TestCaseMixin, unittest.TestCase):
|
||||
self.assertFalse(cam.is_perspective())
|
||||
self.assertIsNone(cam.get_znear())
|
||||
|
||||
def test_getitem(self):
|
||||
R_matrix = torch.randn((6, 3, 3))
|
||||
principal_point = torch.randn((6, 2, 1))
|
||||
focal_length = 5.0
|
||||
cam = OrthographicCameras(
|
||||
R=R_matrix,
|
||||
focal_length=focal_length,
|
||||
principal_point=principal_point,
|
||||
)
|
||||
|
||||
# Check get item returns an instance of the same class
|
||||
# with all the same keys
|
||||
c0 = cam[0]
|
||||
self.assertTrue(isinstance(c0, OrthographicCameras))
|
||||
self.assertEqual(cam.__dict__.keys(), c0.__dict__.keys())
|
||||
|
||||
# Check torch.LongTensor index
|
||||
index = torch.tensor([1, 3, 5], dtype=torch.int64)
|
||||
c135 = cam[index]
|
||||
self.assertEqual(len(c135), 3)
|
||||
self.assertClose(c135.focal_length, torch.tensor([5.0] * 3))
|
||||
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
|
||||
self.assertClose(c135.principal_point, principal_point[[1, 3, 5], ...])
|
||||
|
||||
|
||||
############################################################
|
||||
# Perspective Camera #
|
||||
@ -1027,3 +1110,30 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase):
|
||||
cam = PerspectiveCameras(focal_length=5.0, principal_point=((2.5, 2.5),))
|
||||
self.assertTrue(cam.is_perspective())
|
||||
self.assertIsNone(cam.get_znear())
|
||||
|
||||
def test_getitem(self):
|
||||
R_matrix = torch.randn((6, 3, 3))
|
||||
principal_point = torch.randn((6, 2, 1))
|
||||
focal_length = 5.0
|
||||
cam = PerspectiveCameras(
|
||||
R=R_matrix,
|
||||
focal_length=focal_length,
|
||||
principal_point=principal_point,
|
||||
)
|
||||
|
||||
# Check get item returns an instance of the same class
|
||||
# with all the same keys
|
||||
c0 = cam[0]
|
||||
self.assertTrue(isinstance(c0, PerspectiveCameras))
|
||||
self.assertEqual(cam.__dict__.keys(), c0.__dict__.keys())
|
||||
|
||||
# Check torch.LongTensor index
|
||||
index = torch.tensor([1, 3, 5], dtype=torch.int64)
|
||||
c135 = cam[index]
|
||||
self.assertEqual(len(c135), 3)
|
||||
self.assertClose(c135.focal_length, torch.tensor([5.0] * 3))
|
||||
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
|
||||
self.assertClose(c135.principal_point, principal_point[[1, 3, 5], ...])
|
||||
|
||||
# Check in_ndc is handled correctly
|
||||
self.assertEqual(cam._in_ndc, c0._in_ndc)
|
||||
|
Loading…
x
Reference in New Issue
Block a user