mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Skeleton of ShapeNetCore class
Summary: Skeleton of ShapeNetCore class that loads ShapeNet v1 from a given directory to a Dataset object. Overrides _init_, _len_, and _getitem_ from torch.utils.data.Dataset. Currently getitem returns verts, faces and id_str, where id_str is a concatenation of synset_id and obj_id. Planning on adding support for loading ShapeNet v2, retrieving textures and returning wordnet synsets (not just ids) in next diffs. Reviewed By: nikhilaravi Differential Revision: D21986222 fbshipit-source-id: c2c515303f1898b6c495b52cb53c74d691585326
This commit is contained in:
parent
2f6387f239
commit
9d279ba543
74
pytorch3d/datasets/shapenet/shapenet_core.py
Normal file
74
pytorch3d/datasets/shapenet/shapenet_core.py
Normal file
@ -0,0 +1,74 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from os import path
|
||||
|
||||
import torch
|
||||
from pytorch3d.io import load_obj
|
||||
|
||||
|
||||
class ShapeNetCore(torch.utils.data.Dataset):
|
||||
"""
|
||||
This class loads ShapeNet v.1 from a given directory into a Dataset object.
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir):
|
||||
"""
|
||||
Stores each object's synset id and models id from data_dir.
|
||||
Args:
|
||||
data_dir (path): Path to shapenet data
|
||||
"""
|
||||
self.data_dir = data_dir
|
||||
|
||||
# List of subdirectories of data_dir each containing a category of models.
|
||||
# The name of each subdirectory is the wordnet synset offset of that category.
|
||||
wnsynset_list = [
|
||||
wnsynset
|
||||
for wnsynset in os.listdir(data_dir)
|
||||
if path.isdir(path.join(data_dir, wnsynset))
|
||||
]
|
||||
|
||||
# Extract synset_id and model_id of each object from directory names.
|
||||
# Each grandchildren directory of data_dir contains an object, and the name
|
||||
# of the directory is the object's model_id.
|
||||
self.synset_ids = []
|
||||
self.model_ids = []
|
||||
for synset in wnsynset_list:
|
||||
for model in os.listdir(path.join(data_dir, synset)):
|
||||
if not path.exists(path.join(data_dir, synset, model, "model.obj")):
|
||||
msg = """ model.obj not found in the model directory %s
|
||||
under synset directory %s.""" % (
|
||||
model,
|
||||
synset,
|
||||
)
|
||||
warnings.warn(msg)
|
||||
else:
|
||||
self.synset_ids.append(synset)
|
||||
self.model_ids.append(model)
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Returns # of total models in shapenet core
|
||||
"""
|
||||
return len(self.model_ids)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""
|
||||
Read a model by the given index.
|
||||
Returns:
|
||||
dictionary with following keys:
|
||||
- verts: FloatTensor of shape (V, 3).
|
||||
- faces: LongTensor of shape (F, 3) which indexes into the verts tensor.
|
||||
- synset_id (str): synset id
|
||||
- model_id (str): model id
|
||||
"""
|
||||
model = {}
|
||||
model["synset_id"] = self.synset_ids[idx]
|
||||
model["model_id"] = self.model_ids[idx]
|
||||
model_path = path.join(
|
||||
self.data_dir, model["synset_id"], model["model_id"], "model.obj"
|
||||
)
|
||||
model["verts"], faces, _ = load_obj(model_path)
|
||||
model["faces"] = faces.verts_idx
|
||||
return model
|
62
tests/test_shapenet_core.py
Normal file
62
tests/test_shapenet_core.py
Normal file
@ -0,0 +1,62 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
"""
|
||||
Sanity checks for loading ShapeNet Core v1.
|
||||
"""
|
||||
import os
|
||||
import random
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.datasets.shapenet.shapenet_core import ShapeNetCore
|
||||
|
||||
|
||||
SHAPENET_PATH = None
|
||||
|
||||
|
||||
class TestShapenetCore(TestCaseMixin, unittest.TestCase):
|
||||
def test_load_shapenet_core(self):
|
||||
|
||||
# The ShapeNet dataset is not provided in the repo.
|
||||
# Download this separately and update the `shapenet_path`
|
||||
# with the location of the dataset in order to run this test.
|
||||
if SHAPENET_PATH is None or not os.path.exists(SHAPENET_PATH):
|
||||
url = "https://www.shapenet.org/"
|
||||
msg = """ShapeNet data not found, download from %s, save it at the path %s,
|
||||
update SHAPENET_PATH at the top of the file, and rerun""" % (
|
||||
url,
|
||||
SHAPENET_PATH,
|
||||
)
|
||||
warnings.warn(msg)
|
||||
return True
|
||||
|
||||
# Load ShapeNetCore without specifying any particular categories.
|
||||
shapenet_dataset = ShapeNetCore(SHAPENET_PATH)
|
||||
|
||||
# Count the number of grandchildren directories (which should be equal to
|
||||
# the total number of objects in the dataset) by walking through the given
|
||||
# directory.
|
||||
wnsynset_list = [
|
||||
wnsynset
|
||||
for wnsynset in os.listdir(SHAPENET_PATH)
|
||||
if os.path.isdir(os.path.join(SHAPENET_PATH, wnsynset))
|
||||
]
|
||||
model_num_list = [
|
||||
(len(next(os.walk(os.path.join(SHAPENET_PATH, wnsynset)))[1]))
|
||||
for wnsynset in wnsynset_list
|
||||
]
|
||||
# Check total number of objects in the dataset is correct.
|
||||
self.assertEqual(len(shapenet_dataset), sum(model_num_list))
|
||||
|
||||
# Randomly retrieve an object from the dataset.
|
||||
rand_obj = random.choice(shapenet_dataset)
|
||||
self.assertEqual(len(rand_obj), 4)
|
||||
# Check that data types and shapes of items returned by __getitem__ are correct.
|
||||
verts, faces = rand_obj["verts"], rand_obj["faces"]
|
||||
self.assertTrue(verts.dtype == torch.float32)
|
||||
self.assertTrue(faces.dtype == torch.int64)
|
||||
self.assertEqual(verts.ndim, 2)
|
||||
self.assertEqual(verts.shape[-1], 3)
|
||||
self.assertEqual(faces.ndim, 2)
|
||||
self.assertEqual(faces.shape[-1], 3)
|
Loading…
x
Reference in New Issue
Block a user