Adding support for selecting categories and ver2 for ShapeNetCore

Summary: Adding support so that users can select which categories they would like to load with wordnet synset offsets or labels or a combination of both. ShapeNetCore now also supports loading v2.

Reviewed By: nikhilaravi

Differential Revision: D22039207

fbshipit-source-id: 1f0218acb790e5561e2ae373e99cebb9823eea1a
This commit is contained in:
Luya Gao 2020-06-17 20:29:23 -07:00 committed by Facebook GitHub Bot
parent 9d279ba543
commit 2ea6a7d8ad
6 changed files with 240 additions and 19 deletions

View File

@ -0,0 +1,5 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .shapenet import ShapeNetCore
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@ -0,0 +1,5 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .shapenet_core import ShapeNetCore
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@ -1,43 +1,103 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import json
import os import os
import warnings import warnings
from os import path from os import path
from pathlib import Path
import torch import torch
from pytorch3d.io import load_obj from pytorch3d.io import load_obj
SYNSET_DICT_DIR = Path(__file__).resolve().parent
class ShapeNetCore(torch.utils.data.Dataset): class ShapeNetCore(torch.utils.data.Dataset):
""" """
This class loads ShapeNet v.1 from a given directory into a Dataset object. This class loads ShapeNetCore from a given directory into a Dataset object.
ShapeNetCore is a subset of the ShapeNet dataset and can be downloaded from
https://www.shapenet.org/.
""" """
def __init__(self, data_dir): def __init__(self, data_dir, synsets=None, version: int = 1):
""" """
Stores each object's synset id and models id from data_dir. Store each object's synset id and models id from data_dir.
Args: Args:
data_dir (path): Path to shapenet data data_dir: Path to ShapeNetCore data.
synsets: List of synset categories to load from ShapeNetCore in the form of
synset offsets or labels. A combination of both is also accepted.
When no category is specified, all categories in data_dir are loaded.
version: (int) version of ShapeNetCore data in data_dir, 1 or 2.
Default is set to be 1. Version 1 has 57 categories and verions 2 has 55
categories.
Note: version 1 has two categories 02858304(boat) and 02992529(cellphone)
that are hyponyms of categories 04530566(watercraft) and 04401088(telephone)
respectively. You can combine the categories manually if needed.
Version 2 doesn't have 02858304(boat) or 02834778(bicycle) compared to
version 1.
""" """
self.data_dir = data_dir self.data_dir = data_dir
if version not in [1, 2]:
raise ValueError("Version number must be either 1 or 2.")
self.model_dir = "model.obj" if version == 1 else "models/model_normalized.obj"
# List of subdirectories of data_dir each containing a category of models. # Synset dictionary mapping synset offsets to corresponding labels.
# The name of each subdirectory is the wordnet synset offset of that category. dict_file = "shapenet_synset_dict_v%d.json" % version
wnsynset_list = [ with open(path.join(SYNSET_DICT_DIR, dict_file), "r") as read_dict:
wnsynset self.synset_dict = json.load(read_dict)
for wnsynset in os.listdir(data_dir) # Inverse dicitonary mapping synset labels to corresponding offsets.
if path.isdir(path.join(data_dir, wnsynset)) synset_inv = {label: offset for offset, label in self.synset_dict.items()}
]
# Extract synset_id and model_id of each object from directory names. # If categories are specified, check if each category is in the form of either
# synset offset or synset label, and if the category exists in the given directory.
if synsets is not None:
# Set of categories to load in the form of synset offsets.
synset_set = set()
for synset in synsets:
if (synset in self.synset_dict.keys()) and (
path.isdir(path.join(data_dir, synset))
):
synset_set.add(synset)
elif (synset in synset_inv.keys()) and (
(path.isdir(path.join(data_dir, synset_inv[synset])))
):
synset_set.add(synset_inv[synset])
else:
msg = """Synset category %s either not part of ShapeNetCore dataset
or cannot be found in %s.""" % (
synset,
data_dir,
)
warnings.warn(msg)
# If no category is given, load every category in the given directory.
else:
synset_set = {
synset
for synset in os.listdir(data_dir)
if path.isdir(path.join(data_dir, synset))
}
for synset in synset_set:
if synset not in self.synset_dict.keys():
msg = """Synset category %s(%s) is part of ShapeNetCore ver.%s
but not found in %s.""" % (
synset,
self.synset_dict[synset],
version,
data_dir,
)
warnings.warn(msg)
# Extract model_id of each object from directory names.
# Each grandchildren directory of data_dir contains an object, and the name # Each grandchildren directory of data_dir contains an object, and the name
# of the directory is the object's model_id. # of the directory is the object's model_id.
self.synset_ids = [] self.synset_ids = []
self.model_ids = [] self.model_ids = []
for synset in wnsynset_list: for synset in synset_set:
for model in os.listdir(path.join(data_dir, synset)): for model in os.listdir(path.join(data_dir, synset)):
if not path.exists(path.join(data_dir, synset, model, "model.obj")): if not path.exists(path.join(data_dir, synset, model, self.model_dir)):
msg = """ model.obj not found in the model directory %s msg = """ Object file not found in the model directory %s
under synset directory %s.""" % ( under synset directory %s.""" % (
model, model,
synset, synset,
@ -49,7 +109,7 @@ class ShapeNetCore(torch.utils.data.Dataset):
def __len__(self): def __len__(self):
""" """
Returns # of total models in shapenet core Return number of total models in shapenet core.
""" """
return len(self.model_ids) return len(self.model_ids)
@ -62,13 +122,15 @@ class ShapeNetCore(torch.utils.data.Dataset):
- faces: LongTensor of shape (F, 3) which indexes into the verts tensor. - faces: LongTensor of shape (F, 3) which indexes into the verts tensor.
- synset_id (str): synset id - synset_id (str): synset id
- model_id (str): model id - model_id (str): model id
- label (str): synset label.
""" """
model = {} model = {}
model["synset_id"] = self.synset_ids[idx] model["synset_id"] = self.synset_ids[idx]
model["model_id"] = self.model_ids[idx] model["model_id"] = self.model_ids[idx]
model_path = path.join( model_path = path.join(
self.data_dir, model["synset_id"], model["model_id"], "model.obj" self.data_dir, model["synset_id"], model["model_id"], self.model_dir
) )
model["verts"], faces, _ = load_obj(model_path) model["verts"], faces, _ = load_obj(model_path)
model["faces"] = faces.verts_idx model["faces"] = faces.verts_idx
model["label"] = self.synset_dict[model["synset_id"]]
return model return model

View File

@ -0,0 +1,59 @@
{
"04379243": "table",
"02958343": "car",
"03001627": "chair",
"02691156": "airplane",
"04256520": "sofa",
"04090263": "rifle",
"03636649": "lamp",
"04530566": "watercraft",
"02828884": "bench",
"03691459": "loudspeaker",
"02933112": "cabinet",
"03211117": "display",
"04401088": "telephone",
"02924116": "bus",
"02808440": "bathtub",
"03467517": "guitar",
"03325088": "faucet",
"03046257": "clock",
"03991062": "flowerpot",
"03593526": "jar",
"02876657": "bottle",
"02871439": "bookshelf",
"03642806": "laptop",
"03624134": "knife",
"04468005": "train",
"02747177": "trash bin",
"03790512": "motorbike",
"03948459": "pistol",
"03337140": "file cabinet",
"02818832": "bed",
"03928116": "piano",
"04330267": "stove",
"03797390": "mug",
"02880940": "bowl",
"04554684": "washer",
"04004475": "printer",
"03513137": "helmet",
"03761084": "microwaves",
"04225987": "skateboard",
"04460130": "tower",
"02942699": "camera",
"02801938": "basket",
"02946921": "can",
"03938244": "pillow",
"03710193": "mailbox",
"03207941": "dishwasher",
"04099429": "rocket",
"02773838": "bag",
"02843684": "birdhouse",
"03261776": "earphone",
"03759954": "microphone",
"04074963": "remote",
"03085013": "keyboard",
"02834778": "bicycle",
"02954340": "cap",
"02858304": "boat",
"02992529": "mobile phone"
}

View File

@ -0,0 +1,57 @@
{
"02691156": "airplane",
"02747177": "trash bin",
"02773838": "bag",
"02801938": "basket",
"02808440": "bathtub",
"02818832": "bed",
"02828884": "bench",
"02843684": "birdhouse",
"02871439": "bookshelf",
"02876657": "bottle",
"02880940": "bowl",
"02924116": "bus",
"02933112": "cabinet",
"02942699": "camera",
"02946921": "can",
"02954340": "cap",
"02958343": "car",
"02992529": "cellphone",
"03001627": "chair",
"03046257": "clock",
"03085013": "keyboard",
"03207941": "dishwasher",
"03211117": "display",
"03261776": "earphone",
"03325088": "faucet",
"03337140": "file cabinet",
"03467517": "guitar",
"03513137": "helmet",
"03593526": "jar",
"03624134": "knife",
"03636649": "lamp",
"03642806": "laptop",
"03691459": "loudspeaker",
"03710193": "mailbox",
"03759954": "microphone",
"03761084": "microwaves",
"03790512": "motorbike",
"03797390": "mug",
"03928116": "piano",
"03938244": "pillow",
"03948459": "pistol",
"03991062": "flowerpot",
"04004475": "printer",
"04074963": "remote",
"04090263": "rifle",
"04099429": "rocket",
"04225987": "skateboard",
"04256520": "sofa",
"04330267": "stove",
"04379243": "table",
"04401088": "telephone",
"04460130": "tower",
"04468005": "train",
"04530566": "watercraft",
"04554684": "washer"
}

View File

@ -9,7 +9,7 @@ import warnings
import torch import torch
from common_testing import TestCaseMixin from common_testing import TestCaseMixin
from pytorch3d.datasets.shapenet.shapenet_core import ShapeNetCore from pytorch3d.datasets import ShapeNetCore
SHAPENET_PATH = None SHAPENET_PATH = None
@ -31,6 +31,11 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
warnings.warn(msg) warnings.warn(msg)
return True return True
# Try load ShapeNetCore with an invalid version number and catch error.
with self.assertRaises(ValueError) as err:
ShapeNetCore(SHAPENET_PATH, version=3)
self.assertTrue("Version number must be either 1 or 2." in str(err.exception))
# Load ShapeNetCore without specifying any particular categories. # Load ShapeNetCore without specifying any particular categories.
shapenet_dataset = ShapeNetCore(SHAPENET_PATH) shapenet_dataset = ShapeNetCore(SHAPENET_PATH)
@ -51,7 +56,7 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
# Randomly retrieve an object from the dataset. # Randomly retrieve an object from the dataset.
rand_obj = random.choice(shapenet_dataset) rand_obj = random.choice(shapenet_dataset)
self.assertEqual(len(rand_obj), 4) self.assertEqual(len(rand_obj), 5)
# Check that data types and shapes of items returned by __getitem__ are correct. # Check that data types and shapes of items returned by __getitem__ are correct.
verts, faces = rand_obj["verts"], rand_obj["faces"] verts, faces = rand_obj["verts"], rand_obj["faces"]
self.assertTrue(verts.dtype == torch.float32) self.assertTrue(verts.dtype == torch.float32)
@ -60,3 +65,31 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
self.assertEqual(verts.shape[-1], 3) self.assertEqual(verts.shape[-1], 3)
self.assertEqual(faces.ndim, 2) self.assertEqual(faces.ndim, 2)
self.assertEqual(faces.shape[-1], 3) self.assertEqual(faces.shape[-1], 3)
# Load six categories from ShapeNetCore.
# Specify categories in the form of a combination of offsets and labels.
shapenet_subset = ShapeNetCore(
SHAPENET_PATH,
synsets=[
"04330267",
"guitar",
"02801938",
"birdhouse",
"03991062",
"tower",
],
version=1,
)
subset_offsets = [
"04330267",
"03467517",
"02801938",
"02843684",
"03991062",
"04460130",
]
subset_model_nums = [
(len(next(os.walk(os.path.join(SHAPENET_PATH, offset)))[1]))
for offset in subset_offsets
]
self.assertEqual(len(shapenet_subset), sum(subset_model_nums))