mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 05:40:34 +08:00
Boolean indexing of cameras
Summary: Reasonable to expect bool indexing. Reviewed By: bottler, kjchalup Differential Revision: D38741446 fbshipit-source-id: 22b607bf13110043c5624196c66ca1484fdbce6c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
60808972b8
commit
b7c826b786
@@ -884,7 +884,8 @@ class TestFoVPerspectiveProjection(TestCaseMixin, unittest.TestCase):
|
||||
self.assertTrue(new_cam.device == device)
|
||||
|
||||
def test_getitem(self):
|
||||
R_matrix = torch.randn((6, 3, 3))
|
||||
N_CAMERAS = 6
|
||||
R_matrix = torch.randn((N_CAMERAS, 3, 3))
|
||||
cam = FoVPerspectiveCameras(znear=10.0, zfar=100.0, R=R_matrix)
|
||||
|
||||
# Check get item returns an instance of the same class
|
||||
@@ -908,22 +909,39 @@ class TestFoVPerspectiveProjection(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(c012.R, R_matrix[0:3, ...])
|
||||
|
||||
# Check torch.LongTensor index
|
||||
index = torch.tensor([1, 3, 5], dtype=torch.int64)
|
||||
SLICE = [1, 3, 5]
|
||||
index = torch.tensor(SLICE, dtype=torch.int64)
|
||||
c135 = cam[index]
|
||||
self.assertEqual(len(c135), 3)
|
||||
self.assertClose(c135.zfar, torch.tensor([100.0] * 3))
|
||||
self.assertClose(c135.znear, torch.tensor([10.0] * 3))
|
||||
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
|
||||
self.assertClose(c135.R, R_matrix[SLICE, ...])
|
||||
|
||||
# Check torch.BoolTensor index
|
||||
bool_slice = [i in SLICE for i in range(N_CAMERAS)]
|
||||
index = torch.tensor(bool_slice, dtype=torch.bool)
|
||||
c135 = cam[index]
|
||||
self.assertEqual(len(c135), 3)
|
||||
self.assertClose(c135.zfar, torch.tensor([100.0] * 3))
|
||||
self.assertClose(c135.znear, torch.tensor([10.0] * 3))
|
||||
self.assertClose(c135.R, R_matrix[SLICE, ...])
|
||||
|
||||
# Check errors with get item
|
||||
with self.assertRaisesRegex(ValueError, "out of bounds"):
|
||||
cam[6]
|
||||
cam[N_CAMERAS]
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "does not match cameras"):
|
||||
index = torch.tensor([1, 0, 1], dtype=torch.bool)
|
||||
cam[index]
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Invalid index type"):
|
||||
cam[slice(0, 1)]
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Invalid index type"):
|
||||
index = torch.tensor([1, 3, 5], dtype=torch.float32)
|
||||
cam[[True, False]]
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Invalid index type"):
|
||||
index = torch.tensor(SLICE, dtype=torch.float32)
|
||||
cam[index]
|
||||
|
||||
def test_get_full_transform(self):
|
||||
|
||||
Reference in New Issue
Block a user