mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 22:00:35 +08:00
Adding support for selecting categories and ver2 for ShapeNetCore
Summary: Adding support so that users can select which categories they would like to load with wordnet synset offsets or labels or a combination of both. ShapeNetCore now also supports loading v2. Reviewed By: nikhilaravi Differential Revision: D22039207 fbshipit-source-id: 1f0218acb790e5561e2ae373e99cebb9823eea1a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
9d279ba543
commit
2ea6a7d8ad
@@ -9,7 +9,7 @@ import warnings
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.datasets.shapenet.shapenet_core import ShapeNetCore
|
||||
from pytorch3d.datasets import ShapeNetCore
|
||||
|
||||
|
||||
SHAPENET_PATH = None
|
||||
@@ -31,6 +31,11 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
|
||||
warnings.warn(msg)
|
||||
return True
|
||||
|
||||
# Try load ShapeNetCore with an invalid version number and catch error.
|
||||
with self.assertRaises(ValueError) as err:
|
||||
ShapeNetCore(SHAPENET_PATH, version=3)
|
||||
self.assertTrue("Version number must be either 1 or 2." in str(err.exception))
|
||||
|
||||
# Load ShapeNetCore without specifying any particular categories.
|
||||
shapenet_dataset = ShapeNetCore(SHAPENET_PATH)
|
||||
|
||||
@@ -51,7 +56,7 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
# Randomly retrieve an object from the dataset.
|
||||
rand_obj = random.choice(shapenet_dataset)
|
||||
self.assertEqual(len(rand_obj), 4)
|
||||
self.assertEqual(len(rand_obj), 5)
|
||||
# Check that data types and shapes of items returned by __getitem__ are correct.
|
||||
verts, faces = rand_obj["verts"], rand_obj["faces"]
|
||||
self.assertTrue(verts.dtype == torch.float32)
|
||||
@@ -60,3 +65,31 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(verts.shape[-1], 3)
|
||||
self.assertEqual(faces.ndim, 2)
|
||||
self.assertEqual(faces.shape[-1], 3)
|
||||
|
||||
# Load six categories from ShapeNetCore.
|
||||
# Specify categories in the form of a combination of offsets and labels.
|
||||
shapenet_subset = ShapeNetCore(
|
||||
SHAPENET_PATH,
|
||||
synsets=[
|
||||
"04330267",
|
||||
"guitar",
|
||||
"02801938",
|
||||
"birdhouse",
|
||||
"03991062",
|
||||
"tower",
|
||||
],
|
||||
version=1,
|
||||
)
|
||||
subset_offsets = [
|
||||
"04330267",
|
||||
"03467517",
|
||||
"02801938",
|
||||
"02843684",
|
||||
"03991062",
|
||||
"04460130",
|
||||
]
|
||||
subset_model_nums = [
|
||||
(len(next(os.walk(os.path.join(SHAPENET_PATH, offset)))[1]))
|
||||
for offset in subset_offsets
|
||||
]
|
||||
self.assertEqual(len(shapenet_subset), sum(subset_model_nums))
|
||||
|
||||
Reference in New Issue
Block a user