Enable __getitem__ for Cameras to return an instance of Cameras

Summary:
Added a custom `__getitem__` method to `CamerasBase` which returns an instance of the appropriate camera instead of the `TensorAccessor` class.

Long term we should deprecate the `TensorAccessor` and the `__getitem__` method on `TensorProperties`

FB: In the next diff I will update the uses of `select_cameras` in implicitron.

Reviewed By: bottler

Differential Revision: D33185885

fbshipit-source-id: c31995d0eb126981e91ba61a6151d5404b263f67
This commit is contained in:
Nikhila Ravi
2021-12-21 05:45:32 -08:00
committed by Facebook GitHub Bot
parent cc3259ba93
commit 28ccdb7328
3 changed files with 224 additions and 13 deletions

View File

@@ -783,18 +783,53 @@ class TestFoVPerspectiveProjection(TestCaseMixin, unittest.TestCase):
self.assertTrue(cam.znear.shape == (2,))
self.assertTrue(cam.zfar.shape == (2,))
# update znear element 1
cam[1].znear = 20.0
self.assertTrue(cam.znear[1] == 20.0)
# Get item and get value
c0 = cam[0]
self.assertTrue(c0.zfar == 100.0)
# Test to
new_cam = cam.to(device=device)
self.assertTrue(new_cam.device == device)
def test_getitem(self):
R_matrix = torch.randn((6, 3, 3))
cam = FoVPerspectiveCameras(znear=10.0, zfar=100.0, R=R_matrix)
# Check get item returns an instance of the same class
# with all the same keys
c0 = cam[0]
self.assertTrue(isinstance(c0, FoVPerspectiveCameras))
self.assertEqual(cam.__dict__.keys(), c0.__dict__.keys())
# Check all fields correct in get item with int index
self.assertEqual(len(c0), 1)
self.assertClose(c0.zfar, torch.tensor([100.0]))
self.assertClose(c0.znear, torch.tensor([10.0]))
self.assertClose(c0.R, R_matrix[0:1, ...])
self.assertEqual(c0.device, torch.device("cpu"))
# Check list(int) index
c012 = cam[[0, 1, 2]]
self.assertEqual(len(c012), 3)
self.assertClose(c012.zfar, torch.tensor([100.0] * 3))
self.assertClose(c012.znear, torch.tensor([10.0] * 3))
self.assertClose(c012.R, R_matrix[0:3, ...])
# Check torch.LongTensor index
index = torch.tensor([1, 3, 5], dtype=torch.int64)
c135 = cam[index]
self.assertEqual(len(c135), 3)
self.assertClose(c135.zfar, torch.tensor([100.0] * 3))
self.assertClose(c135.znear, torch.tensor([10.0] * 3))
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
# Check errors with get item
with self.assertRaisesRegex(ValueError, "out of bounds"):
cam[6]
with self.assertRaisesRegex(ValueError, "Invalid index type"):
cam[slice(0, 1)]
with self.assertRaisesRegex(ValueError, "Invalid index type"):
index = torch.tensor([1, 3, 5], dtype=torch.float32)
cam[index]
def test_get_full_transform(self):
cam = FoVPerspectiveCameras()
T = torch.tensor([0.0, 0.0, 1.0]).view(1, -1)
@@ -919,6 +954,30 @@ class TestFoVOrthographicProjection(TestCaseMixin, unittest.TestCase):
self.assertFalse(cam.is_perspective())
self.assertEqual(cam.get_znear(), 1.0)
def test_getitem(self):
R_matrix = torch.randn((6, 3, 3))
scale = torch.tensor([[1.0, 1.0, 1.0]], requires_grad=True)
cam = FoVOrthographicCameras(
znear=10.0, zfar=100.0, R=R_matrix, scale_xyz=scale
)
# Check get item returns an instance of the same class
# with all the same keys
c0 = cam[0]
self.assertTrue(isinstance(c0, FoVOrthographicCameras))
self.assertEqual(cam.__dict__.keys(), c0.__dict__.keys())
# Check torch.LongTensor index
index = torch.tensor([1, 3, 5], dtype=torch.int64)
c135 = cam[index]
self.assertEqual(len(c135), 3)
self.assertClose(c135.zfar, torch.tensor([100.0] * 3))
self.assertClose(c135.znear, torch.tensor([10.0] * 3))
self.assertClose(c135.min_x, torch.tensor([-1.0] * 3))
self.assertClose(c135.max_x, torch.tensor([1.0] * 3))
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
self.assertClose(c135.scale_xyz, scale.expand(3, -1))
############################################################
# Orthographic Camera #
@@ -976,6 +1035,30 @@ class TestOrthographicProjection(TestCaseMixin, unittest.TestCase):
self.assertFalse(cam.is_perspective())
self.assertIsNone(cam.get_znear())
def test_getitem(self):
R_matrix = torch.randn((6, 3, 3))
principal_point = torch.randn((6, 2, 1))
focal_length = 5.0
cam = OrthographicCameras(
R=R_matrix,
focal_length=focal_length,
principal_point=principal_point,
)
# Check get item returns an instance of the same class
# with all the same keys
c0 = cam[0]
self.assertTrue(isinstance(c0, OrthographicCameras))
self.assertEqual(cam.__dict__.keys(), c0.__dict__.keys())
# Check torch.LongTensor index
index = torch.tensor([1, 3, 5], dtype=torch.int64)
c135 = cam[index]
self.assertEqual(len(c135), 3)
self.assertClose(c135.focal_length, torch.tensor([5.0] * 3))
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
self.assertClose(c135.principal_point, principal_point[[1, 3, 5], ...])
############################################################
# Perspective Camera #
@@ -1027,3 +1110,30 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase):
cam = PerspectiveCameras(focal_length=5.0, principal_point=((2.5, 2.5),))
self.assertTrue(cam.is_perspective())
self.assertIsNone(cam.get_znear())
def test_getitem(self):
R_matrix = torch.randn((6, 3, 3))
principal_point = torch.randn((6, 2, 1))
focal_length = 5.0
cam = PerspectiveCameras(
R=R_matrix,
focal_length=focal_length,
principal_point=principal_point,
)
# Check get item returns an instance of the same class
# with all the same keys
c0 = cam[0]
self.assertTrue(isinstance(c0, PerspectiveCameras))
self.assertEqual(cam.__dict__.keys(), c0.__dict__.keys())
# Check torch.LongTensor index
index = torch.tensor([1, 3, 5], dtype=torch.int64)
c135 = cam[index]
self.assertEqual(len(c135), 3)
self.assertClose(c135.focal_length, torch.tensor([5.0] * 3))
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
self.assertClose(c135.principal_point, principal_point[[1, 3, 5], ...])
# Check in_ndc is handled correctly
self.assertEqual(cam._in_ndc, c0._in_ndc)