Test R2N2 loads correct numbers of instances

Summary:
Sample/Get all views at the loading phase instead of returning phase;
Load only views from the split instead of all 24 views;
Test the numbers of views loaded are correct for each category.

Reviewed By: nikhilaravi

Differential Revision: D22631414

fbshipit-source-id: 1c5ce99fe2bdf6618c1aa0b69bb6899473376bc2
This commit is contained in:
Luya Gao
2020-07-23 10:15:50 -07:00
committed by Facebook GitHub Bot
parent 7cb9d8ea86
commit 483e538dae
5 changed files with 96 additions and 45 deletions

View File

@@ -59,20 +59,30 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
Test the loaded train split of R2N2 return items of the correct shapes and types. Also
check the first image returned is correct.
"""
# Load dataset in the train split.
r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
# Load dataset in the test split.
r2n2_dataset = R2N2("test", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
# Check total number of objects in the dataset is correct.
with open(SPLITS_PATH) as splits:
split_dict = json.load(splits)["train"]
model_nums = [len(split_dict[synset].keys()) for synset in split_dict.keys()]
split_dict = json.load(splits)["test"]
model_nums = [len(split_dict[synset]) for synset in split_dict]
self.assertEqual(len(r2n2_dataset), sum(model_nums))
# Randomly retrieve an object from the dataset.
rand_idx = torch.randint(len(r2n2_dataset), (1,))
rand_obj = r2n2_dataset[rand_idx]
# Check the numbers of loaded instances for each category are correct.
for synset in split_dict:
split_synset_nums = sum(
len(split_dict[synset][model]) for model in split_dict[synset]
)
idx_start = r2n2_dataset.synset_start_idxs[synset]
idx_end = idx_start + r2n2_dataset.synset_num_models[synset]
synset_views_list = r2n2_dataset.views_per_model_list[idx_start:idx_end]
loaded_synset_views = sum(len(views) for views in synset_views_list)
self.assertEqual(loaded_synset_views, split_synset_nums)
# Retrieve an object from the dataset.
r2n2_obj = r2n2_dataset[39]
# Check that verts and faces returned by __getitem__ have the correct shapes and types.
verts, faces = rand_obj["verts"], rand_obj["faces"]
verts, faces = r2n2_obj["verts"], r2n2_obj["faces"]
self.assertTrue(verts.dtype == torch.float32)
self.assertTrue(faces.dtype == torch.int64)
self.assertEqual(verts.ndim, 2)
@@ -81,11 +91,17 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
self.assertEqual(faces.shape[-1], 3)
# Check that image batch returned by __getitem__ has the correct shape.
self.assertEqual(rand_obj["images"].shape[0], 24)
self.assertEqual(rand_obj["images"].shape[1], 137)
self.assertEqual(rand_obj["images"].shape[2], 137)
self.assertEqual(rand_obj["images"].shape[-1], 3)
self.assertEqual(r2n2_dataset[rand_idx, [21]]["images"].shape[0], 1)
self.assertEqual(r2n2_obj["images"].shape[0], 24)
self.assertEqual(r2n2_obj["images"].shape[1], 137)
self.assertEqual(r2n2_obj["images"].shape[2], 137)
self.assertEqual(r2n2_obj["images"].shape[-1], 3)
self.assertEqual(r2n2_dataset[39, [21]]["images"].shape[0], 1)
self.assertEqual(r2n2_dataset[39, torch.tensor([12, 21])]["images"].shape[0], 2)
# Check models with total view counts less than 24 return image batches
# of the correct shapes.
self.assertEqual(r2n2_dataset[635]["images"].shape[0], 5)
self.assertEqual(r2n2_dataset[8369]["images"].shape[0], 10)
def test_collate_models(self):
"""