Adding support for selecting categories and ver2 for ShapeNetCore

Summary: Adding support so that users can select which categories they would like to load with wordnet synset offsets or labels or a combination of both. ShapeNetCore now also supports loading v2.

Reviewed By: nikhilaravi

Differential Revision: D22039207

fbshipit-source-id: 1f0218acb790e5561e2ae373e99cebb9823eea1a
This commit is contained in:
Luya Gao 2020-06-17 20:29:23 -07:00 committed by Facebook GitHub Bot
parent 9d279ba543
commit 2ea6a7d8ad
6 changed files with 240 additions and 19 deletions

View File

@ -0,0 +1,5 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .shapenet import ShapeNetCore
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@ -0,0 +1,5 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .shapenet_core import ShapeNetCore
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@ -1,43 +1,103 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import json
import os
import warnings
from os import path
from pathlib import Path
import torch
from pytorch3d.io import load_obj
SYNSET_DICT_DIR = Path(__file__).resolve().parent
class ShapeNetCore(torch.utils.data.Dataset):
"""
This class loads ShapeNet v.1 from a given directory into a Dataset object.
This class loads ShapeNetCore from a given directory into a Dataset object.
ShapeNetCore is a subset of the ShapeNet dataset and can be downloaded from
https://www.shapenet.org/.
"""
def __init__(self, data_dir):
def __init__(self, data_dir, synsets=None, version: int = 1):
"""
Stores each object's synset id and models id from data_dir.
Store each object's synset id and models id from data_dir.
Args:
data_dir (path): Path to shapenet data
data_dir: Path to ShapeNetCore data.
synsets: List of synset categories to load from ShapeNetCore in the form of
synset offsets or labels. A combination of both is also accepted.
When no category is specified, all categories in data_dir are loaded.
version: (int) version of ShapeNetCore data in data_dir, 1 or 2.
Default is set to be 1. Version 1 has 57 categories and verions 2 has 55
categories.
Note: version 1 has two categories 02858304(boat) and 02992529(cellphone)
that are hyponyms of categories 04530566(watercraft) and 04401088(telephone)
respectively. You can combine the categories manually if needed.
Version 2 doesn't have 02858304(boat) or 02834778(bicycle) compared to
version 1.
"""
self.data_dir = data_dir
if version not in [1, 2]:
raise ValueError("Version number must be either 1 or 2.")
self.model_dir = "model.obj" if version == 1 else "models/model_normalized.obj"
# List of subdirectories of data_dir each containing a category of models.
# The name of each subdirectory is the wordnet synset offset of that category.
wnsynset_list = [
wnsynset
for wnsynset in os.listdir(data_dir)
if path.isdir(path.join(data_dir, wnsynset))
]
# Synset dictionary mapping synset offsets to corresponding labels.
dict_file = "shapenet_synset_dict_v%d.json" % version
with open(path.join(SYNSET_DICT_DIR, dict_file), "r") as read_dict:
self.synset_dict = json.load(read_dict)
# Inverse dicitonary mapping synset labels to corresponding offsets.
synset_inv = {label: offset for offset, label in self.synset_dict.items()}
# Extract synset_id and model_id of each object from directory names.
# If categories are specified, check if each category is in the form of either
# synset offset or synset label, and if the category exists in the given directory.
if synsets is not None:
# Set of categories to load in the form of synset offsets.
synset_set = set()
for synset in synsets:
if (synset in self.synset_dict.keys()) and (
path.isdir(path.join(data_dir, synset))
):
synset_set.add(synset)
elif (synset in synset_inv.keys()) and (
(path.isdir(path.join(data_dir, synset_inv[synset])))
):
synset_set.add(synset_inv[synset])
else:
msg = """Synset category %s either not part of ShapeNetCore dataset
or cannot be found in %s.""" % (
synset,
data_dir,
)
warnings.warn(msg)
# If no category is given, load every category in the given directory.
else:
synset_set = {
synset
for synset in os.listdir(data_dir)
if path.isdir(path.join(data_dir, synset))
}
for synset in synset_set:
if synset not in self.synset_dict.keys():
msg = """Synset category %s(%s) is part of ShapeNetCore ver.%s
but not found in %s.""" % (
synset,
self.synset_dict[synset],
version,
data_dir,
)
warnings.warn(msg)
# Extract model_id of each object from directory names.
# Each grandchildren directory of data_dir contains an object, and the name
# of the directory is the object's model_id.
self.synset_ids = []
self.model_ids = []
for synset in wnsynset_list:
for synset in synset_set:
for model in os.listdir(path.join(data_dir, synset)):
if not path.exists(path.join(data_dir, synset, model, "model.obj")):
msg = """ model.obj not found in the model directory %s
if not path.exists(path.join(data_dir, synset, model, self.model_dir)):
msg = """ Object file not found in the model directory %s
under synset directory %s.""" % (
model,
synset,
@ -49,7 +109,7 @@ class ShapeNetCore(torch.utils.data.Dataset):
def __len__(self):
"""
Returns # of total models in shapenet core
Return number of total models in shapenet core.
"""
return len(self.model_ids)
@ -62,13 +122,15 @@ class ShapeNetCore(torch.utils.data.Dataset):
- faces: LongTensor of shape (F, 3) which indexes into the verts tensor.
- synset_id (str): synset id
- model_id (str): model id
- label (str): synset label.
"""
model = {}
model["synset_id"] = self.synset_ids[idx]
model["model_id"] = self.model_ids[idx]
model_path = path.join(
self.data_dir, model["synset_id"], model["model_id"], "model.obj"
self.data_dir, model["synset_id"], model["model_id"], self.model_dir
)
model["verts"], faces, _ = load_obj(model_path)
model["faces"] = faces.verts_idx
model["label"] = self.synset_dict[model["synset_id"]]
return model

