mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
Test R2N2 loads correct numbers of instances
Summary: Sample/Get all views at the loading phase instead of returning phase; Load only views from the split instead of all 24 views; Test the numbers of views loaded are correct for each category. Reviewed By: nikhilaravi Differential Revision: D22631414 fbshipit-source-id: 1c5ce99fe2bdf6618c1aa0b69bb6899473376bc2
This commit is contained in:
committed by
Facebook GitHub Bot
parent
7cb9d8ea86
commit
483e538dae
@@ -59,20 +59,30 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
||||
Test the loaded train split of R2N2 return items of the correct shapes and types. Also
|
||||
check the first image returned is correct.
|
||||
"""
|
||||
# Load dataset in the train split.
|
||||
r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
||||
# Load dataset in the test split.
|
||||
r2n2_dataset = R2N2("test", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
||||
|
||||
# Check total number of objects in the dataset is correct.
|
||||
with open(SPLITS_PATH) as splits:
|
||||
split_dict = json.load(splits)["train"]
|
||||
model_nums = [len(split_dict[synset].keys()) for synset in split_dict.keys()]
|
||||
split_dict = json.load(splits)["test"]
|
||||
model_nums = [len(split_dict[synset]) for synset in split_dict]
|
||||
self.assertEqual(len(r2n2_dataset), sum(model_nums))
|
||||
|
||||
# Randomly retrieve an object from the dataset.
|
||||
rand_idx = torch.randint(len(r2n2_dataset), (1,))
|
||||
rand_obj = r2n2_dataset[rand_idx]
|
||||
# Check the numbers of loaded instances for each category are correct.
|
||||
for synset in split_dict:
|
||||
split_synset_nums = sum(
|
||||
len(split_dict[synset][model]) for model in split_dict[synset]
|
||||
)
|
||||
idx_start = r2n2_dataset.synset_start_idxs[synset]
|
||||
idx_end = idx_start + r2n2_dataset.synset_num_models[synset]
|
||||
synset_views_list = r2n2_dataset.views_per_model_list[idx_start:idx_end]
|
||||
loaded_synset_views = sum(len(views) for views in synset_views_list)
|
||||
self.assertEqual(loaded_synset_views, split_synset_nums)
|
||||
|
||||
# Retrieve an object from the dataset.
|
||||
r2n2_obj = r2n2_dataset[39]
|
||||
# Check that verts and faces returned by __getitem__ have the correct shapes and types.
|
||||
verts, faces = rand_obj["verts"], rand_obj["faces"]
|
||||
verts, faces = r2n2_obj["verts"], r2n2_obj["faces"]
|
||||
self.assertTrue(verts.dtype == torch.float32)
|
||||
self.assertTrue(faces.dtype == torch.int64)
|
||||
self.assertEqual(verts.ndim, 2)
|
||||
@@ -81,11 +91,17 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(faces.shape[-1], 3)
|
||||
|
||||
# Check that image batch returned by __getitem__ has the correct shape.
|
||||
self.assertEqual(rand_obj["images"].shape[0], 24)
|
||||
self.assertEqual(rand_obj["images"].shape[1], 137)
|
||||
self.assertEqual(rand_obj["images"].shape[2], 137)
|
||||
self.assertEqual(rand_obj["images"].shape[-1], 3)
|
||||
self.assertEqual(r2n2_dataset[rand_idx, [21]]["images"].shape[0], 1)
|
||||
self.assertEqual(r2n2_obj["images"].shape[0], 24)
|
||||
self.assertEqual(r2n2_obj["images"].shape[1], 137)
|
||||
self.assertEqual(r2n2_obj["images"].shape[2], 137)
|
||||
self.assertEqual(r2n2_obj["images"].shape[-1], 3)
|
||||
self.assertEqual(r2n2_dataset[39, [21]]["images"].shape[0], 1)
|
||||
self.assertEqual(r2n2_dataset[39, torch.tensor([12, 21])]["images"].shape[0], 2)
|
||||
|
||||
# Check models with total view counts less than 24 return image batches
|
||||
# of the correct shapes.
|
||||
self.assertEqual(r2n2_dataset[635]["images"].shape[0], 5)
|
||||
self.assertEqual(r2n2_dataset[8369]["images"].shape[0], 10)
|
||||
|
||||
def test_collate_models(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user