pytorch3d/tests/test_r2n2.py
Luya Gao 49b4ce1acc R2N2 skeleton
Summary: Skeleton of R2N2 that for now only returns verts and faces extracted from ShapeNetCore v1.

Reviewed By: nikhilaravi

Differential Revision: D22203656

fbshipit-source-id: 00db6ac76bfdb76fdbc77a2087c34a3f0ff01e6a
2020-07-14 14:53:58 -07:00

112 lines
4.5 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""
Sanity checks for loading R2N2.
"""
import json
import os
import unittest
import torch
from common_testing import TestCaseMixin
from pytorch3d.datasets import R2N2, collate_batched_meshes
from torch.utils.data import DataLoader
# Set these paths in order to run the tests.
R2N2_PATH = None
SHAPENET_PATH = None
SPLITS_PATH = None
class TestR2N2(TestCaseMixin, unittest.TestCase):
def setUp(self):
"""
Check if the data paths are given otherwise skip tests.
"""
if SHAPENET_PATH is None or not os.path.exists(SHAPENET_PATH):
url = "https://www.shapenet.org/"
msg = (
"ShapeNet data not found, download from %s, update "
"SHAPENET_PATH at the top of the file, and rerun."
)
self.skipTest(msg % url)
if R2N2_PATH is None or not os.path.exists(R2N2_PATH):
url = "http://3d-r2n2.stanford.edu/"
msg = (
"R2N2 data not found, download from %s, update "
"R2N2_PATH at the top of the file, and rerun."
)
self.skipTest(msg % url)
if SPLITS_PATH is None or not os.path.exists(SPLITS_PATH):
msg = """Splits file not found, update SPLITS_PATH at the top
of the file, and rerun."""
self.skipTest(msg)
def test_load_R2N2(self):
"""
Test loading the train split of R2N2. Check the loaded dataset return items
of the correct shapes and types.
"""
# Load dataset in the train split.
split = "train"
r2n2_dataset = R2N2(split, SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
# Check total number of objects in the dataset is correct.
with open(SPLITS_PATH) as splits:
split_dict = json.load(splits)[split]
model_nums = [len(split_dict[synset].keys()) for synset in split_dict.keys()]
self.assertEqual(len(r2n2_dataset), sum(model_nums))
# Randomly retrieve an object from the dataset.
rand_obj = r2n2_dataset[torch.randint(len(r2n2_dataset), (1,))]
# Check that data type and shape of the item returned by __getitem__ are correct.
verts, faces = rand_obj["verts"], rand_obj["faces"]
self.assertTrue(verts.dtype == torch.float32)
self.assertTrue(faces.dtype == torch.int64)
self.assertEqual(verts.ndim, 2)
self.assertEqual(verts.shape[-1], 3)
self.assertEqual(faces.ndim, 2)
self.assertEqual(faces.shape[-1], 3)
def test_collate_models(self):
"""
Test collate_batched_meshes returns items of the correct shapes and types.
Check that when collate_batched_meshes is passed to Dataloader, batches of
the correct shapes and types are returned.
"""
# Load dataset in the train split.
split = "train"
r2n2_dataset = R2N2(split, SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
# Randomly retrieve several objects from the dataset and collate them.
collated_meshes = collate_batched_meshes(
[r2n2_dataset[idx] for idx in torch.randint(len(r2n2_dataset), (6,))]
)
# Check the collated verts and faces have the correct shapes.
verts, faces = collated_meshes["verts"], collated_meshes["faces"]
self.assertEqual(len(verts), 6)
self.assertEqual(len(faces), 6)
self.assertEqual(verts[0].shape[-1], 3)
self.assertEqual(faces[0].shape[-1], 3)
# Check the collated mesh has the correct shape.
mesh = collated_meshes["mesh"]
self.assertEqual(mesh.verts_padded().shape[0], 6)
self.assertEqual(mesh.verts_padded().shape[-1], 3)
self.assertEqual(mesh.faces_padded().shape[0], 6)
self.assertEqual(mesh.faces_padded().shape[-1], 3)
# Pass the custom collate_fn function to DataLoader and check elements
# in batch have the correct shape.
batch_size = 12
r2n2_loader = DataLoader(
r2n2_dataset, batch_size=batch_size, collate_fn=collate_batched_meshes
)
it = iter(r2n2_loader)
object_batch = next(it)
self.assertEqual(len(object_batch["synset_id"]), batch_size)
self.assertEqual(len(object_batch["model_id"]), batch_size)
self.assertEqual(len(object_batch["label"]), batch_size)
self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size)
self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size)