Adding support for selecting categories and ver2 for ShapeNetCore

Summary: Adding support so that users can select which categories they would like to load with wordnet synset offsets or labels or a combination of both. ShapeNetCore now also supports loading v2.

Reviewed By: nikhilaravi

Differential Revision: D22039207

fbshipit-source-id: 1f0218acb790e5561e2ae373e99cebb9823eea1a
This commit is contained in:
Luya Gao
2020-06-17 20:29:23 -07:00
committed by Facebook GitHub Bot
parent 9d279ba543
commit 2ea6a7d8ad
6 changed files with 240 additions and 19 deletions

View File

@@ -9,7 +9,7 @@ import warnings
import torch
from common_testing import TestCaseMixin
from pytorch3d.datasets.shapenet.shapenet_core import ShapeNetCore
from pytorch3d.datasets import ShapeNetCore
SHAPENET_PATH = None
@@ -31,6 +31,11 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
warnings.warn(msg)
return True
# Try load ShapeNetCore with an invalid version number and catch error.
with self.assertRaises(ValueError) as err:
ShapeNetCore(SHAPENET_PATH, version=3)
self.assertTrue("Version number must be either 1 or 2." in str(err.exception))
# Load ShapeNetCore without specifying any particular categories.
shapenet_dataset = ShapeNetCore(SHAPENET_PATH)
@@ -51,7 +56,7 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
# Randomly retrieve an object from the dataset.
rand_obj = random.choice(shapenet_dataset)
self.assertEqual(len(rand_obj), 4)
self.assertEqual(len(rand_obj), 5)
# Check that data types and shapes of items returned by __getitem__ are correct.
verts, faces = rand_obj["verts"], rand_obj["faces"]
self.assertTrue(verts.dtype == torch.float32)
@@ -60,3 +65,31 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
self.assertEqual(verts.shape[-1], 3)
self.assertEqual(faces.ndim, 2)
self.assertEqual(faces.shape[-1], 3)
# Load six categories from ShapeNetCore.
# Specify categories in the form of a combination of offsets and labels.
shapenet_subset = ShapeNetCore(
SHAPENET_PATH,
synsets=[
"04330267",
"guitar",
"02801938",
"birdhouse",
"03991062",
"tower",
],
version=1,
)
subset_offsets = [
"04330267",
"03467517",
"02801938",
"02843684",
"03991062",
"04460130",
]
subset_model_nums = [
(len(next(os.walk(os.path.join(SHAPENET_PATH, offset)))[1]))
for offset in subset_offsets
]
self.assertEqual(len(shapenet_subset), sum(subset_model_nums))