Boolean indexing of cameras

Summary: Reasonable to expect bool indexing.

Reviewed By: bottler, kjchalup

Differential Revision: D38741446

fbshipit-source-id: 22b607bf13110043c5624196c66ca1484fdbce6c
This commit is contained in:
Roman Shapovalov
2022-08-16 15:19:39 -07:00
committed by Facebook GitHub Bot
parent 60808972b8
commit b7c826b786
6 changed files with 58 additions and 18 deletions

View File

@@ -884,7 +884,8 @@ class TestFoVPerspectiveProjection(TestCaseMixin, unittest.TestCase):
self.assertTrue(new_cam.device == device)
def test_getitem(self):
R_matrix = torch.randn((6, 3, 3))
N_CAMERAS = 6
R_matrix = torch.randn((N_CAMERAS, 3, 3))
cam = FoVPerspectiveCameras(znear=10.0, zfar=100.0, R=R_matrix)
# Check get item returns an instance of the same class
@@ -908,22 +909,39 @@ class TestFoVPerspectiveProjection(TestCaseMixin, unittest.TestCase):
self.assertClose(c012.R, R_matrix[0:3, ...])
# Check torch.LongTensor index
index = torch.tensor([1, 3, 5], dtype=torch.int64)
SLICE = [1, 3, 5]
index = torch.tensor(SLICE, dtype=torch.int64)
c135 = cam[index]
self.assertEqual(len(c135), 3)
self.assertClose(c135.zfar, torch.tensor([100.0] * 3))
self.assertClose(c135.znear, torch.tensor([10.0] * 3))
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
self.assertClose(c135.R, R_matrix[SLICE, ...])
# Check torch.BoolTensor index
bool_slice = [i in SLICE for i in range(N_CAMERAS)]
index = torch.tensor(bool_slice, dtype=torch.bool)
c135 = cam[index]
self.assertEqual(len(c135), 3)
self.assertClose(c135.zfar, torch.tensor([100.0] * 3))
self.assertClose(c135.znear, torch.tensor([10.0] * 3))
self.assertClose(c135.R, R_matrix[SLICE, ...])
# Check errors with get item
with self.assertRaisesRegex(ValueError, "out of bounds"):
cam[6]
cam[N_CAMERAS]
with self.assertRaisesRegex(ValueError, "does not match cameras"):
index = torch.tensor([1, 0, 1], dtype=torch.bool)
cam[index]
with self.assertRaisesRegex(ValueError, "Invalid index type"):
cam[slice(0, 1)]
with self.assertRaisesRegex(ValueError, "Invalid index type"):
index = torch.tensor([1, 3, 5], dtype=torch.float32)
cam[[True, False]]
with self.assertRaisesRegex(ValueError, "Invalid index type"):
index = torch.tensor(SLICE, dtype=torch.float32)
cam[index]
def test_get_full_transform(self):