mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
Adding support for selecting categories and ver2 for ShapeNetCore
Summary: Adding support so that users can select which categories they would like to load with wordnet synset offsets or labels or a combination of both. ShapeNetCore now also supports loading v2. Reviewed By: nikhilaravi Differential Revision: D22039207 fbshipit-source-id: 1f0218acb790e5561e2ae373e99cebb9823eea1a
This commit is contained in:
parent
9d279ba543
commit
2ea6a7d8ad
5
pytorch3d/datasets/__init__.py
Normal file
5
pytorch3d/datasets/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
from .shapenet import ShapeNetCore
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
5
pytorch3d/datasets/shapenet/__init__.py
Normal file
5
pytorch3d/datasets/shapenet/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
from .shapenet_core import ShapeNetCore
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
@ -1,43 +1,103 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from os import path
|
from os import path
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.io import load_obj
|
from pytorch3d.io import load_obj
|
||||||
|
|
||||||
|
|
||||||
|
SYNSET_DICT_DIR = Path(__file__).resolve().parent
|
||||||
|
|
||||||
|
|
||||||
class ShapeNetCore(torch.utils.data.Dataset):
|
class ShapeNetCore(torch.utils.data.Dataset):
|
||||||
"""
|
"""
|
||||||
This class loads ShapeNet v.1 from a given directory into a Dataset object.
|
This class loads ShapeNetCore from a given directory into a Dataset object.
|
||||||
|
ShapeNetCore is a subset of the ShapeNet dataset and can be downloaded from
|
||||||
|
https://www.shapenet.org/.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data_dir):
|
def __init__(self, data_dir, synsets=None, version: int = 1):
|
||||||
"""
|
"""
|
||||||
Stores each object's synset id and models id from data_dir.
|
Store each object's synset id and models id from data_dir.
|
||||||
Args:
|
Args:
|
||||||
data_dir (path): Path to shapenet data
|
data_dir: Path to ShapeNetCore data.
|
||||||
|
synsets: List of synset categories to load from ShapeNetCore in the form of
|
||||||
|
synset offsets or labels. A combination of both is also accepted.
|
||||||
|
When no category is specified, all categories in data_dir are loaded.
|
||||||
|
version: (int) version of ShapeNetCore data in data_dir, 1 or 2.
|
||||||
|
Default is set to be 1. Version 1 has 57 categories and verions 2 has 55
|
||||||
|
categories.
|
||||||
|
Note: version 1 has two categories 02858304(boat) and 02992529(cellphone)
|
||||||
|
that are hyponyms of categories 04530566(watercraft) and 04401088(telephone)
|
||||||
|
respectively. You can combine the categories manually if needed.
|
||||||
|
Version 2 doesn't have 02858304(boat) or 02834778(bicycle) compared to
|
||||||
|
version 1.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.data_dir = data_dir
|
self.data_dir = data_dir
|
||||||
|
if version not in [1, 2]:
|
||||||
|
raise ValueError("Version number must be either 1 or 2.")
|
||||||
|
self.model_dir = "model.obj" if version == 1 else "models/model_normalized.obj"
|
||||||
|
|
||||||
# List of subdirectories of data_dir each containing a category of models.
|
# Synset dictionary mapping synset offsets to corresponding labels.
|
||||||
# The name of each subdirectory is the wordnet synset offset of that category.
|
dict_file = "shapenet_synset_dict_v%d.json" % version
|
||||||
wnsynset_list = [
|
with open(path.join(SYNSET_DICT_DIR, dict_file), "r") as read_dict:
|
||||||
wnsynset
|
self.synset_dict = json.load(read_dict)
|
||||||
for wnsynset in os.listdir(data_dir)
|
# Inverse dicitonary mapping synset labels to corresponding offsets.
|
||||||
if path.isdir(path.join(data_dir, wnsynset))
|
synset_inv = {label: offset for offset, label in self.synset_dict.items()}
|
||||||
]
|
|
||||||
|
|
||||||
# Extract synset_id and model_id of each object from directory names.
|
# If categories are specified, check if each category is in the form of either
|
||||||
|
# synset offset or synset label, and if the category exists in the given directory.
|
||||||
|
if synsets is not None:
|
||||||
|
# Set of categories to load in the form of synset offsets.
|
||||||
|
synset_set = set()
|
||||||
|
for synset in synsets:
|
||||||
|
if (synset in self.synset_dict.keys()) and (
|
||||||
|
path.isdir(path.join(data_dir, synset))
|
||||||
|
):
|
||||||
|
synset_set.add(synset)
|
||||||
|
elif (synset in synset_inv.keys()) and (
|
||||||
|
(path.isdir(path.join(data_dir, synset_inv[synset])))
|
||||||
|
):
|
||||||
|
synset_set.add(synset_inv[synset])
|
||||||
|
else:
|
||||||
|
msg = """Synset category %s either not part of ShapeNetCore dataset
|
||||||
|
or cannot be found in %s.""" % (
|
||||||
|
synset,
|
||||||
|
data_dir,
|
||||||
|
)
|
||||||
|
warnings.warn(msg)
|
||||||
|
# If no category is given, load every category in the given directory.
|
||||||
|
else:
|
||||||
|
synset_set = {
|
||||||
|
synset
|
||||||
|
for synset in os.listdir(data_dir)
|
||||||
|
if path.isdir(path.join(data_dir, synset))
|
||||||
|
}
|
||||||
|
for synset in synset_set:
|
||||||
|
if synset not in self.synset_dict.keys():
|
||||||
|
msg = """Synset category %s(%s) is part of ShapeNetCore ver.%s
|
||||||
|
but not found in %s.""" % (
|
||||||
|
synset,
|
||||||
|
self.synset_dict[synset],
|
||||||
|
version,
|
||||||
|
data_dir,
|
||||||
|
)
|
||||||
|
warnings.warn(msg)
|
||||||
|
|
||||||
|
# Extract model_id of each object from directory names.
|
||||||
# Each grandchildren directory of data_dir contains an object, and the name
|
# Each grandchildren directory of data_dir contains an object, and the name
|
||||||
# of the directory is the object's model_id.
|
# of the directory is the object's model_id.
|
||||||
self.synset_ids = []
|
self.synset_ids = []
|
||||||
self.model_ids = []
|
self.model_ids = []
|
||||||
for synset in wnsynset_list:
|
for synset in synset_set:
|
||||||
for model in os.listdir(path.join(data_dir, synset)):
|
for model in os.listdir(path.join(data_dir, synset)):
|
||||||
if not path.exists(path.join(data_dir, synset, model, "model.obj")):
|
if not path.exists(path.join(data_dir, synset, model, self.model_dir)):
|
||||||
msg = """ model.obj not found in the model directory %s
|
msg = """ Object file not found in the model directory %s
|
||||||
under synset directory %s.""" % (
|
under synset directory %s.""" % (
|
||||||
model,
|
model,
|
||||||
synset,
|
synset,
|
||||||
@ -49,7 +109,7 @@ class ShapeNetCore(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""
|
"""
|
||||||
Returns # of total models in shapenet core
|
Return number of total models in shapenet core.
|
||||||
"""
|
"""
|
||||||
return len(self.model_ids)
|
return len(self.model_ids)
|
||||||
|
|
||||||
@ -62,13 +122,15 @@ class ShapeNetCore(torch.utils.data.Dataset):
|
|||||||
- faces: LongTensor of shape (F, 3) which indexes into the verts tensor.
|
- faces: LongTensor of shape (F, 3) which indexes into the verts tensor.
|
||||||
- synset_id (str): synset id
|
- synset_id (str): synset id
|
||||||
- model_id (str): model id
|
- model_id (str): model id
|
||||||
|
- label (str): synset label.
|
||||||
"""
|
"""
|
||||||
model = {}
|
model = {}
|
||||||
model["synset_id"] = self.synset_ids[idx]
|
model["synset_id"] = self.synset_ids[idx]
|
||||||
model["model_id"] = self.model_ids[idx]
|
model["model_id"] = self.model_ids[idx]
|
||||||
model_path = path.join(
|
model_path = path.join(
|
||||||
self.data_dir, model["synset_id"], model["model_id"], "model.obj"
|
self.data_dir, model["synset_id"], model["model_id"], self.model_dir
|
||||||
)
|
)
|
||||||
model["verts"], faces, _ = load_obj(model_path)
|
model["verts"], faces, _ = load_obj(model_path)
|
||||||
model["faces"] = faces.verts_idx
|
model["faces"] = faces.verts_idx
|
||||||
|
model["label"] = self.synset_dict[model["synset_id"]]
|
||||||
return model
|
return model
|
||||||
|
59
pytorch3d/datasets/shapenet/shapenet_synset_dict_v1.json
Normal file
59
pytorch3d/datasets/shapenet/shapenet_synset_dict_v1.json
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
{
|
||||||
|
"04379243": "table",
|
||||||
|
"02958343": "car",
|
||||||
|
"03001627": "chair",
|
||||||
|
"02691156": "airplane",
|
||||||
|
"04256520": "sofa",
|
||||||
|
"04090263": "rifle",
|
||||||
|
"03636649": "lamp",
|
||||||
|
"04530566": "watercraft",
|
||||||
|
"02828884": "bench",
|
||||||
|
"03691459": "loudspeaker",
|
||||||
|
"02933112": "cabinet",
|
||||||
|
"03211117": "display",
|
||||||
|
"04401088": "telephone",
|
||||||
|
"02924116": "bus",
|
||||||
|
"02808440": "bathtub",
|
||||||
|
"03467517": "guitar",
|
||||||
|
"03325088": "faucet",
|
||||||
|
"03046257": "clock",
|
||||||
|
"03991062": "flowerpot",
|
||||||
|
"03593526": "jar",
|
||||||
|
"02876657": "bottle",
|
||||||
|
"02871439": "bookshelf",
|
||||||
|
"03642806": "laptop",
|
||||||
|
"03624134": "knife",
|
||||||
|
"04468005": "train",
|
||||||
|
"02747177": "trash bin",
|
||||||
|
"03790512": "motorbike",
|
||||||
|
"03948459": "pistol",
|
||||||
|
"03337140": "file cabinet",
|
||||||
|
"02818832": "bed",
|
||||||
|
"03928116": "piano",
|
||||||
|
"04330267": "stove",
|
||||||
|
"03797390": "mug",
|
||||||
|
"02880940": "bowl",
|
||||||
|
"04554684": "washer",
|
||||||
|
"04004475": "printer",
|
||||||
|
"03513137": "helmet",
|
||||||
|
"03761084": "microwaves",
|
||||||
|
"04225987": "skateboard",
|
||||||
|
"04460130": "tower",
|
||||||
|
"02942699": "camera",
|
||||||
|
"02801938": "basket",
|
||||||
|
"02946921": "can",
|
||||||
|
"03938244": "pillow",
|
||||||
|
"03710193": "mailbox",
|
||||||
|
"03207941": "dishwasher",
|
||||||
|
"04099429": "rocket",
|
||||||
|
"02773838": "bag",
|
||||||
|
"02843684": "birdhouse",
|
||||||
|
"03261776": "earphone",
|
||||||
|
"03759954": "microphone",
|
||||||
|
"04074963": "remote",
|
||||||
|
"03085013": "keyboard",
|
||||||
|
"02834778": "bicycle",
|
||||||
|
"02954340": "cap",
|
||||||
|
"02858304": "boat",
|
||||||
|
"02992529": "mobile phone"
|
||||||
|
}
|
57
pytorch3d/datasets/shapenet/shapenet_synset_dict_v2.json
Normal file
57
pytorch3d/datasets/shapenet/shapenet_synset_dict_v2.json
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
{
|
||||||
|
"02691156": "airplane",
|
||||||
|
"02747177": "trash bin",
|
||||||
|
"02773838": "bag",
|
||||||
|
"02801938": "basket",
|
||||||
|
"02808440": "bathtub",
|
||||||
|
"02818832": "bed",
|
||||||
|
"02828884": "bench",
|
||||||
|
"02843684": "birdhouse",
|
||||||
|
"02871439": "bookshelf",
|
||||||
|
"02876657": "bottle",
|
||||||
|
"02880940": "bowl",
|
||||||
|
"02924116": "bus",
|
||||||
|
"02933112": "cabinet",
|
||||||
|
"02942699": "camera",
|
||||||
|
"02946921": "can",
|
||||||
|
"02954340": "cap",
|
||||||
|
"02958343": "car",
|
||||||
|
"02992529": "cellphone",
|
||||||
|
"03001627": "chair",
|
||||||
|
"03046257": "clock",
|
||||||
|
"03085013": "keyboard",
|
||||||
|
"03207941": "dishwasher",
|
||||||
|
"03211117": "display",
|
||||||
|
"03261776": "earphone",
|
||||||
|
"03325088": "faucet",
|
||||||
|
"03337140": "file cabinet",
|
||||||
|
"03467517": "guitar",
|
||||||
|
"03513137": "helmet",
|
||||||
|
"03593526": "jar",
|
||||||
|
"03624134": "knife",
|
||||||
|
"03636649": "lamp",
|
||||||
|
"03642806": "laptop",
|
||||||
|
"03691459": "loudspeaker",
|
||||||
|
"03710193": "mailbox",
|
||||||
|
"03759954": "microphone",
|
||||||
|
"03761084": "microwaves",
|
||||||
|
"03790512": "motorbike",
|
||||||
|
"03797390": "mug",
|
||||||
|
"03928116": "piano",
|
||||||
|
"03938244": "pillow",
|
||||||
|
"03948459": "pistol",
|
||||||
|
"03991062": "flowerpot",
|
||||||
|
"04004475": "printer",
|
||||||
|
"04074963": "remote",
|
||||||
|
"04090263": "rifle",
|
||||||
|
"04099429": "rocket",
|
||||||
|
"04225987": "skateboard",
|
||||||
|
"04256520": "sofa",
|
||||||
|
"04330267": "stove",
|
||||||
|
"04379243": "table",
|
||||||
|
"04401088": "telephone",
|
||||||
|
"04460130": "tower",
|
||||||
|
"04468005": "train",
|
||||||
|
"04530566": "watercraft",
|
||||||
|
"04554684": "washer"
|
||||||
|
}
|
@ -9,7 +9,7 @@ import warnings
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from common_testing import TestCaseMixin
|
from common_testing import TestCaseMixin
|
||||||
from pytorch3d.datasets.shapenet.shapenet_core import ShapeNetCore
|
from pytorch3d.datasets import ShapeNetCore
|
||||||
|
|
||||||
|
|
||||||
SHAPENET_PATH = None
|
SHAPENET_PATH = None
|
||||||
@ -31,6 +31,11 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
|
|||||||
warnings.warn(msg)
|
warnings.warn(msg)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# Try load ShapeNetCore with an invalid version number and catch error.
|
||||||
|
with self.assertRaises(ValueError) as err:
|
||||||
|
ShapeNetCore(SHAPENET_PATH, version=3)
|
||||||
|
self.assertTrue("Version number must be either 1 or 2." in str(err.exception))
|
||||||
|
|
||||||
# Load ShapeNetCore without specifying any particular categories.
|
# Load ShapeNetCore without specifying any particular categories.
|
||||||
shapenet_dataset = ShapeNetCore(SHAPENET_PATH)
|
shapenet_dataset = ShapeNetCore(SHAPENET_PATH)
|
||||||
|
|
||||||
@ -51,7 +56,7 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
# Randomly retrieve an object from the dataset.
|
# Randomly retrieve an object from the dataset.
|
||||||
rand_obj = random.choice(shapenet_dataset)
|
rand_obj = random.choice(shapenet_dataset)
|
||||||
self.assertEqual(len(rand_obj), 4)
|
self.assertEqual(len(rand_obj), 5)
|
||||||
# Check that data types and shapes of items returned by __getitem__ are correct.
|
# Check that data types and shapes of items returned by __getitem__ are correct.
|
||||||
verts, faces = rand_obj["verts"], rand_obj["faces"]
|
verts, faces = rand_obj["verts"], rand_obj["faces"]
|
||||||
self.assertTrue(verts.dtype == torch.float32)
|
self.assertTrue(verts.dtype == torch.float32)
|
||||||
@ -60,3 +65,31 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
|
|||||||
self.assertEqual(verts.shape[-1], 3)
|
self.assertEqual(verts.shape[-1], 3)
|
||||||
self.assertEqual(faces.ndim, 2)
|
self.assertEqual(faces.ndim, 2)
|
||||||
self.assertEqual(faces.shape[-1], 3)
|
self.assertEqual(faces.shape[-1], 3)
|
||||||
|
|
||||||
|
# Load six categories from ShapeNetCore.
|
||||||
|
# Specify categories in the form of a combination of offsets and labels.
|
||||||
|
shapenet_subset = ShapeNetCore(
|
||||||
|
SHAPENET_PATH,
|
||||||
|
synsets=[
|
||||||
|
"04330267",
|
||||||
|
"guitar",
|
||||||
|
"02801938",
|
||||||
|
"birdhouse",
|
||||||
|
"03991062",
|
||||||
|
"tower",
|
||||||
|
],
|
||||||
|
version=1,
|
||||||
|
)
|
||||||
|
subset_offsets = [
|
||||||
|
"04330267",
|
||||||
|
"03467517",
|
||||||
|
"02801938",
|
||||||
|
"02843684",
|
||||||
|
"03991062",
|
||||||
|
"04460130",
|
||||||
|
]
|
||||||
|
subset_model_nums = [
|
||||||
|
(len(next(os.walk(os.path.join(SHAPENET_PATH, offset)))[1]))
|
||||||
|
for offset in subset_offsets
|
||||||
|
]
|
||||||
|
self.assertEqual(len(shapenet_subset), sum(subset_model_nums))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user