From dc08c30583085bfbd818f0bc07124c34b095c5c4 Mon Sep 17 00:00:00 2001 From: Luya Gao Date: Tue, 14 Jul 2020 14:52:21 -0700 Subject: [PATCH] Return R2N2 renderings Summary: R2N2 returns R2N2's own renderings of ShapeNetCore models. Reviewed By: nikhilaravi Differential Revision: D22266988 fbshipit-source-id: 36e67bd06c6459773e6e5f654259166b579be36a --- pytorch3d/datasets/r2n2/r2n2.py | 76 ++++++++++++++++++++++++++++++--- pytorch3d/datasets/utils.py | 7 +++ tests/test_r2n2.py | 16 +++++-- 3 files changed, 91 insertions(+), 8 deletions(-) diff --git a/pytorch3d/datasets/r2n2/r2n2.py b/pytorch3d/datasets/r2n2/r2n2.py index 56e4208d..553736a5 100644 --- a/pytorch3d/datasets/r2n2/r2n2.py +++ b/pytorch3d/datasets/r2n2/r2n2.py @@ -4,8 +4,11 @@ import json import warnings from os import path from pathlib import Path -from typing import Dict +from typing import Dict, List, Optional +import numpy as np +import torch +from PIL import Image from pytorch3d.datasets.shapenet_base import ShapeNetBase from pytorch3d.io import load_obj @@ -21,18 +24,29 @@ class R2N2(ShapeNetBase): voxelized models. """ - def __init__(self, split, shapenet_dir, r2n2_dir, splits_file): + def __init__( + self, + split: str, + shapenet_dir, + r2n2_dir, + splits_file, + return_all_views: bool = True, + ): """ Store each object's synset id and models id the given directories. + Args: split (str): One of (train, val, test). shapenet_dir (path): Path to ShapeNet core v1. r2n2_dir (path): Path to the R2N2 dataset. splits_file (path): File containing the train/val/test splits. + return_all_views (bool): Indicator of whether or not to return all 24 views. If set + to False, one of the 24 views would be randomly selected and returned. """ super().__init__() self.shapenet_dir = shapenet_dir self.r2n2_dir = r2n2_dir + self.return_all_views = return_all_views # Examine if split is valid. if split not in ["train", "val", "test"]: raise ValueError("split has to be one of (train, val, test).") @@ -48,6 +62,16 @@ class R2N2(ShapeNetBase): with open(splits_file) as splits: split_dict = json.load(splits)[split] + self.return_images = True + # Check if the folder containing R2N2 renderings is included in r2n2_dir. + if not path.isdir(path.join(r2n2_dir, "ShapeNetRendering")): + self.return_images = False + msg = ( + "ShapeNetRendering not found in %s. R2N2 renderings will " + "be skipped when returning models." + ) % (r2n2_dir) + warnings.warn(msg) + synset_set = set() for synset in split_dict.keys(): # Examine if the given synset is present in the ShapeNetCore dataset @@ -95,12 +119,15 @@ class R2N2(ShapeNetBase): ) % (shapenet_dir, ", ".join(synset_not_present)) warnings.warn(msg) - def __getitem__(self, idx: int) -> Dict: + def __getitem__(self, model_idx, view_idxs: Optional[List[int]] = None) -> Dict: """ Read a model by the given index. Args: - idx: The idx of the model to be retrieved in the dataset. + model_idx: The idx of the model to be retrieved in the dataset. + view_idx: List of indices of the view to be returned. Each index needs to be + between 0 and 23, inclusive. If an invalid index is supplied, view_idx will be + ignored and views will be sampled according to self.return_all_views. Returns: dictionary with following keys: @@ -109,12 +136,51 @@ class R2N2(ShapeNetBase): - synset_id (str): synset id. - model_id (str): model id. - label (str): synset label. + - images: FloatTensor of shape (V, H, W, C), where V is number of views + returned. Returns a batch of the renderings of the models from the R2N2 dataset. """ - model = self._get_item_ids(idx) + if type(model_idx) is tuple: + model_idx, view_idxs = model_idx + model = self._get_item_ids(model_idx) model_path = path.join( self.shapenet_dir, model["synset_id"], model["model_id"], "model.obj" ) model["verts"], faces, _ = load_obj(model_path) model["faces"] = faces.verts_idx model["label"] = self.synset_dict[model["synset_id"]] + + model["images"] = None + # Retrieve R2N2's renderings if required. + if self.return_images: + ranges = ( + range(24) if self.return_all_views else torch.randint(24, (1,)).tolist() + ) + if view_idxs is not None and any(idx < 0 or idx > 23 for idx in view_idxs): + msg = ( + "One of the indicies in view_idxs is out of range. " + "Index needs to be between 0 and 23, inclusive. " + "Now sampling according to self.return_all_views." + ) + warnings.warn(msg) + elif view_idxs is not None: + ranges = view_idxs + + rendering_path = path.join( + self.r2n2_dir, + "ShapeNetRendering", + model["synset_id"], + model["model_id"], + "rendering", + ) + + images = [] + for i in ranges: + # Read image. + image_path = path.join(rendering_path, "%02d.png" % i) + raw_img = Image.open(image_path) + image = torch.from_numpy(np.array(raw_img) / 255.0)[..., :3] + images.append(image.to(dtype=torch.float32)) + + model["images"] = torch.stack(images) + return model diff --git a/pytorch3d/datasets/utils.py b/pytorch3d/datasets/utils.py index 81616b51..922ff780 100644 --- a/pytorch3d/datasets/utils.py +++ b/pytorch3d/datasets/utils.py @@ -2,6 +2,7 @@ from typing import Dict, List +import torch from pytorch3d.structures import Meshes @@ -32,4 +33,10 @@ def collate_batched_meshes(batch: List[Dict]): verts=collated_dict["verts"], faces=collated_dict["faces"] ) + # If collate_batched_meshes receives R2N2 items, stack the batches of + # views of each model into a new batch of shape (N, V, H, W, 3) where + # V is the number of views. + if "images" in collated_dict: + collated_dict["images"] = torch.stack(collated_dict["images"]) + return collated_dict diff --git a/tests/test_r2n2.py b/tests/test_r2n2.py index ba5fdf69..1de4222d 100644 --- a/tests/test_r2n2.py +++ b/tests/test_r2n2.py @@ -56,7 +56,8 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): def test_load_R2N2(self): """ - Test the loaded train split of R2N2 return items of the correct shapes and types. + Test the loaded train split of R2N2 return items of the correct shapes and types. Also + check the first image returned is correct. """ # Load dataset in the train split. r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH) @@ -68,8 +69,9 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): self.assertEqual(len(r2n2_dataset), sum(model_nums)) # Randomly retrieve an object from the dataset. - rand_obj = r2n2_dataset[torch.randint(len(r2n2_dataset), (1,))] - # Check that data type and shape of the item returned by __getitem__ are correct. + rand_idx = torch.randint(len(r2n2_dataset), (1,)) + rand_obj = r2n2_dataset[rand_idx] + # Check that verts and faces returned by __getitem__ have the correct shapes and types. verts, faces = rand_obj["verts"], rand_obj["faces"] self.assertTrue(verts.dtype == torch.float32) self.assertTrue(faces.dtype == torch.int64) @@ -78,6 +80,13 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): self.assertEqual(faces.ndim, 2) self.assertEqual(faces.shape[-1], 3) + # Check that image batch returned by __getitem__ has the correct shape. + self.assertEqual(rand_obj["images"].shape[0], 24) + self.assertEqual(rand_obj["images"].shape[1], 137) + self.assertEqual(rand_obj["images"].shape[2], 137) + self.assertEqual(rand_obj["images"].shape[-1], 3) + self.assertEqual(r2n2_dataset[rand_idx, [21]]["images"].shape[0], 1) + def test_collate_models(self): """ Test collate_batched_meshes returns items of the correct shapes and types. @@ -118,6 +127,7 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): self.assertEqual(len(object_batch["label"]), batch_size) self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size) self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size) + self.assertEqual(object_batch["images"].shape[0], batch_size) def test_catch_render_arg_errors(self): """