mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Adding support for selecting categories and ver2 for ShapeNetCore
Summary: Adding support so that users can select which categories they would like to load with wordnet synset offsets or labels or a combination of both. ShapeNetCore now also supports loading v2. Reviewed By: nikhilaravi Differential Revision: D22039207 fbshipit-source-id: 1f0218acb790e5561e2ae373e99cebb9823eea1a
This commit is contained in:
parent
9d279ba543
commit
2ea6a7d8ad
5
pytorch3d/datasets/__init__.py
Normal file
5
pytorch3d/datasets/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
from .shapenet import ShapeNetCore
|
||||
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
5
pytorch3d/datasets/shapenet/__init__.py
Normal file
5
pytorch3d/datasets/shapenet/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
from .shapenet_core import ShapeNetCore
|
||||
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
@ -1,43 +1,103 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from os import path
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from pytorch3d.io import load_obj
|
||||
|
||||
|
||||
SYNSET_DICT_DIR = Path(__file__).resolve().parent
|
||||
|
||||
|
||||
class ShapeNetCore(torch.utils.data.Dataset):
|
||||
"""
|
||||
This class loads ShapeNet v.1 from a given directory into a Dataset object.
|
||||
This class loads ShapeNetCore from a given directory into a Dataset object.
|
||||
ShapeNetCore is a subset of the ShapeNet dataset and can be downloaded from
|
||||
https://www.shapenet.org/.
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir):
|
||||
def __init__(self, data_dir, synsets=None, version: int = 1):
|
||||
"""
|
||||
Stores each object's synset id and models id from data_dir.
|
||||
Store each object's synset id and models id from data_dir.
|
||||
Args:
|
||||
data_dir (path): Path to shapenet data
|
||||
data_dir: Path to ShapeNetCore data.
|
||||
synsets: List of synset categories to load from ShapeNetCore in the form of
|
||||
synset offsets or labels. A combination of both is also accepted.
|
||||
When no category is specified, all categories in data_dir are loaded.
|
||||
version: (int) version of ShapeNetCore data in data_dir, 1 or 2.
|
||||
Default is set to be 1. Version 1 has 57 categories and verions 2 has 55
|
||||
categories.
|
||||
Note: version 1 has two categories 02858304(boat) and 02992529(cellphone)
|
||||
that are hyponyms of categories 04530566(watercraft) and 04401088(telephone)
|
||||
respectively. You can combine the categories manually if needed.
|
||||
Version 2 doesn't have 02858304(boat) or 02834778(bicycle) compared to
|
||||
version 1.
|
||||
|
||||
"""
|
||||
self.data_dir = data_dir
|
||||
if version not in [1, 2]:
|
||||
raise ValueError("Version number must be either 1 or 2.")
|
||||
self.model_dir = "model.obj" if version == 1 else "models/model_normalized.obj"
|
||||
|
||||
# List of subdirectories of data_dir each containing a category of models.
|
||||
# The name of each subdirectory is the wordnet synset offset of that category.
|
||||
wnsynset_list = [
|
||||
wnsynset
|
||||
for wnsynset in os.listdir(data_dir)
|
||||
if path.isdir(path.join(data_dir, wnsynset))
|
||||
]
|
||||
# Synset dictionary mapping synset offsets to corresponding labels.
|
||||
dict_file = "shapenet_synset_dict_v%d.json" % version
|
||||
with open(path.join(SYNSET_DICT_DIR, dict_file), "r") as read_dict:
|
||||
self.synset_dict = json.load(read_dict)
|
||||
# Inverse dicitonary mapping synset labels to corresponding offsets.
|
||||
synset_inv = {label: offset for offset, label in self.synset_dict.items()}
|
||||
|
||||
# Extract synset_id and model_id of each object from directory names.
|
||||
# If categories are specified, check if each category is in the form of either
|
||||
# synset offset or synset label, and if the category exists in the given directory.
|
||||
if synsets is not None:
|
||||
# Set of categories to load in the form of synset offsets.
|
||||
synset_set = set()
|
||||
for synset in synsets:
|
||||
if (synset in self.synset_dict.keys()) and (
|
||||
path.isdir(path.join(data_dir, synset))
|
||||
):
|
||||
synset_set.add(synset)
|
||||
elif (synset in synset_inv.keys()) and (
|
||||
(path.isdir(path.join(data_dir, synset_inv[synset])))
|
||||
):
|
||||
synset_set.add(synset_inv[synset])
|
||||
else:
|
||||
msg = """Synset category %s either not part of ShapeNetCore dataset
|
||||
or cannot be found in %s.""" % (
|
||||
synset,
|
||||
data_dir,
|
||||
)
|
||||
warnings.warn(msg)
|
||||
# If no category is given, load every category in the given directory.
|
||||
else:
|
||||
synset_set = {
|
||||
synset
|
||||
for synset in os.listdir(data_dir)
|
||||
if path.isdir(path.join(data_dir, synset))
|
||||
}
|
||||
for synset in synset_set:
|
||||
if synset not in self.synset_dict.keys():
|
||||
msg = """Synset category %s(%s) is part of ShapeNetCore ver.%s
|
||||
but not found in %s.""" % (
|
||||
synset,
|
||||
self.synset_dict[synset],
|
||||
version,
|
||||
data_dir,
|
||||
)
|
||||
warnings.warn(msg)
|
||||
|
||||
# Extract model_id of each object from directory names.
|
||||
# Each grandchildren directory of data_dir contains an object, and the name
|
||||
# of the directory is the object's model_id.
|
||||
self.synset_ids = []
|
||||
self.model_ids = []
|
||||
for synset in wnsynset_list:
|
||||
for synset in synset_set:
|
||||
for model in os.listdir(path.join(data_dir, synset)):
|
||||
if not path.exists(path.join(data_dir, synset, model, "model.obj")):
|
||||
msg = """ model.obj not found in the model directory %s
|
||||
if not path.exists(path.join(data_dir, synset, model, self.model_dir)):
|
||||
msg = """ Object file not found in the model directory %s
|
||||
under synset directory %s.""" % (
|
||||
model,
|
||||
synset,
|
||||
@ -49,7 +109,7 @@ class ShapeNetCore(torch.utils.data.Dataset):
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Returns # of total models in shapenet core
|
||||
Return number of total models in shapenet core.
|
||||
"""
|
||||
return len(self.model_ids)
|
||||
|
||||
@ -62,13 +122,15 @@ class ShapeNetCore(torch.utils.data.Dataset):
|
||||
- faces: LongTensor of shape (F, 3) which indexes into the verts tensor.
|
||||
- synset_id (str): synset id
|
||||
- model_id (str): model id
|
||||
- label (str): synset label.
|
||||
"""
|
||||
model = {}
|
||||
model["synset_id"] = self.synset_ids[idx]
|
||||
model["model_id"] = self.model_ids[idx]
|
||||
model_path = path.join(
|
||||
self.data_dir, model["synset_id"], model["model_id"], "model.obj"
|
||||
self.data_dir, model["synset_id"], model["model_id"], self.model_dir
|
||||
)
|
||||
model["verts"], faces, _ = load_obj(model_path)
|
||||
model["faces"] = faces.verts_idx
|
||||
model["label"] = self.synset_dict[model["synset_id"]]
|
||||
return model
|
||||
|
59
pytorch3d/datasets/shapenet/shapenet_synset_dict_v1.json
Normal file
59
pytorch3d/datasets/shapenet/shapenet_synset_dict_v1.json
Normal file
@ -0,0 +1,59 @@
|
||||
{
|
||||
"04379243": "table",
|
||||
"02958343": "car",
|
||||
"03001627": "chair",
|
||||
"02691156": "airplane",
|
||||
"04256520": "sofa",
|
||||
"04090263": "rifle",
|
||||
"03636649": "lamp",
|
||||
"04530566": "watercraft",
|
||||
"02828884": "bench",
|
||||
"03691459": "loudspeaker",
|
||||
"02933112": "cabinet",
|
||||
"03211117": "display",
|
||||
"04401088": "telephone",
|
||||
"02924116": "bus",
|
||||
"02808440": "bathtub",
|
||||
"03467517": "guitar",
|
||||
"03325088": "faucet",
|
||||
"03046257": "clock",
|
||||
"03991062": "flowerpot",
|
||||
"03593526": "jar",
|
||||
"02876657": "bottle",
|
||||
"02871439": "bookshelf",
|
||||
"03642806": "laptop",
|
||||
"03624134": "knife",
|
||||
"04468005": "train",
|
||||
"02747177": "trash bin",
|
||||
"03790512": "motorbike",
|
||||
"03948459": "pistol",
|
||||
"03337140": "file cabinet",
|
||||
"02818832": "bed",
|
||||
"03928116": "piano",
|
||||
"04330267": "stove",
|
||||
"03797390": "mug",
|
||||
"02880940": "bowl",
|
||||
"04554684": "washer",
|
||||
"04004475": "printer",
|
||||
"03513137": "helmet",
|
||||
"03761084": "microwaves",
|
||||
"04225987": "skateboard",
|
||||
"04460130": "tower",
|
||||
"02942699": "camera",
|
||||
"02801938": "basket",
|
||||
"02946921": "can",
|
||||
"03938244": "pillow",
|
||||
"03710193": "mailbox",
|
||||
"03207941": "dishwasher",
|
||||
"04099429": "rocket",
|
||||
"02773838": "bag",
|
||||
"02843684": "birdhouse",
|
||||
"03261776": "earphone",
|
||||
"03759954": "microphone",
|
||||
"04074963": "remote",
|
||||
"03085013": "keyboard",
|
||||
"02834778": "bicycle",
|
||||
"02954340": "cap",
|
||||
"02858304": "boat",
|
||||
"02992529": "mobile phone"
|
||||
}
|
57
pytorch3d/datasets/shapenet/shapenet_synset_dict_v2.json
Normal file
57
pytorch3d/datasets/shapenet/shapenet_synset_dict_v2.json
Normal file
@ -0,0 +1,57 @@
|
||||
{
|
||||
"02691156": "airplane",
|
||||
"02747177": "trash bin",
|
||||
"02773838": "bag",
|
||||
"02801938": "basket",
|
||||
"02808440": "bathtub",
|
||||
"02818832": "bed",
|
||||
"02828884": "bench",
|
||||
"02843684": "birdhouse",
|
||||
"02871439": "bookshelf",
|
||||
"02876657": "bottle",
|
||||
"02880940": "bowl",
|
||||
"02924116": "bus",
|
||||
"02933112": "cabinet",
|
||||
"02942699": "camera",
|
||||
"02946921": "can",
|
||||
"02954340": "cap",
|
||||
"02958343": "car",
|
||||
"02992529": "cellphone",
|
||||
"03001627": "chair",
|
||||
"03046257": "clock",
|
||||
"03085013": "keyboard",
|
||||
"03207941": "dishwasher",
|
||||
"03211117": "display",
|
||||
"03261776": "earphone",
|
||||
"03325088": "faucet",
|
||||
"03337140": "file cabinet",
|
||||
"03467517": "guitar",
|
||||
"03513137": "helmet",
|
||||
"03593526": "jar",
|
||||
"03624134": "knife",
|
||||
"03636649": "lamp",
|
||||
"03642806": "laptop",
|
||||
"03691459": "loudspeaker",
|
||||
"03710193": "mailbox",
|
||||
"03759954": "microphone",
|
||||
"03761084": "microwaves",
|
||||
"03790512": "motorbike",
|
||||
"03797390": "mug",
|
||||
"03928116": "piano",
|
||||
"03938244": "pillow",
|
||||
"03948459": "pistol",
|
||||
"03991062": "flowerpot",
|
||||
"04004475": "printer",
|
||||
"04074963": "remote",
|
||||
"04090263": "rifle",
|
||||
"04099429": "rocket",
|
||||
"04225987": "skateboard",
|
||||
"04256520": "sofa",
|
||||
"04330267": "stove",
|
||||
"04379243": "table",
|
||||
"04401088": "telephone",
|
||||
"04460130": "tower",
|
||||
"04468005": "train",
|
||||
"04530566": "watercraft",
|
||||
"04554684": "washer"
|
||||
}
|
@ -9,7 +9,7 @@ import warnings
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.datasets.shapenet.shapenet_core import ShapeNetCore
|
||||
from pytorch3d.datasets import ShapeNetCore
|
||||
|
||||
|
||||
SHAPENET_PATH = None
|
||||
@ -31,6 +31,11 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
|
||||
warnings.warn(msg)
|
||||
return True
|
||||
|
||||
# Try load ShapeNetCore with an invalid version number and catch error.
|
||||
with self.assertRaises(ValueError) as err:
|
||||
ShapeNetCore(SHAPENET_PATH, version=3)
|
||||
self.assertTrue("Version number must be either 1 or 2." in str(err.exception))
|
||||
|
||||
# Load ShapeNetCore without specifying any particular categories.
|
||||
shapenet_dataset = ShapeNetCore(SHAPENET_PATH)
|
||||
|
||||
@ -51,7 +56,7 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
# Randomly retrieve an object from the dataset.
|
||||
rand_obj = random.choice(shapenet_dataset)
|
||||
self.assertEqual(len(rand_obj), 4)
|
||||
self.assertEqual(len(rand_obj), 5)
|
||||
# Check that data types and shapes of items returned by __getitem__ are correct.
|
||||
verts, faces = rand_obj["verts"], rand_obj["faces"]
|
||||
self.assertTrue(verts.dtype == torch.float32)
|
||||
@ -60,3 +65,31 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(verts.shape[-1], 3)
|
||||
self.assertEqual(faces.ndim, 2)
|
||||
self.assertEqual(faces.shape[-1], 3)
|
||||
|
||||
# Load six categories from ShapeNetCore.
|
||||
# Specify categories in the form of a combination of offsets and labels.
|
||||
shapenet_subset = ShapeNetCore(
|
||||
SHAPENET_PATH,
|
||||
synsets=[
|
||||
"04330267",
|
||||
"guitar",
|
||||
"02801938",
|
||||
"birdhouse",
|
||||
"03991062",
|
||||
"tower",
|
||||
],
|
||||
version=1,
|
||||
)
|
||||
subset_offsets = [
|
||||
"04330267",
|
||||
"03467517",
|
||||
"02801938",
|
||||
"02843684",
|
||||
"03991062",
|
||||
"04460130",
|
||||
]
|
||||
subset_model_nums = [
|
||||
(len(next(os.walk(os.path.join(SHAPENET_PATH, offset)))[1]))
|
||||
for offset in subset_offsets
|
||||
]
|
||||
self.assertEqual(len(shapenet_subset), sum(subset_model_nums))
|
||||
|
Loading…
x
Reference in New Issue
Block a user