View File

@ -0,0 +1,59 @@
{
"04379243": "table",
"02958343": "car",
"03001627": "chair",
"02691156": "airplane",
"04256520": "sofa",
"04090263": "rifle",
"03636649": "lamp",
"04530566": "watercraft",
"02828884": "bench",
"03691459": "loudspeaker",
"02933112": "cabinet",
"03211117": "display",
"04401088": "telephone",
"02924116": "bus",
"02808440": "bathtub",
"03467517": "guitar",
"03325088": "faucet",
"03046257": "clock",
"03991062": "flowerpot",
"03593526": "jar",
"02876657": "bottle",
"02871439": "bookshelf",
"03642806": "laptop",
"03624134": "knife",
"04468005": "train",
"02747177": "trash bin",
"03790512": "motorbike",
"03948459": "pistol",
"03337140": "file cabinet",
"02818832": "bed",
"03928116": "piano",
"04330267": "stove",
"03797390": "mug",
"02880940": "bowl",
"04554684": "washer",
"04004475": "printer",
"03513137": "helmet",
"03761084": "microwaves",
"04225987": "skateboard",
"04460130": "tower",
"02942699": "camera",
"02801938": "basket",
"02946921": "can",
"03938244": "pillow",
"03710193": "mailbox",
"03207941": "dishwasher",
"04099429": "rocket",
"02773838": "bag",
"02843684": "birdhouse",
"03261776": "earphone",
"03759954": "microphone",
"04074963": "remote",
"03085013": "keyboard",
"02834778": "bicycle",
"02954340": "cap",
"02858304": "boat",
"02992529": "mobile phone"
}

View File

@ -0,0 +1,57 @@
{
"02691156": "airplane",
"02747177": "trash bin",
"02773838": "bag",
"02801938": "basket",
"02808440": "bathtub",
"02818832": "bed",
"02828884": "bench",
"02843684": "birdhouse",
"02871439": "bookshelf",
"02876657": "bottle",
"02880940": "bowl",
"02924116": "bus",
"02933112": "cabinet",
"02942699": "camera",
"02946921": "can",
"02954340": "cap",
"02958343": "car",
"02992529": "cellphone",
"03001627": "chair",
"03046257": "clock",
"03085013": "keyboard",
"03207941": "dishwasher",
"03211117": "display",
"03261776": "earphone",
"03325088": "faucet",
"03337140": "file cabinet",
"03467517": "guitar",
"03513137": "helmet",
"03593526": "jar",
"03624134": "knife",
"03636649": "lamp",
"03642806": "laptop",
"03691459": "loudspeaker",
"03710193": "mailbox",
"03759954": "microphone",
"03761084": "microwaves",
"03790512": "motorbike",
"03797390": "mug",
"03928116": "piano",
"03938244": "pillow",
"03948459": "pistol",
"03991062": "flowerpot",
"04004475": "printer",
"04074963": "remote",
"04090263": "rifle",
"04099429": "rocket",
"04225987": "skateboard",
"04256520": "sofa",
"04330267": "stove",
"04379243": "table",
"04401088": "telephone",
"04460130": "tower",
"04468005": "train",
"04530566": "watercraft",
"04554684": "washer"
}

View File

@ -9,7 +9,7 @@ import warnings
import torch
from common_testing import TestCaseMixin
from pytorch3d.datasets.shapenet.shapenet_core import ShapeNetCore
from pytorch3d.datasets import ShapeNetCore
SHAPENET_PATH = None
@ -31,6 +31,11 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
warnings.warn(msg)
return True
# Try load ShapeNetCore with an invalid version number and catch error.
with self.assertRaises(ValueError) as err:
ShapeNetCore(SHAPENET_PATH, version=3)
self.assertTrue("Version number must be either 1 or 2." in str(err.exception))
# Load ShapeNetCore without specifying any particular categories.
shapenet_dataset = ShapeNetCore(SHAPENET_PATH)
@ -51,7 +56,7 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
# Randomly retrieve an object from the dataset.
rand_obj = random.choice(shapenet_dataset)
self.assertEqual(len(rand_obj), 4)
self.assertEqual(len(rand_obj), 5)
# Check that data types and shapes of items returned by __getitem__ are correct.
verts, faces = rand_obj["verts"], rand_obj["faces"]
self.assertTrue(verts.dtype == torch.float32)
@ -60,3 +65,31 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
self.assertEqual(verts.shape[-1], 3)
self.assertEqual(faces.ndim, 2)
self.assertEqual(faces.shape[-1], 3)
# Load six categories from ShapeNetCore.
# Specify categories in the form of a combination of offsets and labels.
shapenet_subset = ShapeNetCore(
SHAPENET_PATH,
synsets=[
"04330267",
"guitar",
"02801938",
"birdhouse",
"03991062",
"tower",
],
version=1,
)
subset_offsets = [
"04330267",
"03467517",
"02801938",
"02843684",
"03991062",
"04460130",
]
subset_model_nums = [
(len(next(os.walk(os.path.join(SHAPENET_PATH, offset)))[1]))
for offset in subset_offsets
]
self.assertEqual(len(shapenet_subset), sum(subset_model_nums))