From 358e211cde4412c24675af3d048f2d6d4391df59 Mon Sep 17 00:00:00 2001 From: Luya Gao Date: Tue, 14 Jul 2020 14:52:21 -0700 Subject: [PATCH] Adding renderer for ShapeNetBase Summary: Adding a renderer to ShapeNetCore (Note that the lights are currently turned off for the test; will investigate why lighting causes instability in rendering) Reviewed By: nikhilaravi Differential Revision: D22102673 fbshipit-source-id: a704756a1e93b61d5a879f0e5ee14ebcb0df49d7 --- pytorch3d/datasets/shapenet/__init__.py | 1 + pytorch3d/datasets/shapenet/shapenet_core.py | 82 +++++++------- pytorch3d/datasets/shapenet_base.py | 107 ++++++++++++++++++ .../data/test_shapenet_core_render_piano.png | Bin 0 -> 3268 bytes tests/test_shapenet_core.py | 47 +++++++- 5 files changed, 194 insertions(+), 43 deletions(-) create mode 100644 pytorch3d/datasets/shapenet_base.py create mode 100644 tests/data/test_shapenet_core_render_piano.png diff --git a/pytorch3d/datasets/shapenet/__init__.py b/pytorch3d/datasets/shapenet/__init__.py index dd0bc863..44469dab 100644 --- a/pytorch3d/datasets/shapenet/__init__.py +++ b/pytorch3d/datasets/shapenet/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + from .shapenet_core import ShapeNetCore diff --git a/pytorch3d/datasets/shapenet/shapenet_core.py b/pytorch3d/datasets/shapenet/shapenet_core.py index 25b6bca0..e28ae797 100644 --- a/pytorch3d/datasets/shapenet/shapenet_core.py +++ b/pytorch3d/datasets/shapenet/shapenet_core.py @@ -5,15 +5,16 @@ import os import warnings from os import path from pathlib import Path +from typing import Dict -import torch +from pytorch3d.datasets.shapenet_base import ShapeNetBase from pytorch3d.io import load_obj SYNSET_DICT_DIR = Path(__file__).resolve().parent -class ShapeNetCore(torch.utils.data.Dataset): +class ShapeNetCore(ShapeNetBase): """ This class loads ShapeNetCore from a given directory into a Dataset object. ShapeNetCore is a subset of the ShapeNet dataset and can be downloaded from @@ -23,6 +24,7 @@ class ShapeNetCore(torch.utils.data.Dataset): def __init__(self, data_dir, synsets=None, version: int = 1): """ Store each object's synset id and models id from data_dir. + Args: data_dir: Path to ShapeNetCore data. synsets: List of synset categories to load from ShapeNetCore in the form of @@ -38,6 +40,7 @@ class ShapeNetCore(torch.utils.data.Dataset): version 1. """ + super().__init__() self.data_dir = data_dir if version not in [1, 2]: raise ValueError("Version number must be either 1 or 2.") @@ -48,7 +51,7 @@ class ShapeNetCore(torch.utils.data.Dataset): with open(path.join(SYNSET_DICT_DIR, dict_file), "r") as read_dict: self.synset_dict = json.load(read_dict) # Inverse dicitonary mapping synset labels to corresponding offsets. - synset_inv = {label: offset for offset, label in self.synset_dict.items()} + self.synset_inv = {label: offset for offset, label in self.synset_dict.items()} # If categories are specified, check if each category is in the form of either # synset offset or synset label, and if the category exists in the given directory. @@ -60,62 +63,61 @@ class ShapeNetCore(torch.utils.data.Dataset): path.isdir(path.join(data_dir, synset)) ): synset_set.add(synset) - elif (synset in synset_inv.keys()) and ( - (path.isdir(path.join(data_dir, synset_inv[synset]))) + elif (synset in self.synset_inv.keys()) and ( + (path.isdir(path.join(data_dir, self.synset_inv[synset]))) ): - synset_set.add(synset_inv[synset]) + synset_set.add(self.synset_inv[synset]) else: - msg = """Synset category %s either not part of ShapeNetCore dataset - or cannot be found in %s.""" % ( - synset, - data_dir, - ) + msg = ( + "Synset category %s either not part of ShapeNetCore dataset " + "or cannot be found in %s." + ) % (synset, data_dir) warnings.warn(msg) # If no category is given, load every category in the given directory. + # Ignore synset folders not included in the official mapping. else: synset_set = { synset for synset in os.listdir(data_dir) if path.isdir(path.join(data_dir, synset)) + and synset in self.synset_dict } - for synset in synset_set: - if synset not in self.synset_dict.keys(): - msg = """Synset category %s(%s) is part of ShapeNetCore ver.%s - but not found in %s.""" % ( - synset, - self.synset_dict[synset], - version, - data_dir, - ) - warnings.warn(msg) + + # Check if there are any categories in the official mapping that are not loaded. + # Update self.synset_inv so that it only includes the loaded categories. + synset_not_present = set(self.synset_dict.keys()).difference(synset_set) + [self.synset_inv.pop(self.synset_dict[synset]) for synset in synset_not_present] + + if len(synset_not_present) > 0: + msg = ( + "The following categories are included in ShapeNetCore ver.%d's " + "official mapping but not found in the dataset location %s: %s" + "" + ) % (version, data_dir, ", ".join(synset_not_present)) + warnings.warn(msg) # Extract model_id of each object from directory names. # Each grandchildren directory of data_dir contains an object, and the name # of the directory is the object's model_id. - self.synset_ids = [] - self.model_ids = [] for synset in synset_set: for model in os.listdir(path.join(data_dir, synset)): if not path.exists(path.join(data_dir, synset, model, self.model_dir)): - msg = """ Object file not found in the model directory %s - under synset directory %s.""" % ( - model, - synset, - ) + msg = ( + "Object file not found in the model directory %s " + "under synset directory %s." + ) % (model, synset) warnings.warn(msg) - else: - self.synset_ids.append(synset) - self.model_ids.append(model) + continue + self.synset_ids.append(synset) + self.model_ids.append(model) - def __len__(self): - """ - Return number of total models in shapenet core. - """ - return len(self.model_ids) - - def __getitem__(self, idx): + def __getitem__(self, idx: int) -> Dict: """ Read a model by the given index. + + Args: + idx: The idx of the model to be retrieved in the dataset. + Returns: dictionary with following keys: - verts: FloatTensor of shape (V, 3). @@ -124,9 +126,7 @@ class ShapeNetCore(torch.utils.data.Dataset): - model_id (str): model id - label (str): synset label. """ - model = {} - model["synset_id"] = self.synset_ids[idx] - model["model_id"] = self.model_ids[idx] + model = self._get_item_ids(idx) model_path = path.join( self.data_dir, model["synset_id"], model["model_id"], self.model_dir ) diff --git a/pytorch3d/datasets/shapenet_base.py b/pytorch3d/datasets/shapenet_base.py new file mode 100644 index 00000000..f76546ce --- /dev/null +++ b/pytorch3d/datasets/shapenet_base.py @@ -0,0 +1,107 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +from typing import Dict + +import torch +from pytorch3d.renderer import ( + HardPhongShader, + MeshRasterizer, + MeshRenderer, + OpenGLPerspectiveCameras, + PointLights, + RasterizationSettings, +) +from pytorch3d.structures import Meshes, Textures + + +class ShapeNetBase(torch.utils.data.Dataset): + """ + 'ShapeNetBase' implements a base Dataset for ShapeNet and R2N2 with helper methods. + It is not intended to be used on its own as a Dataset for a Dataloader. Both __init__ + and __getitem__ need to be implemented. + """ + + def __init__(self): + """ + Set up lists of synset_ids and model_ids. + """ + self.synset_ids = [] + self.model_ids = [] + + def __len__(self): + """ + Return number of total models in the loaded dataset. + """ + return len(self.model_ids) + + def __getitem__(self, idx) -> Dict: + """ + Read a model by the given index. Need to be implemented for every child class + of ShapeNetBase. + + Args: + idx: The idx of the model to be retrieved in the dataset. + + Returns: + dictionary containing information about the model. + """ + raise NotImplementedError( + "__getitem__ should be implemented in the child class of ShapeNetBase" + ) + + def _get_item_ids(self, idx) -> Dict: + """ + Read a model by the given index. + + Args: + idx: The idx of the model to be retrieved in the dataset. + + Returns: + dictionary with following keys: + - synset_id (str): synset id + - model_id (str): model id + """ + model = {} + model["synset_id"] = self.synset_ids[idx] + model["model_id"] = self.model_ids[idx] + return model + + def render( + self, idx: int = 0, shader_type=HardPhongShader, device="cpu", **kwargs + ) -> torch.Tensor: + """ + Renders a model by the given index. + + Args: + idx: The index of model to be rendered in the dataset. + shader_type: select shading. Valid options include HardPhongShader (default), + SoftPhongShader, HardGouraudShader, SoftGouraudShader, HardFlatShader, + SoftSilhouetteShader. + device: torch.device on which the tensors should be located. + **kwargs: Accepts any of the kwargs that the renderer supports. + + Returns: + Rendered image of shape (1, H, W, 3). + """ + + model = self.__getitem__(idx) + verts, faces = model["verts"], model["faces"] + verts_rgb = torch.ones_like(verts, device=device)[None] + mesh = Meshes( + verts=[verts.to(device)], + faces=[faces.to(device)], + textures=Textures(verts_rgb=verts_rgb.to(device)), + ) + cameras = kwargs.get("cameras", OpenGLPerspectiveCameras()).to(device) + renderer = MeshRenderer( + rasterizer=MeshRasterizer( + cameras=cameras, + raster_settings=kwargs.get("raster_settings", RasterizationSettings()), + ), + shader=shader_type( + device=device, + cameras=cameras, + lights=kwargs.get("lights", PointLights()).to(device), + ), + ) + return renderer(mesh) diff --git a/tests/data/test_shapenet_core_render_piano.png b/tests/data/test_shapenet_core_render_piano.png new file mode 100644 index 0000000000000000000000000000000000000000..fc7524c89e5ea51bd36560a19efb938f75681c80 GIT binary patch literal 3268 zcmeHKX;f257QR`4Ad3+M0nw&@ifXJ64@Feh!H_-6%|bY zx4}cB5@9f*o1;-Sg@!<*LSzez1O!YFRKgM_(9Rrj{L%Ac{`LEF>(>3MUe)*AdUapw z7H?0r1$qkrK+S7|+cp5Cyom(NtclZey$7J0=H<42do*jfz4d6@vj?Qc^W{eM#5+cP z=yM)=wv$Mkl_7;6df9}c4am~1q+)5$=;){g#L6N+Ac1=dz)cxQnSY!A0|gR`b#u>s zvVuqhgu4Q61)x5D>@3g$^NI=|>4;Ps^Ffb3*p8e2g}sVHhbCeT0l&ezvlb0E-|0@j zXWJv_Ea&Q92Oh8G!}Ur(QJMj4cd?m>x=8{`;yLfgCgS%%q57mBZF*d84K0@k_(y2K zz9H)44g?qj`_df-cfhDu8#`1!Ci+euN#1439mq9-HLYnqtz#CNkoES1;K)REj3O9c zm6g@2!zsrD;Y=9@HXnWQ!fsm11Ri)0sS;J#Nu7w)a0b8T6UMD$qT@(dWi=6l;?OCM zc|>ub5~w)R$V^$ADX>qLGfD>d95k#B9*flCgK%RWQKtw+4`Es6|Bz_76eOV@#I$e~ zOh?Ta)qn;PQx9UnB5-K;3VJ;zaz}!h)x_?rR0!8@%B>V2cq8@9#yp~u1~^=3AZUCW)+-w4&mj|&aMt8YO`4|-?pB2}1^#}&q28SAFK2_`BI;sfx z@s8^1>bSfa@r-W0F-XqZrzbS!ZoNLNdx+w087U6?MdgZqG*1 z6Qj#_%ru+DP9&zd%E6C0XJyXB^Qz%1Idoy+@r>7ub?29?XVr)iTRzo~{Fqt1*ctwK zRVJC4Hd?7f#llU+ed8fu7NK&QW$&{snR!rvz*-$2Yek1ETA=>1|J|7VgkvhOA9JQd zoB6pB4Gfc70-i>DUg0t+M%9An2sUf&K3kI?sEYC_)P|j=hr!{pKk4wBr(NTHohvRP z$Rcelb9s8gZ0``>b~b)SF7aeQx3re(^Ap+57&b=@_Xh+n2zUf}qu8vNow)eHyMZZ} z)?^w$`Q0_*oqnh?Z@~4uk6Typ03~z}MOgB-S%>CVK7&F@9GD*|DSh}lR;U$n(%zIzubKpTD(np&aV!~sFLOh2bE&X5Xhd1=39&RNYn>>Ln%3$F z*5qZop!2aoUjGZoYZ^%&4ai zhQOin`gx~A`R9t1g?4;!=(wJ0eJI~&Rs30(Y63pP=FltEJ#U+))NDf|RZa)}u4hxc zb8@MD(zvEGkZnVXe&0i2)mQ`BDkP_V&+yhE7}w;`qyONZunK)reF2dmCw$=&d$%eA zuA`iQH?$7T%J7$X$zj~J@9(~5(X}Ln3RAzOHU-KZ%aK!tW^xUTPI)s@Lmsih^k^`Zd!vWaUsPyo zGpb+jN|EU~d6?agS&^gIqrIiQA?-GJ;z*T?JK3D$KjaqU;!^cQGhknyEN2=Le zw$47OrVUx68{2LlQI|aa(XLcsvDMhJnS+ACM6JO>3DwmC<8ab|^tA0_e1DPI`+Mpd z^l7!V4@4dz`kf+kM2nz7*XfwxR?OnI-W{a4qo(4wWw}Gmo;1m^7Q|518*6?<3tE}C zlT}{!^z3y#M9WE)MDg>y3eY(u>%jjr4rGyi~@B;R&~GVaE_uU zhPTXv7}giv%7pZ;?Y-IhVG77Kbd|)_y)CU-^2F=zH{wua6-?K|{7#hLiwfV+?rWt- zx=k>jicx-zT;KYvmYiJiWi*A*V>-Z9@cASYKg<=qe0n3bvcEF&Qn`gr&iEnqb{%C> zrZwZvPiJM5ON7z>7vY}{3rnYpvZFAhq~c>{3D@-9K_kadH(2fPX5hL5C5WsTtD3BV z^8~nw2K>f3gL|ILY5vyTY#gK2F!&$XXY`*&h$Ej@Og9GDiUh35muB;um&YwMfW92z zK(ED8mLlM^22j>aig#gn1^S7+Epr2ShiP3tnzN|r)J+R1s}WEYx#FY#ESN4T?FWy! z7|UXVVS?O--AjTR!*&2dliW6coU zE3OL^Bc67+!n8c8!_)j3 zysB?+`QLQSvQj(aVtj*ok|lwAgcJ;vQYz^{692Hk@tr zRy5P<-8o$Bc$@kF#l4Vh$o&~WC(`7Y7J{4QYQ#MQz=-Kis45?Db29zI=i7cehq_^cy@nXqA~-U-=xyI*9nO|55qcJr8g4Q zZwU~y#VBEQS63SFP-Ds2bF1SKxm9;up+%r+(KE6V?^LFmS=VkTSoi_#f|S^3U&UzR zTr$g(b~dNy8{vk8{DW8ou1cUFKXj&+)4ESiIHkzHHMjR;8H1S5_dV#aQU5)5{!D*+ gUs3xFLMDuq5AS{!vlaWR{O<(3+`ZlIe@{vJ7cE<%7XSbN literal 0 HcmV?d00001 diff --git a/tests/test_shapenet_core.py b/tests/test_shapenet_core.py index ff623f78..db92c83a 100644 --- a/tests/test_shapenet_core.py +++ b/tests/test_shapenet_core.py @@ -6,17 +6,32 @@ import os import random import unittest import warnings +from pathlib import Path +import numpy as np import torch -from common_testing import TestCaseMixin +from common_testing import TestCaseMixin, load_rgb_image +from PIL import Image from pytorch3d.datasets import ShapeNetCore +from pytorch3d.renderer import ( + OpenGLPerspectiveCameras, + PointLights, + RasterizationSettings, + look_at_view_transform, +) SHAPENET_PATH = None +# If DEBUG=True, save out images generated in the tests for debugging. +# All saved images have prefix DEBUG_ +DEBUG = False +DATA_DIR = Path(__file__).resolve().parent / "data" class TestShapenetCore(TestCaseMixin, unittest.TestCase): def test_load_shapenet_core(self): + # Setup + device = torch.device("cuda:0") # The ShapeNet dataset is not provided in the repo. # Download this separately and update the `shapenet_path` @@ -31,7 +46,7 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase): warnings.warn(msg) return True - # Try load ShapeNetCore with an invalid version number and catch error. + # Try loading ShapeNetCore with an invalid version number and catch error. with self.assertRaises(ValueError) as err: ShapeNetCore(SHAPENET_PATH, version=3) self.assertTrue("Version number must be either 1 or 2." in str(err.exception)) @@ -93,3 +108,31 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase): for offset in subset_offsets ] self.assertEqual(len(shapenet_subset), sum(subset_model_nums)) + + # Render the first image in the piano category. + R, T = look_at_view_transform(1.0, 1.0, 90) + piano_dataset = ShapeNetCore(SHAPENET_PATH, synsets=["piano"]) + + cameras = OpenGLPerspectiveCameras(R=R, T=T, device=device) + raster_settings = RasterizationSettings(image_size=512) + lights = PointLights( + location=torch.tensor([0.0, 1.0, -2.0], device=device)[None], + # TODO: debug the source of the discrepancy in two images when rendering on GPU. + diffuse_color=((0, 0, 0),), + specular_color=((0, 0, 0),), + device=device, + ) + images = piano_dataset.render( + 0, + device=device, + cameras=cameras, + raster_settings=raster_settings, + lights=lights, + ) + rgb = images[0, ..., :3].squeeze().cpu() + if DEBUG: + Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / "DEBUG_shapenet_core_render_piano.png" + ) + image_ref = load_rgb_image("test_shapenet_core_render_piano.png", DATA_DIR) + self.assertClose(rgb, image_ref, atol=0.05)