pytorch3d/tests/test_shapenet_core.py
Luya Gao 9d279ba543 Skeleton of ShapeNetCore class
Summary: Skeleton of ShapeNetCore class that loads ShapeNet v1 from a given directory to a Dataset object. Overrides _init_, _len_, and _getitem_ from torch.utils.data.Dataset. Currently getitem returns verts, faces and id_str, where id_str is a concatenation of synset_id and obj_id. Planning on adding support for loading ShapeNet v2, retrieving textures and returning wordnet synsets (not just ids) in next diffs.

Reviewed By: nikhilaravi

Differential Revision: D21986222

fbshipit-source-id: c2c515303f1898b6c495b52cb53c74d691585326
2020-06-17 20:31:01 -07:00

63 lines
2.3 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""
Sanity checks for loading ShapeNet Core v1.
"""
import os
import random
import unittest
import warnings
import torch
from common_testing import TestCaseMixin
from pytorch3d.datasets.shapenet.shapenet_core import ShapeNetCore
SHAPENET_PATH = None
class TestShapenetCore(TestCaseMixin, unittest.TestCase):
def test_load_shapenet_core(self):
# The ShapeNet dataset is not provided in the repo.
# Download this separately and update the `shapenet_path`
# with the location of the dataset in order to run this test.
if SHAPENET_PATH is None or not os.path.exists(SHAPENET_PATH):
url = "https://www.shapenet.org/"
msg = """ShapeNet data not found, download from %s, save it at the path %s,
update SHAPENET_PATH at the top of the file, and rerun""" % (
url,
SHAPENET_PATH,
)
warnings.warn(msg)
return True
# Load ShapeNetCore without specifying any particular categories.
shapenet_dataset = ShapeNetCore(SHAPENET_PATH)
# Count the number of grandchildren directories (which should be equal to
# the total number of objects in the dataset) by walking through the given
# directory.
wnsynset_list = [
wnsynset
for wnsynset in os.listdir(SHAPENET_PATH)
if os.path.isdir(os.path.join(SHAPENET_PATH, wnsynset))
]
model_num_list = [
(len(next(os.walk(os.path.join(SHAPENET_PATH, wnsynset)))[1]))
for wnsynset in wnsynset_list
]
# Check total number of objects in the dataset is correct.
self.assertEqual(len(shapenet_dataset), sum(model_num_list))
# Randomly retrieve an object from the dataset.
rand_obj = random.choice(shapenet_dataset)
self.assertEqual(len(rand_obj), 4)
# Check that data types and shapes of items returned by __getitem__ are correct.
verts, faces = rand_obj["verts"], rand_obj["faces"]
self.assertTrue(verts.dtype == torch.float32)
self.assertTrue(faces.dtype == torch.int64)
self.assertEqual(verts.ndim, 2)
self.assertEqual(verts.shape[-1], 3)
self.assertEqual(faces.ndim, 2)
self.assertEqual(faces.shape[-1], 3)