diff --git a/pytorch3d/datasets/r2n2/r2n2.py b/pytorch3d/datasets/r2n2/r2n2.py index 1305e478..56e4208d 100644 --- a/pytorch3d/datasets/r2n2/r2n2.py +++ b/pytorch3d/datasets/r2n2/r2n2.py @@ -64,6 +64,7 @@ class R2N2(ShapeNetBase): continue synset_set.add(synset) + self.synset_starts[synset] = len(self.synset_ids) models = split_dict[synset].keys() for model in models: # Examine if the given model is present in the ShapeNetCore path. @@ -78,6 +79,7 @@ class R2N2(ShapeNetBase): continue self.synset_ids.append(synset) self.model_ids.append(model) + self.synset_lens[synset] = len(self.synset_ids) - self.synset_starts[synset] # Examine if all the synsets in the standard R2N2 mapping are present. # Update self.synset_inv so that it only includes the loaded categories. diff --git a/pytorch3d/datasets/shapenet_base.py b/pytorch3d/datasets/shapenet_base.py index daf156be..d305894f 100644 --- a/pytorch3d/datasets/shapenet_base.py +++ b/pytorch3d/datasets/shapenet_base.py @@ -34,7 +34,7 @@ class ShapeNetBase(torch.utils.data.Dataset): self.synset_starts = {} self.synset_lens = {} self.shapenet_dir = "" - self.model_dir = "" + self.model_dir = "model.obj" def __len__(self): """ diff --git a/tests/data/test_r2n2_render_by_categories_0.png b/tests/data/test_r2n2_render_by_categories_0.png new file mode 100644 index 00000000..03fb791e Binary files /dev/null and b/tests/data/test_r2n2_render_by_categories_0.png differ diff --git a/tests/data/test_r2n2_render_by_categories_1.png b/tests/data/test_r2n2_render_by_categories_1.png new file mode 100644 index 00000000..871e5594 Binary files /dev/null and b/tests/data/test_r2n2_render_by_categories_1.png differ diff --git a/tests/data/test_r2n2_render_by_categories_2.png b/tests/data/test_r2n2_render_by_categories_2.png new file mode 100644 index 00000000..461dba8e Binary files /dev/null and b/tests/data/test_r2n2_render_by_categories_2.png differ diff --git a/tests/data/test_r2n2_render_by_idxs_and_ids_0.png b/tests/data/test_r2n2_render_by_idxs_and_ids_0.png new file mode 100644 index 00000000..d16c2ad9 Binary files /dev/null and b/tests/data/test_r2n2_render_by_idxs_and_ids_0.png differ diff --git a/tests/data/test_r2n2_render_by_idxs_and_ids_1.png b/tests/data/test_r2n2_render_by_idxs_and_ids_1.png new file mode 100644 index 00000000..ae285391 Binary files /dev/null and b/tests/data/test_r2n2_render_by_idxs_and_ids_1.png differ diff --git a/tests/data/test_r2n2_render_by_idxs_and_ids_2.png b/tests/data/test_r2n2_render_by_idxs_and_ids_2.png new file mode 100644 index 00000000..b953b232 Binary files /dev/null and b/tests/data/test_r2n2_render_by_idxs_and_ids_2.png differ diff --git a/tests/test_r2n2.py b/tests/test_r2n2.py index 8cd2ed8c..ba5fdf69 100644 --- a/tests/test_r2n2.py +++ b/tests/test_r2n2.py @@ -5,10 +5,19 @@ Sanity checks for loading R2N2. import json import os import unittest +from pathlib import Path +import numpy as np import torch -from common_testing import TestCaseMixin +from common_testing import TestCaseMixin, load_rgb_image +from PIL import Image from pytorch3d.datasets import R2N2, collate_batched_meshes +from pytorch3d.renderer import ( + OpenGLPerspectiveCameras, + PointLights, + RasterizationSettings, + look_at_view_transform, +) from torch.utils.data import DataLoader @@ -17,6 +26,9 @@ R2N2_PATH = None SHAPENET_PATH = None SPLITS_PATH = None +DEBUG = False +DATA_DIR = Path(__file__).resolve().parent / "data" + class TestR2N2(TestCaseMixin, unittest.TestCase): def setUp(self): @@ -44,16 +56,14 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): def test_load_R2N2(self): """ - Test loading the train split of R2N2. Check the loaded dataset return items - of the correct shapes and types. + Test the loaded train split of R2N2 return items of the correct shapes and types. """ # Load dataset in the train split. - split = "train" - r2n2_dataset = R2N2(split, SHAPENET_PATH, R2N2_PATH, SPLITS_PATH) + r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH) # Check total number of objects in the dataset is correct. with open(SPLITS_PATH) as splits: - split_dict = json.load(splits)[split] + split_dict = json.load(splits)["train"] model_nums = [len(split_dict[synset].keys()) for synset in split_dict.keys()] self.assertEqual(len(r2n2_dataset), sum(model_nums)) @@ -75,8 +85,7 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): the correct shapes and types are returned. """ # Load dataset in the train split. - split = "train" - r2n2_dataset = R2N2(split, SHAPENET_PATH, R2N2_PATH, SPLITS_PATH) + r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH) # Randomly retrieve several objects from the dataset and collate them. collated_meshes = collate_batched_meshes( @@ -109,3 +118,117 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): self.assertEqual(len(object_batch["label"]), batch_size) self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size) self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size) + + def test_catch_render_arg_errors(self): + """ + Test rendering R2N2 with an invalid model_id, category or index, and + catch corresponding errors. + """ + # Load dataset in the train split. + r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH) + + # Try loading with an invalid model_id and catch error. + with self.assertRaises(ValueError) as err: + r2n2_dataset.render(model_ids=["lamp0"]) + self.assertTrue("not found in the loaded dataset" in str(err.exception)) + + # Try loading with an index out of bounds and catch error. + with self.assertRaises(IndexError) as err: + r2n2_dataset.render(idxs=[1000000]) + self.assertTrue("are out of bounds" in str(err.exception)) + + def test_render_r2n2(self): + """ + Test rendering objects from R2N2 selected both by indices and model_ids. + """ + # Set up device and seed for random selections. + device = torch.device("cuda:0") + torch.manual_seed(39) + + # Load dataset in the train split. + r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH) + + # Render first three models in the dataset. + R, T = look_at_view_transform(1.0, 1.0, 90) + cameras = OpenGLPerspectiveCameras(R=R, T=T, device=device) + raster_settings = RasterizationSettings(image_size=512) + lights = PointLights( + location=torch.tensor([0.0, 1.0, -2.0], device=device)[None], + # TODO: debug the source of the discrepancy in two images when rendering on GPU. + diffuse_color=((0, 0, 0),), + specular_color=((0, 0, 0),), + device=device, + ) + + r2n2_by_idxs = r2n2_dataset.render( + idxs=list(range(3)), + device=device, + cameras=cameras, + raster_settings=raster_settings, + lights=lights, + ) + # Check that there are three images in the batch. + self.assertEqual(r2n2_by_idxs.shape[0], 3) + + # Compare the rendered models to the reference images. + for idx in range(3): + r2n2_by_idxs_rgb = r2n2_by_idxs[idx, ..., :3].squeeze().cpu() + if DEBUG: + Image.fromarray((r2n2_by_idxs_rgb.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / ("DEBUG_r2n2_render_by_idxs_%s.png" % idx) + ) + image_ref = load_rgb_image( + "test_r2n2_render_by_idxs_and_ids_%s.png" % idx, DATA_DIR + ) + self.assertClose(r2n2_by_idxs_rgb, image_ref, atol=0.05) + + # Render the same models but by model_ids this time. + r2n2_by_model_ids = r2n2_dataset.render( + model_ids=[ + "1a4a8592046253ab5ff61a3a2a0e2484", + "1a04dcce7027357ab540cc4083acfa57", + "1a9d0480b74d782698f5bccb3529a48d", + ], + device=device, + cameras=cameras, + raster_settings=raster_settings, + lights=lights, + ) + + # Compare the rendered models to the reference images. + for idx in range(3): + r2n2_by_model_ids_rgb = r2n2_by_model_ids[idx, ..., :3].squeeze().cpu() + if DEBUG: + Image.fromarray( + (r2n2_by_model_ids_rgb.numpy() * 255).astype(np.uint8) + ).save(DATA_DIR / ("DEBUG_r2n2_render_by_model_ids_%s.png" % idx)) + image_ref = load_rgb_image( + "test_r2n2_render_by_idxs_and_ids_%s.png" % idx, DATA_DIR + ) + self.assertClose(r2n2_by_model_ids_rgb, image_ref, atol=0.05) + + ############################### + # Test rendering by categories + ############################### + + # Render a mixture of categories. + categories = ["chair", "lamp"] + mixed_objs = r2n2_dataset.render( + categories=categories, + sample_nums=[1, 2], + device=device, + cameras=cameras, + raster_settings=raster_settings, + lights=lights, + ) + # Compare the rendered models to the reference images. + for idx in range(3): + mixed_rgb = mixed_objs[idx, ..., :3].squeeze().cpu() + if DEBUG: + Image.fromarray((mixed_rgb.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / ("DEBUG_r2n2_render_by_categories_%s.png" % idx) + ) + image_ref = load_rgb_image( + "test_r2n2_render_by_categories_%s.png" % idx, DATA_DIR + ) + self.assertClose(mixed_rgb, image_ref, atol=0.05)