mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
Return R2N2 renderings
Summary: R2N2 returns R2N2's own renderings of ShapeNetCore models. Reviewed By: nikhilaravi Differential Revision: D22266988 fbshipit-source-id: 36e67bd06c6459773e6e5f654259166b579be36a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
5636eb6152
commit
dc08c30583
@@ -56,7 +56,8 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
def test_load_R2N2(self):
|
||||
"""
|
||||
Test the loaded train split of R2N2 return items of the correct shapes and types.
|
||||
Test the loaded train split of R2N2 return items of the correct shapes and types. Also
|
||||
check the first image returned is correct.
|
||||
"""
|
||||
# Load dataset in the train split.
|
||||
r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
||||
@@ -68,8 +69,9 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(len(r2n2_dataset), sum(model_nums))
|
||||
|
||||
# Randomly retrieve an object from the dataset.
|
||||
rand_obj = r2n2_dataset[torch.randint(len(r2n2_dataset), (1,))]
|
||||
# Check that data type and shape of the item returned by __getitem__ are correct.
|
||||
rand_idx = torch.randint(len(r2n2_dataset), (1,))
|
||||
rand_obj = r2n2_dataset[rand_idx]
|
||||
# Check that verts and faces returned by __getitem__ have the correct shapes and types.
|
||||
verts, faces = rand_obj["verts"], rand_obj["faces"]
|
||||
self.assertTrue(verts.dtype == torch.float32)
|
||||
self.assertTrue(faces.dtype == torch.int64)
|
||||
@@ -78,6 +80,13 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(faces.ndim, 2)
|
||||
self.assertEqual(faces.shape[-1], 3)
|
||||
|
||||
# Check that image batch returned by __getitem__ has the correct shape.
|
||||
self.assertEqual(rand_obj["images"].shape[0], 24)
|
||||
self.assertEqual(rand_obj["images"].shape[1], 137)
|
||||
self.assertEqual(rand_obj["images"].shape[2], 137)
|
||||
self.assertEqual(rand_obj["images"].shape[-1], 3)
|
||||
self.assertEqual(r2n2_dataset[rand_idx, [21]]["images"].shape[0], 1)
|
||||
|
||||
def test_collate_models(self):
|
||||
"""
|
||||
Test collate_batched_meshes returns items of the correct shapes and types.
|
||||
@@ -118,6 +127,7 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(len(object_batch["label"]), batch_size)
|
||||
self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size)
|
||||
self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size)
|
||||
self.assertEqual(object_batch["images"].shape[0], batch_size)
|
||||
|
||||
def test_catch_render_arg_errors(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user