Skeleton of ShapeNetCore class

Summary: Skeleton of ShapeNetCore class that loads ShapeNet v1 from a given directory to a Dataset object. Overrides _init_, _len_, and _getitem_ from torch.utils.data.Dataset. Currently getitem returns verts, faces and id_str, where id_str is a concatenation of synset_id and obj_id. Planning on adding support for loading ShapeNet v2, retrieving textures and returning wordnet synsets (not just ids) in next diffs.

Reviewed By: nikhilaravi

Differential Revision: D21986222

fbshipit-source-id: c2c515303f1898b6c495b52cb53c74d691585326
This commit is contained in:
Luya Gao 2020-06-17 20:29:23 -07:00 committed by Facebook GitHub Bot
parent 2f6387f239
commit 9d279ba543
2 changed files with 136 additions and 0 deletions

View File

@ -0,0 +1,74 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import os
import warnings
from os import path
import torch
from pytorch3d.io import load_obj
class ShapeNetCore(torch.utils.data.Dataset):
"""
This class loads ShapeNet v.1 from a given directory into a Dataset object.
"""
def __init__(self, data_dir):
"""
Stores each object's synset id and models id from data_dir.
Args:
data_dir (path): Path to shapenet data
"""
self.data_dir = data_dir
# List of subdirectories of data_dir each containing a category of models.
# The name of each subdirectory is the wordnet synset offset of that category.
wnsynset_list = [
wnsynset
for wnsynset in os.listdir(data_dir)
if path.isdir(path.join(data_dir, wnsynset))
]
# Extract synset_id and model_id of each object from directory names.
# Each grandchildren directory of data_dir contains an object, and the name
# of the directory is the object's model_id.
self.synset_ids = []
self.model_ids = []
for synset in wnsynset_list:
for model in os.listdir(path.join(data_dir, synset)):
if not path.exists(path.join(data_dir, synset, model, "model.obj")):
msg = """ model.obj not found in the model directory %s
under synset directory %s.""" % (
model,
synset,
)
warnings.warn(msg)
else:
self.synset_ids.append(synset)
self.model_ids.append(model)
def __len__(self):
"""
Returns # of total models in shapenet core
"""
return len(self.model_ids)
def __getitem__(self, idx):
"""
Read a model by the given index.
Returns:
dictionary with following keys:
- verts: FloatTensor of shape (V, 3).
- faces: LongTensor of shape (F, 3) which indexes into the verts tensor.
- synset_id (str): synset id
- model_id (str): model id
"""
model = {}
model["synset_id"] = self.synset_ids[idx]
model["model_id"] = self.model_ids[idx]
model_path = path.join(
self.data_dir, model["synset_id"], model["model_id"], "model.obj"
)
model["verts"], faces, _ = load_obj(model_path)
model["faces"] = faces.verts_idx
return model

View File

@ -0,0 +1,62 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""
Sanity checks for loading ShapeNet Core v1.
"""
import os
import random
import unittest
import warnings
import torch
from common_testing import TestCaseMixin
from pytorch3d.datasets.shapenet.shapenet_core import ShapeNetCore
SHAPENET_PATH = None
class TestShapenetCore(TestCaseMixin, unittest.TestCase):
def test_load_shapenet_core(self):
# The ShapeNet dataset is not provided in the repo.
# Download this separately and update the `shapenet_path`
# with the location of the dataset in order to run this test.
if SHAPENET_PATH is None or not os.path.exists(SHAPENET_PATH):
url = "https://www.shapenet.org/"
msg = """ShapeNet data not found, download from %s, save it at the path %s,
update SHAPENET_PATH at the top of the file, and rerun""" % (
url,
SHAPENET_PATH,
)
warnings.warn(msg)
return True
# Load ShapeNetCore without specifying any particular categories.
shapenet_dataset = ShapeNetCore(SHAPENET_PATH)
# Count the number of grandchildren directories (which should be equal to
# the total number of objects in the dataset) by walking through the given
# directory.
wnsynset_list = [
wnsynset
for wnsynset in os.listdir(SHAPENET_PATH)
if os.path.isdir(os.path.join(SHAPENET_PATH, wnsynset))
]
model_num_list = [
(len(next(os.walk(os.path.join(SHAPENET_PATH, wnsynset)))[1]))
for wnsynset in wnsynset_list
]
# Check total number of objects in the dataset is correct.
self.assertEqual(len(shapenet_dataset), sum(model_num_list))
# Randomly retrieve an object from the dataset.
rand_obj = random.choice(shapenet_dataset)
self.assertEqual(len(rand_obj), 4)
# Check that data types and shapes of items returned by __getitem__ are correct.
verts, faces = rand_obj["verts"], rand_obj["faces"]
self.assertTrue(verts.dtype == torch.float32)
self.assertTrue(faces.dtype == torch.int64)
self.assertEqual(verts.ndim, 2)
self.assertEqual(verts.shape[-1], 3)
self.assertEqual(faces.ndim, 2)
self.assertEqual(faces.shape[-1], 3)