mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Summary: Skeleton of R2N2 that for now only returns verts and faces extracted from ShapeNetCore v1. Reviewed By: nikhilaravi Differential Revision: D22203656 fbshipit-source-id: 00db6ac76bfdb76fdbc77a2087c34a3f0ff01e6a
112 lines
4.5 KiB
Python
112 lines
4.5 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
"""
|
|
Sanity checks for loading R2N2.
|
|
"""
|
|
import json
|
|
import os
|
|
import unittest
|
|
|
|
import torch
|
|
from common_testing import TestCaseMixin
|
|
from pytorch3d.datasets import R2N2, collate_batched_meshes
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
# Set these paths in order to run the tests.
|
|
R2N2_PATH = None
|
|
SHAPENET_PATH = None
|
|
SPLITS_PATH = None
|
|
|
|
|
|
class TestR2N2(TestCaseMixin, unittest.TestCase):
|
|
def setUp(self):
|
|
"""
|
|
Check if the data paths are given otherwise skip tests.
|
|
"""
|
|
if SHAPENET_PATH is None or not os.path.exists(SHAPENET_PATH):
|
|
url = "https://www.shapenet.org/"
|
|
msg = (
|
|
"ShapeNet data not found, download from %s, update "
|
|
"SHAPENET_PATH at the top of the file, and rerun."
|
|
)
|
|
self.skipTest(msg % url)
|
|
if R2N2_PATH is None or not os.path.exists(R2N2_PATH):
|
|
url = "http://3d-r2n2.stanford.edu/"
|
|
msg = (
|
|
"R2N2 data not found, download from %s, update "
|
|
"R2N2_PATH at the top of the file, and rerun."
|
|
)
|
|
self.skipTest(msg % url)
|
|
if SPLITS_PATH is None or not os.path.exists(SPLITS_PATH):
|
|
msg = """Splits file not found, update SPLITS_PATH at the top
|
|
of the file, and rerun."""
|
|
self.skipTest(msg)
|
|
|
|
def test_load_R2N2(self):
|
|
"""
|
|
Test loading the train split of R2N2. Check the loaded dataset return items
|
|
of the correct shapes and types.
|
|
"""
|
|
# Load dataset in the train split.
|
|
split = "train"
|
|
r2n2_dataset = R2N2(split, SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
|
|
|
# Check total number of objects in the dataset is correct.
|
|
with open(SPLITS_PATH) as splits:
|
|
split_dict = json.load(splits)[split]
|
|
model_nums = [len(split_dict[synset].keys()) for synset in split_dict.keys()]
|
|
self.assertEqual(len(r2n2_dataset), sum(model_nums))
|
|
|
|
# Randomly retrieve an object from the dataset.
|
|
rand_obj = r2n2_dataset[torch.randint(len(r2n2_dataset), (1,))]
|
|
# Check that data type and shape of the item returned by __getitem__ are correct.
|
|
verts, faces = rand_obj["verts"], rand_obj["faces"]
|
|
self.assertTrue(verts.dtype == torch.float32)
|
|
self.assertTrue(faces.dtype == torch.int64)
|
|
self.assertEqual(verts.ndim, 2)
|
|
self.assertEqual(verts.shape[-1], 3)
|
|
self.assertEqual(faces.ndim, 2)
|
|
self.assertEqual(faces.shape[-1], 3)
|
|
|
|
def test_collate_models(self):
|
|
"""
|
|
Test collate_batched_meshes returns items of the correct shapes and types.
|
|
Check that when collate_batched_meshes is passed to Dataloader, batches of
|
|
the correct shapes and types are returned.
|
|
"""
|
|
# Load dataset in the train split.
|
|
split = "train"
|
|
r2n2_dataset = R2N2(split, SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
|
|
|
# Randomly retrieve several objects from the dataset and collate them.
|
|
collated_meshes = collate_batched_meshes(
|
|
[r2n2_dataset[idx] for idx in torch.randint(len(r2n2_dataset), (6,))]
|
|
)
|
|
# Check the collated verts and faces have the correct shapes.
|
|
verts, faces = collated_meshes["verts"], collated_meshes["faces"]
|
|
self.assertEqual(len(verts), 6)
|
|
self.assertEqual(len(faces), 6)
|
|
self.assertEqual(verts[0].shape[-1], 3)
|
|
self.assertEqual(faces[0].shape[-1], 3)
|
|
|
|
# Check the collated mesh has the correct shape.
|
|
mesh = collated_meshes["mesh"]
|
|
self.assertEqual(mesh.verts_padded().shape[0], 6)
|
|
self.assertEqual(mesh.verts_padded().shape[-1], 3)
|
|
self.assertEqual(mesh.faces_padded().shape[0], 6)
|
|
self.assertEqual(mesh.faces_padded().shape[-1], 3)
|
|
|
|
# Pass the custom collate_fn function to DataLoader and check elements
|
|
# in batch have the correct shape.
|
|
batch_size = 12
|
|
r2n2_loader = DataLoader(
|
|
r2n2_dataset, batch_size=batch_size, collate_fn=collate_batched_meshes
|
|
)
|
|
it = iter(r2n2_loader)
|
|
object_batch = next(it)
|
|
self.assertEqual(len(object_batch["synset_id"]), batch_size)
|
|
self.assertEqual(len(object_batch["model_id"]), batch_size)
|
|
self.assertEqual(len(object_batch["label"]), batch_size)
|
|
self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size)
|
|
self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size)
|