Return R2N2 renderings

Summary: R2N2 returns R2N2's own renderings of ShapeNetCore models.

Reviewed By: nikhilaravi

Differential Revision: D22266988

fbshipit-source-id: 36e67bd06c6459773e6e5f654259166b579be36a
This commit is contained in:
Luya Gao
2020-07-14 14:52:21 -07:00
committed by Facebook GitHub Bot
parent 5636eb6152
commit dc08c30583
3 changed files with 91 additions and 8 deletions

View File

@@ -56,7 +56,8 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
def test_load_R2N2(self):
"""
Test the loaded train split of R2N2 return items of the correct shapes and types.
Test the loaded train split of R2N2 return items of the correct shapes and types. Also
check the first image returned is correct.
"""
# Load dataset in the train split.
r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
@@ -68,8 +69,9 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
self.assertEqual(len(r2n2_dataset), sum(model_nums))
# Randomly retrieve an object from the dataset.
rand_obj = r2n2_dataset[torch.randint(len(r2n2_dataset), (1,))]
# Check that data type and shape of the item returned by __getitem__ are correct.
rand_idx = torch.randint(len(r2n2_dataset), (1,))
rand_obj = r2n2_dataset[rand_idx]
# Check that verts and faces returned by __getitem__ have the correct shapes and types.
verts, faces = rand_obj["verts"], rand_obj["faces"]
self.assertTrue(verts.dtype == torch.float32)
self.assertTrue(faces.dtype == torch.int64)
@@ -78,6 +80,13 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
self.assertEqual(faces.ndim, 2)
self.assertEqual(faces.shape[-1], 3)
# Check that image batch returned by __getitem__ has the correct shape.
self.assertEqual(rand_obj["images"].shape[0], 24)
self.assertEqual(rand_obj["images"].shape[1], 137)
self.assertEqual(rand_obj["images"].shape[2], 137)
self.assertEqual(rand_obj["images"].shape[-1], 3)
self.assertEqual(r2n2_dataset[rand_idx, [21]]["images"].shape[0], 1)
def test_collate_models(self):
"""
Test collate_batched_meshes returns items of the correct shapes and types.
@@ -118,6 +127,7 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
self.assertEqual(len(object_batch["label"]), batch_size)
self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size)
self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size)
self.assertEqual(object_batch["images"].shape[0], batch_size)
def test_catch_render_arg_errors(self):
"""