From 2ea6a7d8ad219c7b0a8b557b1ddcc2bfba450875 Mon Sep 17 00:00:00 2001 From: Luya Gao Date: Wed, 17 Jun 2020 20:29:23 -0700 Subject: [PATCH] Adding support for selecting categories and ver2 for ShapeNetCore Summary: Adding support so that users can select which categories they would like to load with wordnet synset offsets or labels or a combination of both. ShapeNetCore now also supports loading v2. Reviewed By: nikhilaravi Differential Revision: D22039207 fbshipit-source-id: 1f0218acb790e5561e2ae373e99cebb9823eea1a --- pytorch3d/datasets/__init__.py | 5 + pytorch3d/datasets/shapenet/__init__.py | 5 + pytorch3d/datasets/shapenet/shapenet_core.py | 96 +++++++++++++++---- .../shapenet/shapenet_synset_dict_v1.json | 59 ++++++++++++ .../shapenet/shapenet_synset_dict_v2.json | 57 +++++++++++ tests/test_shapenet_core.py | 37 ++++++- 6 files changed, 240 insertions(+), 19 deletions(-) create mode 100644 pytorch3d/datasets/__init__.py create mode 100644 pytorch3d/datasets/shapenet/__init__.py create mode 100644 pytorch3d/datasets/shapenet/shapenet_synset_dict_v1.json create mode 100644 pytorch3d/datasets/shapenet/shapenet_synset_dict_v2.json diff --git a/pytorch3d/datasets/__init__.py b/pytorch3d/datasets/__init__.py new file mode 100644 index 00000000..3cf0f3f3 --- /dev/null +++ b/pytorch3d/datasets/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +from .shapenet import ShapeNetCore + + +__all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/pytorch3d/datasets/shapenet/__init__.py b/pytorch3d/datasets/shapenet/__init__.py new file mode 100644 index 00000000..dd0bc863 --- /dev/null +++ b/pytorch3d/datasets/shapenet/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +from .shapenet_core import ShapeNetCore + + +__all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/pytorch3d/datasets/shapenet/shapenet_core.py b/pytorch3d/datasets/shapenet/shapenet_core.py index f9ebf584..25b6bca0 100644 --- a/pytorch3d/datasets/shapenet/shapenet_core.py +++ b/pytorch3d/datasets/shapenet/shapenet_core.py @@ -1,43 +1,103 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import json import os import warnings from os import path +from pathlib import Path import torch from pytorch3d.io import load_obj +SYNSET_DICT_DIR = Path(__file__).resolve().parent + + class ShapeNetCore(torch.utils.data.Dataset): """ - This class loads ShapeNet v.1 from a given directory into a Dataset object. + This class loads ShapeNetCore from a given directory into a Dataset object. + ShapeNetCore is a subset of the ShapeNet dataset and can be downloaded from + https://www.shapenet.org/. """ - def __init__(self, data_dir): + def __init__(self, data_dir, synsets=None, version: int = 1): """ - Stores each object's synset id and models id from data_dir. + Store each object's synset id and models id from data_dir. Args: - data_dir (path): Path to shapenet data + data_dir: Path to ShapeNetCore data. + synsets: List of synset categories to load from ShapeNetCore in the form of + synset offsets or labels. A combination of both is also accepted. + When no category is specified, all categories in data_dir are loaded. + version: (int) version of ShapeNetCore data in data_dir, 1 or 2. + Default is set to be 1. Version 1 has 57 categories and verions 2 has 55 + categories. + Note: version 1 has two categories 02858304(boat) and 02992529(cellphone) + that are hyponyms of categories 04530566(watercraft) and 04401088(telephone) + respectively. You can combine the categories manually if needed. + Version 2 doesn't have 02858304(boat) or 02834778(bicycle) compared to + version 1. + """ self.data_dir = data_dir + if version not in [1, 2]: + raise ValueError("Version number must be either 1 or 2.") + self.model_dir = "model.obj" if version == 1 else "models/model_normalized.obj" - # List of subdirectories of data_dir each containing a category of models. - # The name of each subdirectory is the wordnet synset offset of that category. - wnsynset_list = [ - wnsynset - for wnsynset in os.listdir(data_dir) - if path.isdir(path.join(data_dir, wnsynset)) - ] + # Synset dictionary mapping synset offsets to corresponding labels. + dict_file = "shapenet_synset_dict_v%d.json" % version + with open(path.join(SYNSET_DICT_DIR, dict_file), "r") as read_dict: + self.synset_dict = json.load(read_dict) + # Inverse dicitonary mapping synset labels to corresponding offsets. + synset_inv = {label: offset for offset, label in self.synset_dict.items()} - # Extract synset_id and model_id of each object from directory names. + # If categories are specified, check if each category is in the form of either + # synset offset or synset label, and if the category exists in the given directory. + if synsets is not None: + # Set of categories to load in the form of synset offsets. + synset_set = set() + for synset in synsets: + if (synset in self.synset_dict.keys()) and ( + path.isdir(path.join(data_dir, synset)) + ): + synset_set.add(synset) + elif (synset in synset_inv.keys()) and ( + (path.isdir(path.join(data_dir, synset_inv[synset]))) + ): + synset_set.add(synset_inv[synset]) + else: + msg = """Synset category %s either not part of ShapeNetCore dataset + or cannot be found in %s.""" % ( + synset, + data_dir, + ) + warnings.warn(msg) + # If no category is given, load every category in the given directory. + else: + synset_set = { + synset + for synset in os.listdir(data_dir) + if path.isdir(path.join(data_dir, synset)) + } + for synset in synset_set: + if synset not in self.synset_dict.keys(): + msg = """Synset category %s(%s) is part of ShapeNetCore ver.%s + but not found in %s.""" % ( + synset, + self.synset_dict[synset], + version, + data_dir, + ) + warnings.warn(msg) + + # Extract model_id of each object from directory names. # Each grandchildren directory of data_dir contains an object, and the name # of the directory is the object's model_id. self.synset_ids = [] self.model_ids = [] - for synset in wnsynset_list: + for synset in synset_set: for model in os.listdir(path.join(data_dir, synset)): - if not path.exists(path.join(data_dir, synset, model, "model.obj")): - msg = """ model.obj not found in the model directory %s + if not path.exists(path.join(data_dir, synset, model, self.model_dir)): + msg = """ Object file not found in the model directory %s under synset directory %s.""" % ( model, synset, @@ -49,7 +109,7 @@ class ShapeNetCore(torch.utils.data.Dataset): def __len__(self): """ - Returns # of total models in shapenet core + Return number of total models in shapenet core. """ return len(self.model_ids) @@ -62,13 +122,15 @@ class ShapeNetCore(torch.utils.data.Dataset): - faces: LongTensor of shape (F, 3) which indexes into the verts tensor. - synset_id (str): synset id - model_id (str): model id + - label (str): synset label. """ model = {} model["synset_id"] = self.synset_ids[idx] model["model_id"] = self.model_ids[idx] model_path = path.join( - self.data_dir, model["synset_id"], model["model_id"], "model.obj" + self.data_dir, model["synset_id"], model["model_id"], self.model_dir ) model["verts"], faces, _ = load_obj(model_path) model["faces"] = faces.verts_idx + model["label"] = self.synset_dict[model["synset_id"]] return model diff --git a/pytorch3d/datasets/shapenet/shapenet_synset_dict_v1.json b/pytorch3d/datasets/shapenet/shapenet_synset_dict_v1.json new file mode 100644 index 00000000..b2fc62ae --- /dev/null +++ b/pytorch3d/datasets/shapenet/shapenet_synset_dict_v1.json @@ -0,0 +1,59 @@ +{ + "04379243": "table", + "02958343": "car", + "03001627": "chair", + "02691156": "airplane", + "04256520": "sofa", + "04090263": "rifle", + "03636649": "lamp", + "04530566": "watercraft", + "02828884": "bench", + "03691459": "loudspeaker", + "02933112": "cabinet", + "03211117": "display", + "04401088": "telephone", + "02924116": "bus", + "02808440": "bathtub", + "03467517": "guitar", + "03325088": "faucet", + "03046257": "clock", + "03991062": "flowerpot", + "03593526": "jar", + "02876657": "bottle", + "02871439": "bookshelf", + "03642806": "laptop", + "03624134": "knife", + "04468005": "train", + "02747177": "trash bin", + "03790512": "motorbike", + "03948459": "pistol", + "03337140": "file cabinet", + "02818832": "bed", + "03928116": "piano", + "04330267": "stove", + "03797390": "mug", + "02880940": "bowl", + "04554684": "washer", + "04004475": "printer", + "03513137": "helmet", + "03761084": "microwaves", + "04225987": "skateboard", + "04460130": "tower", + "02942699": "camera", + "02801938": "basket", + "02946921": "can", + "03938244": "pillow", + "03710193": "mailbox", + "03207941": "dishwasher", + "04099429": "rocket", + "02773838": "bag", + "02843684": "birdhouse", + "03261776": "earphone", + "03759954": "microphone", + "04074963": "remote", + "03085013": "keyboard", + "02834778": "bicycle", + "02954340": "cap", + "02858304": "boat", + "02992529": "mobile phone" +} diff --git a/pytorch3d/datasets/shapenet/shapenet_synset_dict_v2.json b/pytorch3d/datasets/shapenet/shapenet_synset_dict_v2.json new file mode 100644 index 00000000..f0107c93 --- /dev/null +++ b/pytorch3d/datasets/shapenet/shapenet_synset_dict_v2.json @@ -0,0 +1,57 @@ +{ + "02691156": "airplane", + "02747177": "trash bin", + "02773838": "bag", + "02801938": "basket", + "02808440": "bathtub", + "02818832": "bed", + "02828884": "bench", + "02843684": "birdhouse", + "02871439": "bookshelf", + "02876657": "bottle", + "02880940": "bowl", + "02924116": "bus", + "02933112": "cabinet", + "02942699": "camera", + "02946921": "can", + "02954340": "cap", + "02958343": "car", + "02992529": "cellphone", + "03001627": "chair", + "03046257": "clock", + "03085013": "keyboard", + "03207941": "dishwasher", + "03211117": "display", + "03261776": "earphone", + "03325088": "faucet", + "03337140": "file cabinet", + "03467517": "guitar", + "03513137": "helmet", + "03593526": "jar", + "03624134": "knife", + "03636649": "lamp", + "03642806": "laptop", + "03691459": "loudspeaker", + "03710193": "mailbox", + "03759954": "microphone", + "03761084": "microwaves", + "03790512": "motorbike", + "03797390": "mug", + "03928116": "piano", + "03938244": "pillow", + "03948459": "pistol", + "03991062": "flowerpot", + "04004475": "printer", + "04074963": "remote", + "04090263": "rifle", + "04099429": "rocket", + "04225987": "skateboard", + "04256520": "sofa", + "04330267": "stove", + "04379243": "table", + "04401088": "telephone", + "04460130": "tower", + "04468005": "train", + "04530566": "watercraft", + "04554684": "washer" +} diff --git a/tests/test_shapenet_core.py b/tests/test_shapenet_core.py index 55f2b1d4..ff623f78 100644 --- a/tests/test_shapenet_core.py +++ b/tests/test_shapenet_core.py @@ -9,7 +9,7 @@ import warnings import torch from common_testing import TestCaseMixin -from pytorch3d.datasets.shapenet.shapenet_core import ShapeNetCore +from pytorch3d.datasets import ShapeNetCore SHAPENET_PATH = None @@ -31,6 +31,11 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase): warnings.warn(msg) return True + # Try load ShapeNetCore with an invalid version number and catch error. + with self.assertRaises(ValueError) as err: + ShapeNetCore(SHAPENET_PATH, version=3) + self.assertTrue("Version number must be either 1 or 2." in str(err.exception)) + # Load ShapeNetCore without specifying any particular categories. shapenet_dataset = ShapeNetCore(SHAPENET_PATH) @@ -51,7 +56,7 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase): # Randomly retrieve an object from the dataset. rand_obj = random.choice(shapenet_dataset) - self.assertEqual(len(rand_obj), 4) + self.assertEqual(len(rand_obj), 5) # Check that data types and shapes of items returned by __getitem__ are correct. verts, faces = rand_obj["verts"], rand_obj["faces"] self.assertTrue(verts.dtype == torch.float32) @@ -60,3 +65,31 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase): self.assertEqual(verts.shape[-1], 3) self.assertEqual(faces.ndim, 2) self.assertEqual(faces.shape[-1], 3) + + # Load six categories from ShapeNetCore. + # Specify categories in the form of a combination of offsets and labels. + shapenet_subset = ShapeNetCore( + SHAPENET_PATH, + synsets=[ + "04330267", + "guitar", + "02801938", + "birdhouse", + "03991062", + "tower", + ], + version=1, + ) + subset_offsets = [ + "04330267", + "03467517", + "02801938", + "02843684", + "03991062", + "04460130", + ] + subset_model_nums = [ + (len(next(os.walk(os.path.join(SHAPENET_PATH, offset)))[1])) + for offset in subset_offsets + ] + self.assertEqual(len(shapenet_subset), sum(subset_model_nums))