From 22f2963cf11d11c50fc75e49199b14e091d922ee Mon Sep 17 00:00:00 2001 From: Luya Gao Date: Tue, 14 Jul 2020 14:52:21 -0700 Subject: [PATCH] Render objects in a batch by the specified model_ids, categories or idxs for ShapeNetBase Summary: Additional functionality for renderer in ShapeNetCore: users can select which objects to render by specifying their model_ids, or users could choose to render several random objects in some categories, or users could specify indices of the objects in the loaded dataset. (currently doesn't support changing lighting, still investigating why lighting is causing instability in renderings) Reviewed By: nikhilaravi Differential Revision: D22179594 fbshipit-source-id: 74c49094ffa3ea2eb71de9451f9e5da5053d356d --- pytorch3d/datasets/__init__.py | 1 + pytorch3d/datasets/shapenet/shapenet_core.py | 6 +- pytorch3d/datasets/shapenet_base.py | 177 ++++++++++++++++-- ...enet_core_render_mixed_by_categories_0.png | Bin 0 -> 2517 bytes ...enet_core_render_mixed_by_categories_1.png | Bin 0 -> 2275 bytes ...enet_core_render_mixed_by_categories_2.png | Bin 0 -> 2633 bytes .../test_shapenet_core_render_piano_0.png | Bin 0 -> 3268 bytes .../test_shapenet_core_render_piano_1.png | Bin 0 -> 2527 bytes .../test_shapenet_core_render_piano_2.png | Bin 0 -> 2259 bytes ...enet_core_render_without_sample_nums_0.png | Bin 0 -> 3201 bytes ...enet_core_render_without_sample_nums_1.png | Bin 0 -> 2767 bytes tests/test_shapenet_core.py | 171 ++++++++++++++--- 12 files changed, 308 insertions(+), 47 deletions(-) create mode 100644 tests/data/test_shapenet_core_render_mixed_by_categories_0.png create mode 100644 tests/data/test_shapenet_core_render_mixed_by_categories_1.png create mode 100644 tests/data/test_shapenet_core_render_mixed_by_categories_2.png create mode 100644 tests/data/test_shapenet_core_render_piano_0.png create mode 100644 tests/data/test_shapenet_core_render_piano_1.png create mode 100644 tests/data/test_shapenet_core_render_piano_2.png create mode 100644 tests/data/test_shapenet_core_render_without_sample_nums_0.png create mode 100644 tests/data/test_shapenet_core_render_without_sample_nums_1.png diff --git a/pytorch3d/datasets/__init__.py b/pytorch3d/datasets/__init__.py index 3cf0f3f3..243247e5 100644 --- a/pytorch3d/datasets/__init__.py +++ b/pytorch3d/datasets/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + from .shapenet import ShapeNetCore diff --git a/pytorch3d/datasets/shapenet/shapenet_core.py b/pytorch3d/datasets/shapenet/shapenet_core.py index e28ae797..79852975 100644 --- a/pytorch3d/datasets/shapenet/shapenet_core.py +++ b/pytorch3d/datasets/shapenet/shapenet_core.py @@ -41,7 +41,7 @@ class ShapeNetCore(ShapeNetBase): """ super().__init__() - self.data_dir = data_dir + self.shapenet_dir = data_dir if version not in [1, 2]: raise ValueError("Version number must be either 1 or 2.") self.model_dir = "model.obj" if version == 1 else "models/model_normalized.obj" @@ -100,6 +100,7 @@ class ShapeNetCore(ShapeNetBase): # Each grandchildren directory of data_dir contains an object, and the name # of the directory is the object's model_id. for synset in synset_set: + self.synset_starts[synset] = len(self.synset_ids) for model in os.listdir(path.join(data_dir, synset)): if not path.exists(path.join(data_dir, synset, model, self.model_dir)): msg = ( @@ -110,6 +111,7 @@ class ShapeNetCore(ShapeNetBase): continue self.synset_ids.append(synset) self.model_ids.append(model) + self.synset_lens[synset] = len(self.synset_ids) - self.synset_starts[synset] def __getitem__(self, idx: int) -> Dict: """ @@ -128,7 +130,7 @@ class ShapeNetCore(ShapeNetBase): """ model = self._get_item_ids(idx) model_path = path.join( - self.data_dir, model["synset_id"], model["model_id"], self.model_dir + self.shapenet_dir, model["synset_id"], model["model_id"], self.model_dir ) model["verts"], faces, _ = load_obj(model_path) model["faces"] = faces.verts_idx diff --git a/pytorch3d/datasets/shapenet_base.py b/pytorch3d/datasets/shapenet_base.py index f76546ce..daf156be 100644 --- a/pytorch3d/datasets/shapenet_base.py +++ b/pytorch3d/datasets/shapenet_base.py @@ -1,8 +1,11 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from typing import Dict +import warnings +from os import path +from typing import Dict, List, Optional import torch +from pytorch3d.io import load_objs_as_meshes from pytorch3d.renderer import ( HardPhongShader, MeshRasterizer, @@ -11,7 +14,7 @@ from pytorch3d.renderer import ( PointLights, RasterizationSettings, ) -from pytorch3d.structures import Meshes, Textures +from pytorch3d.structures import Textures class ShapeNetBase(torch.utils.data.Dataset): @@ -27,6 +30,11 @@ class ShapeNetBase(torch.utils.data.Dataset): """ self.synset_ids = [] self.model_ids = [] + self.synset_inv = {} + self.synset_starts = {} + self.synset_lens = {} + self.shapenet_dir = "" + self.model_dir = "" def __len__(self): """ @@ -67,30 +75,46 @@ class ShapeNetBase(torch.utils.data.Dataset): return model def render( - self, idx: int = 0, shader_type=HardPhongShader, device="cpu", **kwargs + self, + model_ids: Optional[List[str]] = None, + categories: Optional[List[str]] = None, + sample_nums: Optional[List[int]] = None, + idxs: Optional[List[int]] = None, + shader_type=HardPhongShader, + device="cpu", + **kwargs ) -> torch.Tensor: """ - Renders a model by the given index. + If a list of model_ids are supplied, render all the objects by the given model_ids. + If no model_ids are supplied, but categories and sample_nums are specified, randomly + select a number of objects (number specified in sample_nums) in the given categories + and render these objects. If instead a list of idxs is specified, check if the idxs + are all valid and render models by the given idxs. Otherwise, randomly select a number + (first number in sample_nums, default is set to be 1) of models from the loaded dataset + and render these models. Args: - idx: The index of model to be rendered in the dataset. - shader_type: select shading. Valid options include HardPhongShader (default), + model_ids: List[str] of model_ids of models intended to be rendered. + categories: List[str] of categories intended to be rendered. categories + and sample_nums must be specified at the same time. categories can be given + in the form of synset offsets or labels, or a combination of both. + sample_nums: List[int] of number of models to be randomly sampled from + each category. Could also contain one single integer, in which case it + will be broadcasted for every category. + idxs: List[int] of indices of models to be rendered in the dataset. + shader_type: Select shading. Valid options include HardPhongShader (default), SoftPhongShader, HardGouraudShader, SoftGouraudShader, HardFlatShader, SoftSilhouetteShader. device: torch.device on which the tensors should be located. **kwargs: Accepts any of the kwargs that the renderer supports. Returns: - Rendered image of shape (1, H, W, 3). + Batch of rendered images of shape (N, H, W, 3). """ - - model = self.__getitem__(idx) - verts, faces = model["verts"], model["faces"] - verts_rgb = torch.ones_like(verts, device=device)[None] - mesh = Meshes( - verts=[verts.to(device)], - faces=[faces.to(device)], - textures=Textures(verts_rgb=verts_rgb.to(device)), + paths = self._handle_render_inputs(model_ids, categories, sample_nums, idxs) + meshes = load_objs_as_meshes(paths, device=device, load_textures=False) + meshes.textures = Textures( + verts_rgb=torch.ones_like(meshes.verts_padded(), device=device) ) cameras = kwargs.get("cameras", OpenGLPerspectiveCameras()).to(device) renderer = MeshRenderer( @@ -104,4 +128,125 @@ class ShapeNetBase(torch.utils.data.Dataset): lights=kwargs.get("lights", PointLights()).to(device), ), ) - return renderer(mesh) + return renderer(meshes) + + def _handle_render_inputs( + self, + model_ids: Optional[List[str]] = None, + categories: Optional[List[str]] = None, + sample_nums: Optional[List[int]] = None, + idxs: Optional[List[int]] = None, + ) -> List[str]: + """ + Helper function for converting user provided model_ids, categories and sample_nums + to indices of models in the loaded dataset. If model idxs are provided, we check if + the idxs are valid. If no models are specified, the first model in the loaded dataset + is chosen. The function returns the file paths to the selected models. + + Args: + model_ids: List[str] of model_ids of models to be rendered. + categories: List[str] of categories to be rendered. + sample_nums: List[int] of number of models to be randomly sampled from + each category. + idxs: List[int] of indices of models to be rendered in the dataset. + + Returns: + List of paths of models to be rendered. + """ + # Get corresponding indices if model_ids are supplied. + if model_ids is not None and len(model_ids) > 0: + idxs = [] + for model_id in model_ids: + if model_id not in self.model_ids: + raise ValueError( + "model_id %s not found in the loaded dataset." % model_id + ) + idxs.append(self.model_ids.index(model_id)) + + # Sample random models if categories and sample_nums are supplied and get + # the corresponding indices. + elif categories is not None and len(categories) > 0: + sample_nums = [1] if sample_nums is None else sample_nums + if len(categories) != len(sample_nums) and len(sample_nums) != 1: + raise ValueError( + "categories and sample_nums needs to be of the same length or " + "sample_nums needs to be of length 1." + ) + + idxs_tensor = torch.empty(0, dtype=torch.int32) + for i in range(len(categories)): + category = self.synset_inv.get(categories[i], categories[i]) + if category not in self.synset_inv.values(): + raise ValueError( + "Category %s is not in the loaded dataset." % category + ) + # Broadcast if sample_nums has length of 1. + sample_num = sample_nums[i] if len(sample_nums) > 1 else sample_nums[0] + sampled_idxs = self._sample_idxs_from_category( + sample_num=sample_num, category=category + ) + idxs_tensor = torch.cat((idxs_tensor, sampled_idxs)) + idxs = idxs_tensor.tolist() + # Check if the indices are valid if idxs are supplied. + elif idxs is not None and len(idxs) > 0: + if any(idx < 0 or idx >= len(self.model_ids) for idx in idxs): + raise IndexError( + "One or more idx values are out of bounds. Indices need to be" + "between 0 and %s." % (len(self.model_ids) - 1) + ) + # Check if sample_nums is specified, if so sample sample_nums[0] number + # of indices from the entire loaded dataset. Otherwise randomly select one + # index from the dataset. + else: + sample_nums = [1] if sample_nums is None else sample_nums + if len(sample_nums) > 1: + msg = ( + "More than one sample sizes specified, now sampling " + "%d models from the dataset." % sample_nums[0] + ) + warnings.warn(msg) + idxs = self._sample_idxs_from_category(sample_nums[0]) + return [ + path.join( + self.shapenet_dir, + self.synset_ids[idx], + self.model_ids[idx], + self.model_dir, + ) + for idx in idxs + ] + + def _sample_idxs_from_category( + self, sample_num: int = 1, category: Optional[str] = None + ) -> List[int]: + """ + Helper function for sampling a number of indices from the given category. + + Args: + sample_num: number of indicies to be sampled from the given category. + category: category synset of the category to be sampled from. If not + specified, sample from all models in the loaded dataset. + """ + start = self.synset_starts[category] if category is not None else 0 + range_len = ( + self.synset_lens[category] if category is not None else self.__len__() + ) + replacement = sample_num > range_len + sampled_idxs = ( + torch.multinomial( + torch.ones((range_len), dtype=torch.float32), + sample_num, + replacement=replacement, + ) + + start + ) + if replacement: + msg = ( + "Sample size %d is larger than the number of objects in %s, " + "values sampled with replacement." + ) % ( + sample_num, + "category " + category if category is not None else "all categories", + ) + warnings.warn(msg) + return sampled_idxs diff --git a/tests/data/test_shapenet_core_render_mixed_by_categories_0.png b/tests/data/test_shapenet_core_render_mixed_by_categories_0.png new file mode 100644 index 0000000000000000000000000000000000000000..5104bd71415beeccb70589a4cf1e52cdef6b27e8 GIT binary patch literal 2517 zcmeAS@N?(olHy`uVBq!ia0y~yU;;9k7&t&wwUqN(1_sVKo-U3d6?5L+Jy>+xO~B!x z##8?1^Nra_VKlX*Vv>3c?>!rgc6%%5zry_usp_ja+L-S+zR#cCf| z3T7&18>IAz&;748=e9Jjq_PKtp26Fq`DW+KV@>xLSFg|fEb@cvMTqhIdiK5}7uoi` zG~S+^&r~2M{#{VdCN%$r;G9=8Ssthy)_?u*@!ykAljH3U6u0%;*6f+faX@m~js?>t zCSBM4=%#P@gpK7vf!@<*9?Khw$#0GyoaeyMxAmIqo(o$RyUCiKxh^JEx-&Ceg@GZE zsX$?V`~%l%F9Mhj=uf+`z0sLZ=4H7z&&ioA3<_QheVg7bUt6mFocr~@wXgl`<^v0i zzdNcL%0GEBZZpZ+>dyOF$)7*ae@5BVD+U`(ZkGu7gKZL- zAIWF*E#WxdHY@gO8AgT$oxLtr8!awdEy`4&GhN^2>@OK#h95_T zw9N}nG%Eki>r}Ly%ph~_+CgERbDS&>c=Vobmo8nhxb0@MvTZxVoRm)Ps!qkqqeA%~ zg`D3xGMp>keK+Q2zR44V_MVOG)8WpC}~ky&{__<-rM z+DnqAJnqt8%lvP*iuf}$$82R=y?#C@t`7?{oZkJU?UCLwHkJedbE9)lst$)cJf5>t zadz%wOa0>ubw5^rrX5^`@gqhj(Wq+zUUcy!#^$v-~N5!&U>?1 z5_WCZ_fM4WujoCOdr1F8^-r#hUFPd1$nSWT{eG`gOwMOus51P!|6FPBCcm(W$9LYd zVo8v+*T1UFoyJ+&`g*Mc!?cu}7L#A+hX3OM#S3ykeoU^6&Q9w^G`+~}vM zz>og;|83K8c)ekB>brN28=6xd1pX74ah8K2y0E_up>((9i)e*-y^ZpW3C{WZBX|Ga ztjd_+pT8#+$kJwPSij@<=8cdTDp(Ny?!mi%Z+sa%e!N?IXYG6@#s!ZKR_}_omt}Ie z{(SE}+4(xal=CUr0ho3f1XMu@Ym^!dg29#wzVdHypB^Byb+sI@xyIn>>gTe~DWM4f DTTFn~ literal 0 HcmV?d00001 diff --git a/tests/data/test_shapenet_core_render_mixed_by_categories_1.png b/tests/data/test_shapenet_core_render_mixed_by_categories_1.png new file mode 100644 index 0000000000000000000000000000000000000000..4df6712209ca0b7d5ddd8fad27e81429dcdf82c8 GIT binary patch literal 2275 zcmeAS@N?(olHy`uVBq!ia0y~yU;;9k7&t&wwUqN(1_q8*o-U3d6?5L+JzI3eLc-zV zxhpn1?%&y*%;m9*ndR@@;#$%?f)}z3=G8_3=UHm7+g5%PhHu+ zS)4&Q=K9ws{v8#&-&cQ=da(ZYQo;4?E*uQ%zYhnlW0~X2U~w(HH(&BWh+d=R(EFVG#?10v4))`Dc;5oinn0T*2Z4D{-hc{#hm*C&bX8!N|}7rW(@k zs1_K&cu+wo^&n#3gO=yWCOvS@6SSC(Jpc=g?iU^sW?~RfWnc&dQy=WUv&;A+I}PCj zi#h8pA1WiO25EX`p z@n%XFVEATME~d>GEzD51_1uTpbkz^ucNjM;U7GeTsJrpJFvGXa4-Ui@cz&1!^?-N6 zwvDH=r^R05oX2qEuv>8H;-x~*IUBCNwf6rWYbiJn5imewR!ojtFn^1yHm57%x$k}S97Kl4u?rBn3LLk$t)(10Jdl)<^u*v-=!{v*dMQmiAA z^{62BEImNslAw6K^lThc%fJ=C16AtZuD!qcIz#{Q$Az)i->C9jx~bVM%H#*HCC*XyJK(*X;&+unU06TtN0%5Qf{$i@%tj39(4wvgd2YkG#a%27n+pmW7~cphUq}cQ_hHc*4hFL(*mSh zxBlIv#+aZuaihddwb#N7(~4i!*p$Y<@n(p)E-C!EC$>O^F@bsZxeLD2ecKr}O#bTr z>X&v)K1)OIT7%s12j1M%Jk8SKcZ^+UxM_u?%fx0WO{QHw!Fl@^zxX1o8GFt5d zbF|+5=f@1W1s+ZwI`K>2?Y;IpK{>{X2qn&}lg_FTy$|yzq zAh`1Sp7;Sl-P4=Cs}`hRn_@rrTIt>2QWgTokK9?KzyFhvfW4*cvh>cMQ^Aq({*JU; zM1~y)jU(s7r%#_Y>72ed{lMplynD2^4K=ZSfLOdavV?_Ufj0xgA{GV(FY>7Sd=V)Z z1sU#H@&)Ob=Y!3uc)rfykke;Jh94&{8rXWDSM(ccT7=dXwYC}=n$eh^}y)#rZ=^mHD@msXp&&T1D^j!ZY3MlYfbZQ7=4~D z>U}Hnpm*cRPb>mSuRp1tXOEo6@a*IE+HCt`Sz%LKd1lrQp@x$?BOhnArB^W-a8^%U zZkTgfpke0C6?aZ>ob`nzfV2mebJdNDiUmNCGGkh`$-6@ezCJKnU}8zAh+C8@&X&&v z_E>RI`7v3^<&n5TYzjlqS@tbACw`Qki&WShXxU?Q`mW?EP2Ad0(<~~F=ZyWubMJT_ zUO-N{X^%}K<&I>Wbw_e5X$jly>1|63j>Q&wNJ*w)WrgXqcB6#Jg5)OB2F0Ioe0)b8 z_~3c$*+auQ>IP}XWE%-`ibVhVlt&FcPdi{?kRT*?pi=6v884FSNprJA&-+IYHt;+z zfEfu2zO)Bh!{?njan=zSta0D>F)Vv@?8C)BZ9A12AE=#Mtp588_xIn=#h4D5?yLP> zW;OTQKCXMr8`i!4{{H^=ePYLy8DgHjJyQ5f{?yBos%UY3mY5^655znWe15uL-u~S` zA`U0$^G;$;pd#&_WxbYn4`iFb6#`v z;{N!3p%1d|F+bR2-!Jif#_^9&qatLtffFmVAbDV;ufczO(#BR;dh0n-Y|EbRa$E_N z84sA=c_G{N^nn$Vfz9msu@ZYY(F@<%702@x#JYRIfv{uJ?K90A1Y+RX1EaKTc>R;* z*E=1??E(zjimve-G};JG=m!k$+%#lqwi1BFhQ#dSxn(JT+|)s?FfiMm@aDviUO$E# zr}E^RQj@G1C64WrzP-^%VDTY0rURArw=Zo2n#v0_^<2RbnK`VFves9}GaP@)B6;^V zSCsp|W5NeM$Zgy^J@x1J{?cVKr*0mtx-j9+-}Fx@Ob7C}$%vKxW6X^-Yjx``2X!PT?|WNO9XCw zbKa5RxHfan^oC=_zi)4zo>zVRNQd0UQ+84Yl8h4PuS$RI`kU$9|1PfkUxzeznL2}W z)giylGt-k2eoXq#(DvoU*)(JSV+TH{`ZF+pzHS?H>dXgiMv2?9-{bP0l+XkKqjU_k literal 0 HcmV?d00001 diff --git a/tests/data/test_shapenet_core_render_piano_0.png b/tests/data/test_shapenet_core_render_piano_0.png new file mode 100644 index 0000000000000000000000000000000000000000..fc7524c89e5ea51bd36560a19efb938f75681c80 GIT binary patch literal 3268 zcmeHKX;f257QR`4Ad3+M0nw&@ifXJ64@Feh!H_-6%|bY zx4}cB5@9f*o1;-Sg@!<*LSzez1O!YFRKgM_(9Rrj{L%Ac{`LEF>(>3MUe)*AdUapw z7H?0r1$qkrK+S7|+cp5Cyom(NtclZey$7J0=H<42do*jfz4d6@vj?Qc^W{eM#5+cP z=yM)=wv$Mkl_7;6df9}c4am~1q+)5$=;){g#L6N+Ac1=dz)cxQnSY!A0|gR`b#u>s zvVuqhgu4Q61)x5D>@3g$^NI=|>4;Ps^Ffb3*p8e2g}sVHhbCeT0l&ezvlb0E-|0@j zXWJv_Ea&Q92Oh8G!}Ur(QJMj4cd?m>x=8{`;yLfgCgS%%q57mBZF*d84K0@k_(y2K zz9H)44g?qj`_df-cfhDu8#`1!Ci+euN#1439mq9-HLYnqtz#CNkoES1;K)REj3O9c zm6g@2!zsrD;Y=9@HXnWQ!fsm11Ri)0sS;J#Nu7w)a0b8T6UMD$qT@(dWi=6l;?OCM zc|>ub5~w)R$V^$ADX>qLGfD>d95k#B9*flCgK%RWQKtw+4`Es6|Bz_76eOV@#I$e~ zOh?Ta)qn;PQx9UnB5-K;3VJ;zaz}!h)x_?rR0!8@%B>V2cq8@9#yp~u1~^=3AZUCW)+-w4&mj|&aMt8YO`4|-?pB2}1^#}&q28SAFK2_`BI;sfx z@s8^1>bSfa@r-W0F-XqZrzbS!ZoNLNdx+w087U6?MdgZqG*1 z6Qj#_%ru+DP9&zd%E6C0XJyXB^Qz%1Idoy+@r>7ub?29?XVr)iTRzo~{Fqt1*ctwK zRVJC4Hd?7f#llU+ed8fu7NK&QW$&{snR!rvz*-$2Yek1ETA=>1|J|7VgkvhOA9JQd zoB6pB4Gfc70-i>DUg0t+M%9An2sUf&K3kI?sEYC_)P|j=hr!{pKk4wBr(NTHohvRP z$Rcelb9s8gZ0``>b~b)SF7aeQx3re(^Ap+57&b=@_Xh+n2zUf}qu8vNow)eHyMZZ} z)?^w$`Q0_*oqnh?Z@~4uk6Typ03~z}MOgB-S%>CVK7&F@9GD*|DSh}lR;U$n(%zIzubKpTD(np&aV!~sFLOh2bE&X5Xhd1=39&RNYn>>Ln%3$F z*5qZop!2aoUjGZoYZ^%&4ai zhQOin`gx~A`R9t1g?4;!=(wJ0eJI~&Rs30(Y63pP=FltEJ#U+))NDf|RZa)}u4hxc zb8@MD(zvEGkZnVXe&0i2)mQ`BDkP_V&+yhE7}w;`qyONZunK)reF2dmCw$=&d$%eA zuA`iQH?$7T%J7$X$zj~J@9(~5(X}Ln3RAzOHU-KZ%aK!tW^xUTPI)s@Lmsih^k^`Zd!vWaUsPyo zGpb+jN|EU~d6?agS&^gIqrIiQA?-GJ;z*T?JK3D$KjaqU;!^cQGhknyEN2=Le zw$47OrVUx68{2LlQI|aa(XLcsvDMhJnS+ACM6JO>3DwmC<8ab|^tA0_e1DPI`+Mpd z^l7!V4@4dz`kf+kM2nz7*XfwxR?OnI-W{a4qo(4wWw}Gmo;1m^7Q|518*6?<3tE}C zlT}{!^z3y#M9WE)MDg>y3eY(u>%jjr4rGyi~@B;R&~GVaE_uU zhPTXv7}giv%7pZ;?Y-IhVG77Kbd|)_y)CU-^2F=zH{wua6-?K|{7#hLiwfV+?rWt- zx=k>jicx-zT;KYvmYiJiWi*A*V>-Z9@cASYKg<=qe0n3bvcEF&Qn`gr&iEnqb{%C> zrZwZvPiJM5ON7z>7vY}{3rnYpvZFAhq~c>{3D@-9K_kadH(2fPX5hL5C5WsTtD3BV z^8~nw2K>f3gL|ILY5vyTY#gK2F!&$XXY`*&h$Ej@Og9GDiUh35muB;um&YwMfW92z zK(ED8mLlM^22j>aig#gn1^S7+Epr2ShiP3tnzN|r)J+R1s}WEYx#FY#ESN4T?FWy! z7|UXVVS?O--AjTR!*&2dliW6coU zE3OL^Bc67+!n8c8!_)j3 zysB?+`QLQSvQj(aVtj*ok|lwAgcJ;vQYz^{692Hk@tr zRy5P<-8o$Bc$@kF#l4Vh$o&~WC(`7Y7J{4QYQ#MQz=-Kis45?Db29zI=i7cehq_^cy@nXqA~-U-=xyI*9nO|55qcJr8g4Q zZwU~y#VBEQS63SFP-Ds2bF1SKxm9;up+%r+(KE6V?^LFmS=VkTSoi_#f|S^3U&UzR zTr$g(b~dNy8{vk8{DW8ou1cUFKXj&+)4ESiIHkzHHMjR;8H1S5_dV#aQU5)5{!D*+ gUs3xFLMDuq5AS{!vlaWR{O<(3+`ZlIe@{vJ7cE<%7XSbN literal 0 HcmV?d00001 diff --git a/tests/data/test_shapenet_core_render_piano_1.png b/tests/data/test_shapenet_core_render_piano_1.png new file mode 100644 index 0000000000000000000000000000000000000000..b53a022a4efc7963030618027ad4d16ea63703d8 GIT binary patch literal 2527 zcmeAS@N?(olHy`uVBq!ia0y~yU;;9k7&t&wwUqN(1_sV$o-U3d6?5L+J(zddi^1XI zb@%cbquDt>Ql_2I(F$ELWs=&*p5x!=h&ru(Wbt^XoVD+~ZMnbs8|M8_&u3sL=3sD` z!oc9dIZBNN0a>Zw%I7z{5^sxk+I%?Pz`X0*w#~WN?B-?VAMChV>J9cU+dcU3tLt!q@$~PjQ=Ha(X;7yF&Vf zq=lNL@Z(Ae#sqgiiSx&~tJL%u-=0(1aeVe-uJV{fU>FAEtG$!9vpK$mzyFcYg->B~ z1md^fd~;qp?$Jb+2FpB)bH2PEBA5=CsFlp{W^gmOeE0o>-xqu{!=9*p&{O+&^p95Y zh4k&0d91Im*`N8n>-{M%28B~x4VHG*Ztr~r zhwq;LeRKFAOT)}_g$L$8jucQqio^}uPF#I{;&RJDSi(vW5I1aCEXdHH!N|}d#IU1V zUpp-p9A^iXM}C~s%)rpDp!FENyqLW6EBj_xaeBalfB&_Ik4~1@u{SLRt1Kuvw?0uF zHRhQZ&O8tMFT%ud=g^c|Cb*pWCrg1fez4@ggHy7R|@Oo=nX=$?dB76(}FegCHU{@uFMPd`2XzyJ6A?>0xi+RsVN z*Lhp{jqO1E-(08v6_+X>O!;!q-EPJY3r1~_xe11k{rAp&Qyj*0z-!r=qxSXxt}ytW zv}F)8DF65O)6>&mavN+T7(sQBL2v2R7K1e}8;T{-(C)@LKP4jtmxC*2nKZ zcaOzF7vzzKuRH!7SST^!wNS%RiyIaShl>K4HuTKi|CWW#z#{z1mfYKKavS<(a?>nt zeLHZLrD5tHyPt9Pg7-f}Z?_d#b5vEq@7x#OdSwQk$7e(K-9KaZrZGVeXl@ZdD6ydw qP{_fGUU-elp_S&N0RjtvVCaBiz#zYhT$ZVaBTelF{r5}E+HYrvfV literal 0 HcmV?d00001 diff --git a/tests/data/test_shapenet_core_render_piano_2.png b/tests/data/test_shapenet_core_render_piano_2.png new file mode 100644 index 0000000000000000000000000000000000000000..fdcd4933407a74a14e84d0d75176441fd726e3a7 GIT binary patch literal 2259 zcmeAS@N?(olHy`uVBq!ia0y~yU;;9k7&t&wwUqN(1_q88kH#xYa=lN!2XUu!hBB++`F<%MEoDUl5 z-5XdI|FVFk?~0CguZWJu+6fE{E}RSwQyBi7d90#y{Dz$q$Wss4-zg^qwtaU8+jqZ^ zozdI)(0yT0{`jH$ov(nA?WZ);2h-VXdWS9k_koN!Abv;xz)^`;`kV*S&A4JFUHqrd z_(MFJMej^Qy%NL!4VDi+_55^a`p})rDA@KlADH#r);|n9{_m#UKND%DtbhBp4>(<) zy!|4#MA|}op#qEWhN*oGf3`NL>XdAZQe$9{@dmo<8pi>r{t$-MF`KS&Ffg=(-4k)o z>Hd=kUO9WX1whV=xc9b9dM`^Cj6Xvheciw+5{BMP-&A<0+|E4vaxmYj{ z6xt2JJA<1mIY6-hiVWkTAQp9jhKVuH8XDp8DsfowMFA7T8efkxu(xjyU-sCU$)X(`Ne1tVAC^9QAqvyCVKTda#av4%P~gI?LP^&RjC=m_ z8iRut7L;L64=)ktWGFVgR3ODDaX6Jv!jIuwu^{_2U^GDz?g5X!Aa*RJX#&stx&uI8 zK$E!6NzF&E!Pd?=@%#7h7pe{?Zogn%^R{k(HQ(K{>-mm-VgV^--tn0=O#yCp0-vk9 zAoh|v!S5eo~pk^mNVqjEzoB)*CZw5-idH&Pm#gTFrgraNj@m)9d=-^|&6_bv>@<<8fUV zzoPpY8Ei5D02ul2^F0UvqzfT{{zKTCeEtIfcAvlR?n5cpXWkAb4tZ?!e$m0C4W-Yf zqM2X3ycmo-xV647ZO(g#;>DSIvyb^bJw0pS`ND7m3Gmox;C%%EU(9FuJc9qn7Qh%9 zjW#PrhT-d7mY0^78?~EKq$k{;hfzv0&zgzAqmSk1YO({y<<_kZyl2$vV~;LbUT<1J z!)`aioY?IQ;qawpgQF%riC2tbx3tY2q9GFGrKhWN8YKfeOM<{XcCvtAOG;J?C(Yr< zf1AxeJQga9$_N62O${z?1?3)GQyQY;sH}ytS3Ie_y@af~euo`{(eP_R>_B`y(_d-Z zJg(Hozf&>Yg|;#C&`=N6}5H% z=P75{BpK?XsAYPI(YUn+E!5x9^DH8D7b_}UBlFB}*8e_h!7cgD&d1X+*yy_pvqyJB zkV=cud*QP)f>3Egf#+I2)W9GlzXk#~VM; zLvGyl?0R1R+q4ErKJdOb@=T2&=oS9keQ%XgBr`{TsYzZ~TdcFJYaEY914+Q!mwAzh z37E)1k#B6Ure#?B9Y;&9!bvm)b|Tlx7mO_u)U7O+ zo3==vpuZnwxxUDhcQ&-I?dq#2%NL_v@6SQ1jee?7d$o}yweVt7M>CRJDp?dZ)8~Gk zFH2OXl`T?4pf3Biu8N_(ATsvYFpJM8pSKhTfXERz&rB4%tVU5f`q03=vocZL8k!BL z&t%3S+K~GXwN3m>VsF5Z?!qlq#nF~t_g@3lHZ}e9a!P*l6Sn1C4N{H~_URjM{_v#s zk+Baq2UHbi;+Ev$M}i@h! z*^IFM3t8t1K-KvcllCljED3`52-98u@^6kgQEnh~Pw5Py$WR#g7XHq_R&(LdL|(ws zNl1l6#q&kMfzRAhI*XiNm80QgJc&{GiBqWSz}Cu9Ta%=2opkQPq2cZLhqemRlckyb zxZL}uJe)?fqp9Pg*h0E9L9r7lccRNm6neSCt6Hoo5KOj&~xO+iz|Q8PL~r} zeoPoz2azxPU}^@cGjup_^D{52jyZyXnZ}L*G4*_N-aJ0oqC4!_77*-FK`zL;_G=NG z>{+KWw)*17Adu`SR#B#1N;YU)#|Tj zBDx%S_dFG<@PCb_K}mWjYi{&E5^cd^($I+Dbe&)hBu-2PPb?32I`B$DK{bAL5CgAN zAa)B^43JDrVa9qC*A%KU*uMMxEwOPFL~2V*`X`I+N<$FZE+=h6v3EmNTko|*J-z z(Cr-SwDvK(Uy~Q0v$Xgl<&HW~=9J}cX|j?w3C1%Q;HMrstxTuKX==&uJkx&nGvm1g zhJ~ESkfzR3_X6Rk-j{Dw>sgs1q&(K7va;`q<+oJ7d~L$2Eu+OfUg_S#vRK7?DImlz z7OOZ()0WrH%U`F}nR4|H&3Eo^ZXS`D3M^I~y#|a$AdEY!?l)Q#*LCJjP8};l1HF{^ zh@P%gQ53?NSM)-Py{-e*x$!dbNUx4z@HGUJxs)cJtxEwQU+(+uP-`axf|-#Af3P4( zFkm$Guo3I$t}P&7B+IoRnsO141%0pAwYF0s*f%or1jW0|4DcKgOyVwyxNE975EOR{ zl=}QueI##tN~xE$u8#lbA%vR`p2@`KbfQ7u{BiAeJH>tozTGGnM|(sB{7KzCjpt%* zZq_sgxT~tQSs6)&2ynp+%#N^Rq%WU*g9KalgStEs$0&o10cUfXR~5BAGh7arD?4YT z67>#NC6^QEfbk;re*41g1wAyd>YtyfA9W431IiiMVnr=A++0^_x*=MZNQe0!`990% f5&ZAB;QiVGexRd6USY28wE+Hm>AsbFnE(DSiZWY# literal 0 HcmV?d00001 diff --git a/tests/data/test_shapenet_core_render_without_sample_nums_1.png b/tests/data/test_shapenet_core_render_without_sample_nums_1.png new file mode 100644 index 0000000000000000000000000000000000000000..cbf4d6924494ff7036f3415afcc84721af8fbc20 GIT binary patch literal 2767 zcmeHJ|67Y|7=QL|*vc}Zk~M|O(UdvTQmtjRiYN{lQd=z(qg76l)LYS(R=ALL8akX) zoJJ|tsMfSjCpGJw%klk7WAd#=gyw6fy>YINi|hCU{B*z1^*r}`-}mS9x$o=F4q51r zH?cMW0N?`yd=~=%4SYfXYcyM|8OH%I$q)4PS&}5x-&`4geU1fvQLPiYPknKsqVvpw z#7HkJ+!VeJn3ZnKn;meq^^HDp3%cmOcc}}2T`|CR9sq~GnQst$K?SMD6@dB5U|?;< z@f*qdRvSRy3xF?9dGJPMg01e96AFmN0C?)xm@9i?pWLB0$1H;Y!5)A`P`P7YH?f7P zM>NQ0Xx0%mA?pJ_FUUmQYxMq~zaB`R?UuT{&_{eJ^%FU>-Dt<62uG?W8gqn? zQ;$&gqOOaLMSQ5mXu4G+HU_cPiIa?ZeMs*sO<|9HeNfbuYP*lF9)kHAE_1`Eq5s)` zKZZQtst%7G67xd_VTBz{Y&+eiABVpG-z;9F*4_+_K)3h!iKW7rx~ZFt(FK8KVbUGx zbqF1BSKPHb+r*vYkUGP}2^J;MRF|RnY3F%Hb>0;}s2ZA8SXN1CScXqMA!5WkZ?$vx z`Mmx@P*;G%R=gmg>I=|oDhv=CTh#Wt1g-59$_<5dH}KSwWLT2w^k6u!b+(4%d(%6( z9fZRuP-F%3_37R*is@s~fKCDDHE3@Wh5KCyC>;kCb2JoWv>V_z7jI2kfdbl`PiXls zN&lXzppZAkj#|5_8c`?!h7Y9eGo|atPZX4B3o`+zbL}9mKw%Y1<-4`9qX(vBkKdP+ zb1Syy`lEGnNkQ7QXRl(tp1P!_q=;)ZQEAUaUHtdzpUVsh3^_MK-cfeZQrRID6a^;* zrdY3d=O}#?X~_4dPdL+h^^j`w!_6XpIzx8TqP7)Voa{|P;VJzO9lKjexw_GR(uZ4y zbF;I?z$e8np66eS9e`>9Dbhn_>#*6IhL)W5hXme*GjlzRtt9#5^zc;fg4JW-W6V)C zot&xYy74H*B!#=_e)((3N8?Q5qkg29Y<)cQa;|3q8Semg=1ZO}A{G;ugIUC4iMr{r570EuBXIt1XMjZZ z5sA7*@ADBXO=HR0ui_EcoNarxx zZr7Y^e@CC+)?h>RBVksF$};4h9j{behxHR4gW}}41TyJ9uf?T_2ljI)a!h8j9nzq^ z)bVF`%FYEr0(^CbDdog6JFtg^lkbTxKrqlkM3J1x*n)#*WJbe6R)IS*jX-IFkjJ}pTX1PD1{rOcuj)({(XI;vj9h2~G}InTi@SUz zyqS&<98VGLCgu;#kBLO^heS9kpvzNQDT`r9_@~QT4qkh;#Uz zk&x+fr6o?;n`cd91nqj^jz+v3-5ZnxNyj6&vAV&D9NbMFWK{hzT`3!Dx5p+1O-aH&rBGz$*Jj;`ypM{n;Qz*Aquc>mhhh zh2X8GNSQBt**^%o<}~^sJI;XspBzaUY(n(-h@{>2*7E?9A{SswXXPQ)6EtLjt*L9v uz^=4%D;|)wX5aQCgw`)C37@@tZ@Mbj5_c+L8*5-u1%ZAGeNW6~XZ!={88wIi literal 0 HcmV?d00001 diff --git a/tests/test_shapenet_core.py b/tests/test_shapenet_core.py index db92c83a..a1fc664c 100644 --- a/tests/test_shapenet_core.py +++ b/tests/test_shapenet_core.py @@ -1,11 +1,9 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. """ -Sanity checks for loading ShapeNet Core v1. +Sanity checks for loading ShapeNetCore. """ import os -import random import unittest -import warnings from pathlib import Path import numpy as np @@ -21,6 +19,7 @@ from pytorch3d.renderer import ( ) +# Set the SHAPENET_PATH to the local path to the dataset SHAPENET_PATH = None # If DEBUG=True, save out images generated in the tests for debugging. # All saved images have prefix DEBUG_ @@ -29,23 +28,26 @@ DATA_DIR = Path(__file__).resolve().parent / "data" class TestShapenetCore(TestCaseMixin, unittest.TestCase): - def test_load_shapenet_core(self): - # Setup - device = torch.device("cuda:0") - - # The ShapeNet dataset is not provided in the repo. - # Download this separately and update the `shapenet_path` - # with the location of the dataset in order to run this test. + def setUp(self): + """ + Check if the ShapeNet dataset is provided in the repo. + If not, download this separately and update the shapenet_path` + with the location of the dataset in order to run the tests. + """ if SHAPENET_PATH is None or not os.path.exists(SHAPENET_PATH): url = "https://www.shapenet.org/" - msg = """ShapeNet data not found, download from %s, save it at the path %s, - update SHAPENET_PATH at the top of the file, and rerun""" % ( - url, - SHAPENET_PATH, + msg = ( + "ShapeNet data not found, download from %s, update " + "SHAPENET_PATH at the top of the file, and rerun." ) - warnings.warn(msg) - return True + self.skipTest(msg % url) + + def test_load_shapenet_core(self): + """ + Test loading both the entire ShapeNetCore dataset and a subset of the ShapeNetCore + dataset. Check the loaded datasets return items of the correct shapes and types. + """ # Try loading ShapeNetCore with an invalid version number and catch error. with self.assertRaises(ValueError) as err: ShapeNetCore(SHAPENET_PATH, version=3) @@ -70,8 +72,7 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase): self.assertEqual(len(shapenet_dataset), sum(model_num_list)) # Randomly retrieve an object from the dataset. - rand_obj = random.choice(shapenet_dataset) - self.assertEqual(len(rand_obj), 5) + rand_obj = shapenet_dataset[torch.randint(len(shapenet_dataset), (1,))] # Check that data types and shapes of items returned by __getitem__ are correct. verts, faces = rand_obj["verts"], rand_obj["faces"] self.assertTrue(verts.dtype == torch.float32) @@ -82,7 +83,7 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase): self.assertEqual(faces.shape[-1], 3) # Load six categories from ShapeNetCore. - # Specify categories in the form of a combination of offsets and labels. + # Specify categories with a combination of offsets and labels. shapenet_subset = ShapeNetCore( SHAPENET_PATH, synsets=[ @@ -109,10 +110,37 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase): ] self.assertEqual(len(shapenet_subset), sum(subset_model_nums)) - # Render the first image in the piano category. - R, T = look_at_view_transform(1.0, 1.0, 90) + def test_catch_render_arg_errors(self): + """ + Test rendering ShapeNetCore with invalid model_ids, categories or indices, + and catch corresponding errors. + """ + # Load ShapeNetCore. + shapenet_dataset = ShapeNetCore(SHAPENET_PATH) + + # Try loading with an invalid model_id and catch error. + with self.assertRaises(ValueError) as err: + shapenet_dataset.render(model_ids=["piano0"]) + self.assertTrue("not found in the loaded dataset" in str(err.exception)) + + # Try loading with an index out of bounds and catch error. + with self.assertRaises(IndexError) as err: + shapenet_dataset.render(idxs=[100000]) + self.assertTrue("are out of bounds" in str(err.exception)) + + def test_render_shapenet_core(self): + """ + Test rendering objects from ShapeNetCore. + """ + # Setup device and seed for random selections. + device = torch.device("cuda:0") + torch.manual_seed(39) + + # Load category piano from ShapeNetCore. piano_dataset = ShapeNetCore(SHAPENET_PATH, synsets=["piano"]) + # Rendering settings. + R, T = look_at_view_transform(1.0, 1.0, 90) cameras = OpenGLPerspectiveCameras(R=R, T=T, device=device) raster_settings = RasterizationSettings(image_size=512) lights = PointLights( @@ -122,17 +150,102 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase): specular_color=((0, 0, 0),), device=device, ) - images = piano_dataset.render( - 0, + + # Render first three models in the piano category. + pianos = piano_dataset.render( + idxs=list(range(3)), device=device, cameras=cameras, raster_settings=raster_settings, lights=lights, ) - rgb = images[0, ..., :3].squeeze().cpu() - if DEBUG: - Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( - DATA_DIR / "DEBUG_shapenet_core_render_piano.png" + # Check that there are three images in the batch. + self.assertEqual(pianos.shape[0], 3) + + # Compare the rendered models to the reference images. + for idx in range(3): + piano_rgb = pianos[idx, ..., :3].squeeze().cpu() + if DEBUG: + Image.fromarray((piano_rgb.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / ("DEBUG_shapenet_core_render_piano_by_idxs_%s.png" % idx) + ) + image_ref = load_rgb_image( + "test_shapenet_core_render_piano_%s.png" % idx, DATA_DIR ) - image_ref = load_rgb_image("test_shapenet_core_render_piano.png", DATA_DIR) - self.assertClose(rgb, image_ref, atol=0.05) + self.assertClose(piano_rgb, image_ref, atol=0.05) + + # Render the same piano models but by model_ids this time. + pianos_2 = piano_dataset.render( + model_ids=[ + "13394ca47c89f91525a3aaf903a41c90", + "14755c2ee8e693aba508f621166382b0", + "156c4207af6d2c8f1fdc97905708b8ea", + ], + device=device, + cameras=cameras, + raster_settings=raster_settings, + lights=lights, + ) + + # Compare the rendered models to the reference images. + for idx in range(3): + piano_rgb_2 = pianos_2[idx, ..., :3].squeeze().cpu() + if DEBUG: + Image.fromarray((piano_rgb_2.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / ("DEBUG_shapenet_core_render_piano_by_ids_%s.png" % idx) + ) + image_ref = load_rgb_image( + "test_shapenet_core_render_piano_%s.png" % idx, DATA_DIR + ) + self.assertClose(piano_rgb_2, image_ref, atol=0.05) + + ####################### + # Render by categories + ####################### + + # Load ShapeNetCore. + shapenet_dataset = ShapeNetCore(SHAPENET_PATH) + + # Render a mixture of categories and specify the number of models to be + # randomly sampled from each category. + mixed_objs = shapenet_dataset.render( + categories=["faucet", "chair"], + sample_nums=[2, 1], + device=device, + cameras=cameras, + raster_settings=raster_settings, + lights=lights, + ) + # Compare the rendered models to the reference images. + for idx in range(3): + mixed_rgb = mixed_objs[idx, ..., :3].squeeze().cpu() + if DEBUG: + Image.fromarray((mixed_rgb.numpy() * 255).astype(np.uint8)).save( + DATA_DIR + / ("DEBUG_shapenet_core_render_mixed_by_categories_%s.png" % idx) + ) + image_ref = load_rgb_image( + "test_shapenet_core_render_mixed_by_categories_%s.png" % idx, DATA_DIR + ) + self.assertClose(mixed_rgb, image_ref, atol=0.05) + + # Render a mixture of categories without specifying sample_nums. + mixed_objs_2 = shapenet_dataset.render( + categories=["faucet", "chair"], + device=device, + cameras=cameras, + raster_settings=raster_settings, + lights=lights, + ) + # Compare the rendered models to the reference images. + for idx in range(2): + mixed_rgb_2 = mixed_objs_2[idx, ..., :3].squeeze().cpu() + if DEBUG: + Image.fromarray((mixed_rgb_2.numpy() * 255).astype(np.uint8)).save( + DATA_DIR + / ("DEBUG_shapenet_core_render_without_sample_nums_%s.png" % idx) + ) + image_ref = load_rgb_image( + "test_shapenet_core_render_without_sample_nums_%s.png" % idx, DATA_DIR + ) + self.assertClose(mixed_rgb_2, image_ref, atol=0.05)