diff --git a/pytorch3d/datasets/__init__.py b/pytorch3d/datasets/__init__.py index 3cf0f3f3..243247e5 100644 --- a/pytorch3d/datasets/__init__.py +++ b/pytorch3d/datasets/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + from .shapenet import ShapeNetCore diff --git a/pytorch3d/datasets/shapenet/shapenet_core.py b/pytorch3d/datasets/shapenet/shapenet_core.py index e28ae797..79852975 100644 --- a/pytorch3d/datasets/shapenet/shapenet_core.py +++ b/pytorch3d/datasets/shapenet/shapenet_core.py @@ -41,7 +41,7 @@ class ShapeNetCore(ShapeNetBase): """ super().__init__() - self.data_dir = data_dir + self.shapenet_dir = data_dir if version not in [1, 2]: raise ValueError("Version number must be either 1 or 2.") self.model_dir = "model.obj" if version == 1 else "models/model_normalized.obj" @@ -100,6 +100,7 @@ class ShapeNetCore(ShapeNetBase): # Each grandchildren directory of data_dir contains an object, and the name # of the directory is the object's model_id. for synset in synset_set: + self.synset_starts[synset] = len(self.synset_ids) for model in os.listdir(path.join(data_dir, synset)): if not path.exists(path.join(data_dir, synset, model, self.model_dir)): msg = ( @@ -110,6 +111,7 @@ class ShapeNetCore(ShapeNetBase): continue self.synset_ids.append(synset) self.model_ids.append(model) + self.synset_lens[synset] = len(self.synset_ids) - self.synset_starts[synset] def __getitem__(self, idx: int) -> Dict: """ @@ -128,7 +130,7 @@ class ShapeNetCore(ShapeNetBase): """ model = self._get_item_ids(idx) model_path = path.join( - self.data_dir, model["synset_id"], model["model_id"], self.model_dir + self.shapenet_dir, model["synset_id"], model["model_id"], self.model_dir ) model["verts"], faces, _ = load_obj(model_path) model["faces"] = faces.verts_idx diff --git a/pytorch3d/datasets/shapenet_base.py b/pytorch3d/datasets/shapenet_base.py index f76546ce..daf156be 100644 --- a/pytorch3d/datasets/shapenet_base.py +++ b/pytorch3d/datasets/shapenet_base.py @@ -1,8 +1,11 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from typing import Dict +import warnings +from os import path +from typing import Dict, List, Optional import torch +from pytorch3d.io import load_objs_as_meshes from pytorch3d.renderer import ( HardPhongShader, MeshRasterizer, @@ -11,7 +14,7 @@ from pytorch3d.renderer import ( PointLights, RasterizationSettings, ) -from pytorch3d.structures import Meshes, Textures +from pytorch3d.structures import Textures class ShapeNetBase(torch.utils.data.Dataset): @@ -27,6 +30,11 @@ class ShapeNetBase(torch.utils.data.Dataset): """ self.synset_ids = [] self.model_ids = [] + self.synset_inv = {} + self.synset_starts = {} + self.synset_lens = {} + self.shapenet_dir = "" + self.model_dir = "" def __len__(self): """ @@ -67,30 +75,46 @@ class ShapeNetBase(torch.utils.data.Dataset): return model def render( - self, idx: int = 0, shader_type=HardPhongShader, device="cpu", **kwargs + self, + model_ids: Optional[List[str]] = None, + categories: Optional[List[str]] = None, + sample_nums: Optional[List[int]] = None, + idxs: Optional[List[int]] = None, + shader_type=HardPhongShader, + device="cpu", + **kwargs ) -> torch.Tensor: """ - Renders a model by the given index. + If a list of model_ids are supplied, render all the objects by the given model_ids. + If no model_ids are supplied, but categories and sample_nums are specified, randomly + select a number of objects (number specified in sample_nums) in the given categories + and render these objects. If instead a list of idxs is specified, check if the idxs + are all valid and render models by the given idxs. Otherwise, randomly select a number + (first number in sample_nums, default is set to be 1) of models from the loaded dataset + and render these models. Args: - idx: The index of model to be rendered in the dataset. - shader_type: select shading. Valid options include HardPhongShader (default), + model_ids: List[str] of model_ids of models intended to be rendered. + categories: List[str] of categories intended to be rendered. categories + and sample_nums must be specified at the same time. categories can be given + in the form of synset offsets or labels, or a combination of both. + sample_nums: List[int] of number of models to be randomly sampled from + each category. Could also contain one single integer, in which case it + will be broadcasted for every category. + idxs: List[int] of indices of models to be rendered in the dataset. + shader_type: Select shading. Valid options include HardPhongShader (default), SoftPhongShader, HardGouraudShader, SoftGouraudShader, HardFlatShader, SoftSilhouetteShader. device: torch.device on which the tensors should be located. **kwargs: Accepts any of the kwargs that the renderer supports. Returns: - Rendered image of shape (1, H, W, 3). + Batch of rendered images of shape (N, H, W, 3). """ - - model = self.__getitem__(idx) - verts, faces = model["verts"], model["faces"] - verts_rgb = torch.ones_like(verts, device=device)[None] - mesh = Meshes( - verts=[verts.to(device)], - faces=[faces.to(device)], - textures=Textures(verts_rgb=verts_rgb.to(device)), + paths = self._handle_render_inputs(model_ids, categories, sample_nums, idxs) + meshes = load_objs_as_meshes(paths, device=device, load_textures=False) + meshes.textures = Textures( + verts_rgb=torch.ones_like(meshes.verts_padded(), device=device) ) cameras = kwargs.get("cameras", OpenGLPerspectiveCameras()).to(device) renderer = MeshRenderer( @@ -104,4 +128,125 @@ class ShapeNetBase(torch.utils.data.Dataset): lights=kwargs.get("lights", PointLights()).to(device), ), ) - return renderer(mesh) + return renderer(meshes) + + def _handle_render_inputs( + self, + model_ids: Optional[List[str]] = None, + categories: Optional[List[str]] = None, + sample_nums: Optional[List[int]] = None, + idxs: Optional[List[int]] = None, + ) -> List[str]: + """ + Helper function for converting user provided model_ids, categories and sample_nums + to indices of models in the loaded dataset. If model idxs are provided, we check if + the idxs are valid. If no models are specified, the first model in the loaded dataset + is chosen. The function returns the file paths to the selected models. + + Args: + model_ids: List[str] of model_ids of models to be rendered. + categories: List[str] of categories to be rendered. + sample_nums: List[int] of number of models to be randomly sampled from + each category. + idxs: List[int] of indices of models to be rendered in the dataset. + + Returns: + List of paths of models to be rendered. + """ + # Get corresponding indices if model_ids are supplied. + if model_ids is not None and len(model_ids) > 0: + idxs = [] + for model_id in model_ids: + if model_id not in self.model_ids: + raise ValueError( + "model_id %s not found in the loaded dataset." % model_id + ) + idxs.append(self.model_ids.index(model_id)) + + # Sample random models if categories and sample_nums are supplied and get + # the corresponding indices. + elif categories is not None and len(categories) > 0: + sample_nums = [1] if sample_nums is None else sample_nums + if len(categories) != len(sample_nums) and len(sample_nums) != 1: + raise ValueError( + "categories and sample_nums needs to be of the same length or " + "sample_nums needs to be of length 1." + ) + + idxs_tensor = torch.empty(0, dtype=torch.int32) + for i in range(len(categories)): + category = self.synset_inv.get(categories[i], categories[i]) + if category not in self.synset_inv.values(): + raise ValueError( + "Category %s is not in the loaded dataset." % category + ) + # Broadcast if sample_nums has length of 1. + sample_num = sample_nums[i] if len(sample_nums) > 1 else sample_nums[0] + sampled_idxs = self._sample_idxs_from_category( + sample_num=sample_num, category=category + ) + idxs_tensor = torch.cat((idxs_tensor, sampled_idxs)) + idxs = idxs_tensor.tolist() + # Check if the indices are valid if idxs are supplied. + elif idxs is not None and len(idxs) > 0: + if any(idx < 0 or idx >= len(self.model_ids) for idx in idxs): + raise IndexError( + "One or more idx values are out of bounds. Indices need to be" + "between 0 and %s." % (len(self.model_ids) - 1) + ) + # Check if sample_nums is specified, if so sample sample_nums[0] number + # of indices from the entire loaded dataset. Otherwise randomly select one + # index from the dataset. + else: + sample_nums = [1] if sample_nums is None else sample_nums + if len(sample_nums) > 1: + msg = ( + "More than one sample sizes specified, now sampling " + "%d models from the dataset." % sample_nums[0] + ) + warnings.warn(msg) + idxs = self._sample_idxs_from_category(sample_nums[0]) + return [ + path.join( + self.shapenet_dir, + self.synset_ids[idx], + self.model_ids[idx], + self.model_dir, + ) + for idx in idxs + ] + + def _sample_idxs_from_category( + self, sample_num: int = 1, category: Optional[str] = None + ) -> List[int]: + """ + Helper function for sampling a number of indices from the given category. + + Args: + sample_num: number of indicies to be sampled from the given category. + category: category synset of the category to be sampled from. If not + specified, sample from all models in the loaded dataset. + """ + start = self.synset_starts[category] if category is not None else 0 + range_len = ( + self.synset_lens[category] if category is not None else self.__len__() + ) + replacement = sample_num > range_len + sampled_idxs = ( + torch.multinomial( + torch.ones((range_len), dtype=torch.float32), + sample_num, + replacement=replacement, + ) + + start + ) + if replacement: + msg = ( + "Sample size %d is larger than the number of objects in %s, " + "values sampled with replacement." + ) % ( + sample_num, + "category " + category if category is not None else "all categories", + ) + warnings.warn(msg) + return sampled_idxs diff --git a/tests/data/test_shapenet_core_render_mixed_by_categories_0.png b/tests/data/test_shapenet_core_render_mixed_by_categories_0.png new file mode 100644 index 00000000..5104bd71 Binary files /dev/null and b/tests/data/test_shapenet_core_render_mixed_by_categories_0.png differ diff --git a/tests/data/test_shapenet_core_render_mixed_by_categories_1.png b/tests/data/test_shapenet_core_render_mixed_by_categories_1.png new file mode 100644 index 00000000..4df67122 Binary files /dev/null and b/tests/data/test_shapenet_core_render_mixed_by_categories_1.png differ diff --git a/tests/data/test_shapenet_core_render_mixed_by_categories_2.png b/tests/data/test_shapenet_core_render_mixed_by_categories_2.png new file mode 100644 index 00000000..e1c5ebb0 Binary files /dev/null and b/tests/data/test_shapenet_core_render_mixed_by_categories_2.png differ diff --git a/tests/data/test_shapenet_core_render_piano_0.png b/tests/data/test_shapenet_core_render_piano_0.png new file mode 100644 index 00000000..fc7524c8 Binary files /dev/null and b/tests/data/test_shapenet_core_render_piano_0.png differ diff --git a/tests/data/test_shapenet_core_render_piano_1.png b/tests/data/test_shapenet_core_render_piano_1.png new file mode 100644 index 00000000..b53a022a Binary files /dev/null and b/tests/data/test_shapenet_core_render_piano_1.png differ diff --git a/tests/data/test_shapenet_core_render_piano_2.png b/tests/data/test_shapenet_core_render_piano_2.png new file mode 100644 index 00000000..fdcd4933 Binary files /dev/null and b/tests/data/test_shapenet_core_render_piano_2.png differ diff --git a/tests/data/test_shapenet_core_render_without_sample_nums_0.png b/tests/data/test_shapenet_core_render_without_sample_nums_0.png new file mode 100644 index 00000000..3fe33135 Binary files /dev/null and b/tests/data/test_shapenet_core_render_without_sample_nums_0.png differ diff --git a/tests/data/test_shapenet_core_render_without_sample_nums_1.png b/tests/data/test_shapenet_core_render_without_sample_nums_1.png new file mode 100644 index 00000000..cbf4d692 Binary files /dev/null and b/tests/data/test_shapenet_core_render_without_sample_nums_1.png differ diff --git a/tests/test_shapenet_core.py b/tests/test_shapenet_core.py index db92c83a..a1fc664c 100644 --- a/tests/test_shapenet_core.py +++ b/tests/test_shapenet_core.py @@ -1,11 +1,9 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. """ -Sanity checks for loading ShapeNet Core v1. +Sanity checks for loading ShapeNetCore. """ import os -import random import unittest -import warnings from pathlib import Path import numpy as np @@ -21,6 +19,7 @@ from pytorch3d.renderer import ( ) +# Set the SHAPENET_PATH to the local path to the dataset SHAPENET_PATH = None # If DEBUG=True, save out images generated in the tests for debugging. # All saved images have prefix DEBUG_ @@ -29,23 +28,26 @@ DATA_DIR = Path(__file__).resolve().parent / "data" class TestShapenetCore(TestCaseMixin, unittest.TestCase): - def test_load_shapenet_core(self): - # Setup - device = torch.device("cuda:0") - - # The ShapeNet dataset is not provided in the repo. - # Download this separately and update the `shapenet_path` - # with the location of the dataset in order to run this test. + def setUp(self): + """ + Check if the ShapeNet dataset is provided in the repo. + If not, download this separately and update the shapenet_path` + with the location of the dataset in order to run the tests. + """ if SHAPENET_PATH is None or not os.path.exists(SHAPENET_PATH): url = "https://www.shapenet.org/" - msg = """ShapeNet data not found, download from %s, save it at the path %s, - update SHAPENET_PATH at the top of the file, and rerun""" % ( - url, - SHAPENET_PATH, + msg = ( + "ShapeNet data not found, download from %s, update " + "SHAPENET_PATH at the top of the file, and rerun." ) - warnings.warn(msg) - return True + self.skipTest(msg % url) + + def test_load_shapenet_core(self): + """ + Test loading both the entire ShapeNetCore dataset and a subset of the ShapeNetCore + dataset. Check the loaded datasets return items of the correct shapes and types. + """ # Try loading ShapeNetCore with an invalid version number and catch error. with self.assertRaises(ValueError) as err: ShapeNetCore(SHAPENET_PATH, version=3) @@ -70,8 +72,7 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase): self.assertEqual(len(shapenet_dataset), sum(model_num_list)) # Randomly retrieve an object from the dataset. - rand_obj = random.choice(shapenet_dataset) - self.assertEqual(len(rand_obj), 5) + rand_obj = shapenet_dataset[torch.randint(len(shapenet_dataset), (1,))] # Check that data types and shapes of items returned by __getitem__ are correct. verts, faces = rand_obj["verts"], rand_obj["faces"] self.assertTrue(verts.dtype == torch.float32) @@ -82,7 +83,7 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase): self.assertEqual(faces.shape[-1], 3) # Load six categories from ShapeNetCore. - # Specify categories in the form of a combination of offsets and labels. + # Specify categories with a combination of offsets and labels. shapenet_subset = ShapeNetCore( SHAPENET_PATH, synsets=[ @@ -109,10 +110,37 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase): ] self.assertEqual(len(shapenet_subset), sum(subset_model_nums)) - # Render the first image in the piano category. - R, T = look_at_view_transform(1.0, 1.0, 90) + def test_catch_render_arg_errors(self): + """ + Test rendering ShapeNetCore with invalid model_ids, categories or indices, + and catch corresponding errors. + """ + # Load ShapeNetCore. + shapenet_dataset = ShapeNetCore(SHAPENET_PATH) + + # Try loading with an invalid model_id and catch error. + with self.assertRaises(ValueError) as err: + shapenet_dataset.render(model_ids=["piano0"]) + self.assertTrue("not found in the loaded dataset" in str(err.exception)) + + # Try loading with an index out of bounds and catch error. + with self.assertRaises(IndexError) as err: + shapenet_dataset.render(idxs=[100000]) + self.assertTrue("are out of bounds" in str(err.exception)) + + def test_render_shapenet_core(self): + """ + Test rendering objects from ShapeNetCore. + """ + # Setup device and seed for random selections. + device = torch.device("cuda:0") + torch.manual_seed(39) + + # Load category piano from ShapeNetCore. piano_dataset = ShapeNetCore(SHAPENET_PATH, synsets=["piano"]) + # Rendering settings. + R, T = look_at_view_transform(1.0, 1.0, 90) cameras = OpenGLPerspectiveCameras(R=R, T=T, device=device) raster_settings = RasterizationSettings(image_size=512) lights = PointLights( @@ -122,17 +150,102 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase): specular_color=((0, 0, 0),), device=device, ) - images = piano_dataset.render( - 0, + + # Render first three models in the piano category. + pianos = piano_dataset.render( + idxs=list(range(3)), device=device, cameras=cameras, raster_settings=raster_settings, lights=lights, ) - rgb = images[0, ..., :3].squeeze().cpu() - if DEBUG: - Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( - DATA_DIR / "DEBUG_shapenet_core_render_piano.png" + # Check that there are three images in the batch. + self.assertEqual(pianos.shape[0], 3) + + # Compare the rendered models to the reference images. + for idx in range(3): + piano_rgb = pianos[idx, ..., :3].squeeze().cpu() + if DEBUG: + Image.fromarray((piano_rgb.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / ("DEBUG_shapenet_core_render_piano_by_idxs_%s.png" % idx) + ) + image_ref = load_rgb_image( + "test_shapenet_core_render_piano_%s.png" % idx, DATA_DIR ) - image_ref = load_rgb_image("test_shapenet_core_render_piano.png", DATA_DIR) - self.assertClose(rgb, image_ref, atol=0.05) + self.assertClose(piano_rgb, image_ref, atol=0.05) + + # Render the same piano models but by model_ids this time. + pianos_2 = piano_dataset.render( + model_ids=[ + "13394ca47c89f91525a3aaf903a41c90", + "14755c2ee8e693aba508f621166382b0", + "156c4207af6d2c8f1fdc97905708b8ea", + ], + device=device, + cameras=cameras, + raster_settings=raster_settings, + lights=lights, + ) + + # Compare the rendered models to the reference images. + for idx in range(3): + piano_rgb_2 = pianos_2[idx, ..., :3].squeeze().cpu() + if DEBUG: + Image.fromarray((piano_rgb_2.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / ("DEBUG_shapenet_core_render_piano_by_ids_%s.png" % idx) + ) + image_ref = load_rgb_image( + "test_shapenet_core_render_piano_%s.png" % idx, DATA_DIR + ) + self.assertClose(piano_rgb_2, image_ref, atol=0.05) + + ####################### + # Render by categories + ####################### + + # Load ShapeNetCore. + shapenet_dataset = ShapeNetCore(SHAPENET_PATH) + + # Render a mixture of categories and specify the number of models to be + # randomly sampled from each category. + mixed_objs = shapenet_dataset.render( + categories=["faucet", "chair"], + sample_nums=[2, 1], + device=device, + cameras=cameras, + raster_settings=raster_settings, + lights=lights, + ) + # Compare the rendered models to the reference images. + for idx in range(3): + mixed_rgb = mixed_objs[idx, ..., :3].squeeze().cpu() + if DEBUG: + Image.fromarray((mixed_rgb.numpy() * 255).astype(np.uint8)).save( + DATA_DIR + / ("DEBUG_shapenet_core_render_mixed_by_categories_%s.png" % idx) + ) + image_ref = load_rgb_image( + "test_shapenet_core_render_mixed_by_categories_%s.png" % idx, DATA_DIR + ) + self.assertClose(mixed_rgb, image_ref, atol=0.05) + + # Render a mixture of categories without specifying sample_nums. + mixed_objs_2 = shapenet_dataset.render( + categories=["faucet", "chair"], + device=device, + cameras=cameras, + raster_settings=raster_settings, + lights=lights, + ) + # Compare the rendered models to the reference images. + for idx in range(2): + mixed_rgb_2 = mixed_objs_2[idx, ..., :3].squeeze().cpu() + if DEBUG: + Image.fromarray((mixed_rgb_2.numpy() * 255).astype(np.uint8)).save( + DATA_DIR + / ("DEBUG_shapenet_core_render_without_sample_nums_%s.png" % idx) + ) + image_ref = load_rgb_image( + "test_shapenet_core_render_without_sample_nums_%s.png" % idx, DATA_DIR + ) + self.assertClose(mixed_rgb_2, image_ref, atol=0.05)