diff --git a/pytorch3d/datasets/shapenet/shapenet_core.py b/pytorch3d/datasets/shapenet/shapenet_core.py new file mode 100644 index 00000000..f9ebf584 --- /dev/null +++ b/pytorch3d/datasets/shapenet/shapenet_core.py @@ -0,0 +1,74 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import os +import warnings +from os import path + +import torch +from pytorch3d.io import load_obj + + +class ShapeNetCore(torch.utils.data.Dataset): + """ + This class loads ShapeNet v.1 from a given directory into a Dataset object. + """ + + def __init__(self, data_dir): + """ + Stores each object's synset id and models id from data_dir. + Args: + data_dir (path): Path to shapenet data + """ + self.data_dir = data_dir + + # List of subdirectories of data_dir each containing a category of models. + # The name of each subdirectory is the wordnet synset offset of that category. + wnsynset_list = [ + wnsynset + for wnsynset in os.listdir(data_dir) + if path.isdir(path.join(data_dir, wnsynset)) + ] + + # Extract synset_id and model_id of each object from directory names. + # Each grandchildren directory of data_dir contains an object, and the name + # of the directory is the object's model_id. + self.synset_ids = [] + self.model_ids = [] + for synset in wnsynset_list: + for model in os.listdir(path.join(data_dir, synset)): + if not path.exists(path.join(data_dir, synset, model, "model.obj")): + msg = """ model.obj not found in the model directory %s + under synset directory %s.""" % ( + model, + synset, + ) + warnings.warn(msg) + else: + self.synset_ids.append(synset) + self.model_ids.append(model) + + def __len__(self): + """ + Returns # of total models in shapenet core + """ + return len(self.model_ids) + + def __getitem__(self, idx): + """ + Read a model by the given index. + Returns: + dictionary with following keys: + - verts: FloatTensor of shape (V, 3). + - faces: LongTensor of shape (F, 3) which indexes into the verts tensor. + - synset_id (str): synset id + - model_id (str): model id + """ + model = {} + model["synset_id"] = self.synset_ids[idx] + model["model_id"] = self.model_ids[idx] + model_path = path.join( + self.data_dir, model["synset_id"], model["model_id"], "model.obj" + ) + model["verts"], faces, _ = load_obj(model_path) + model["faces"] = faces.verts_idx + return model diff --git a/tests/test_shapenet_core.py b/tests/test_shapenet_core.py new file mode 100644 index 00000000..55f2b1d4 --- /dev/null +++ b/tests/test_shapenet_core.py @@ -0,0 +1,62 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +""" +Sanity checks for loading ShapeNet Core v1. +""" +import os +import random +import unittest +import warnings + +import torch +from common_testing import TestCaseMixin +from pytorch3d.datasets.shapenet.shapenet_core import ShapeNetCore + + +SHAPENET_PATH = None + + +class TestShapenetCore(TestCaseMixin, unittest.TestCase): + def test_load_shapenet_core(self): + + # The ShapeNet dataset is not provided in the repo. + # Download this separately and update the `shapenet_path` + # with the location of the dataset in order to run this test. + if SHAPENET_PATH is None or not os.path.exists(SHAPENET_PATH): + url = "https://www.shapenet.org/" + msg = """ShapeNet data not found, download from %s, save it at the path %s, + update SHAPENET_PATH at the top of the file, and rerun""" % ( + url, + SHAPENET_PATH, + ) + warnings.warn(msg) + return True + + # Load ShapeNetCore without specifying any particular categories. + shapenet_dataset = ShapeNetCore(SHAPENET_PATH) + + # Count the number of grandchildren directories (which should be equal to + # the total number of objects in the dataset) by walking through the given + # directory. + wnsynset_list = [ + wnsynset + for wnsynset in os.listdir(SHAPENET_PATH) + if os.path.isdir(os.path.join(SHAPENET_PATH, wnsynset)) + ] + model_num_list = [ + (len(next(os.walk(os.path.join(SHAPENET_PATH, wnsynset)))[1])) + for wnsynset in wnsynset_list + ] + # Check total number of objects in the dataset is correct. + self.assertEqual(len(shapenet_dataset), sum(model_num_list)) + + # Randomly retrieve an object from the dataset. + rand_obj = random.choice(shapenet_dataset) + self.assertEqual(len(rand_obj), 4) + # Check that data types and shapes of items returned by __getitem__ are correct. + verts, faces = rand_obj["verts"], rand_obj["faces"] + self.assertTrue(verts.dtype == torch.float32) + self.assertTrue(faces.dtype == torch.int64) + self.assertEqual(verts.ndim, 2) + self.assertEqual(verts.shape[-1], 3) + self.assertEqual(faces.ndim, 2) + self.assertEqual(faces.shape[-1], 3)