From 5636eb6152d4a11593e4268b4d11c9d6d9abdbf1 Mon Sep 17 00:00:00 2001 From: Luya Gao Date: Tue, 14 Jul 2020 14:52:21 -0700 Subject: [PATCH] Test rendering models for R2N2 Summary: Adding a render function for R2N2. Reviewed By: nikhilaravi Differential Revision: D22230228 fbshipit-source-id: a9f588ddcba15bb5d8be1401f68d730e810b4251 --- pytorch3d/datasets/r2n2/r2n2.py | 2 + pytorch3d/datasets/shapenet_base.py | 2 +- .../data/test_r2n2_render_by_categories_0.png | Bin 0 -> 2299 bytes .../data/test_r2n2_render_by_categories_1.png | Bin 0 -> 3764 bytes .../data/test_r2n2_render_by_categories_2.png | Bin 0 -> 2390 bytes .../test_r2n2_render_by_idxs_and_ids_0.png | Bin 0 -> 2197 bytes .../test_r2n2_render_by_idxs_and_ids_1.png | Bin 0 -> 2031 bytes .../test_r2n2_render_by_idxs_and_ids_2.png | Bin 0 -> 2065 bytes tests/test_r2n2.py | 139 +++++++++++++++++- 9 files changed, 134 insertions(+), 9 deletions(-) create mode 100644 tests/data/test_r2n2_render_by_categories_0.png create mode 100644 tests/data/test_r2n2_render_by_categories_1.png create mode 100644 tests/data/test_r2n2_render_by_categories_2.png create mode 100644 tests/data/test_r2n2_render_by_idxs_and_ids_0.png create mode 100644 tests/data/test_r2n2_render_by_idxs_and_ids_1.png create mode 100644 tests/data/test_r2n2_render_by_idxs_and_ids_2.png diff --git a/pytorch3d/datasets/r2n2/r2n2.py b/pytorch3d/datasets/r2n2/r2n2.py index 1305e478..56e4208d 100644 --- a/pytorch3d/datasets/r2n2/r2n2.py +++ b/pytorch3d/datasets/r2n2/r2n2.py @@ -64,6 +64,7 @@ class R2N2(ShapeNetBase): continue synset_set.add(synset) + self.synset_starts[synset] = len(self.synset_ids) models = split_dict[synset].keys() for model in models: # Examine if the given model is present in the ShapeNetCore path. @@ -78,6 +79,7 @@ class R2N2(ShapeNetBase): continue self.synset_ids.append(synset) self.model_ids.append(model) + self.synset_lens[synset] = len(self.synset_ids) - self.synset_starts[synset] # Examine if all the synsets in the standard R2N2 mapping are present. # Update self.synset_inv so that it only includes the loaded categories. diff --git a/pytorch3d/datasets/shapenet_base.py b/pytorch3d/datasets/shapenet_base.py index daf156be..d305894f 100644 --- a/pytorch3d/datasets/shapenet_base.py +++ b/pytorch3d/datasets/shapenet_base.py @@ -34,7 +34,7 @@ class ShapeNetBase(torch.utils.data.Dataset): self.synset_starts = {} self.synset_lens = {} self.shapenet_dir = "" - self.model_dir = "" + self.model_dir = "model.obj" def __len__(self): """ diff --git a/tests/data/test_r2n2_render_by_categories_0.png b/tests/data/test_r2n2_render_by_categories_0.png new file mode 100644 index 0000000000000000000000000000000000000000..03fb791ed35cb5be0b793a317e35919bc261c052 GIT binary patch literal 2299 zcmeAS@N?(olHy`uVBq!ia0y~yU;;9k7&t&wwUqN(1_q8po-U3d6?5L+y%{vefx+S8 zj8BLE7Jm}ZP772%?O^_W&hr_IBwkoMn19})bK0DtYG3?bMur$?28Kl}3<_SO6lp;) zZ~Z-nWnVshD*CGSfp0$d>K#lB9YPEZ8jSd;hS?wFIIa;?hR+BjxgDJLT<4>h7z9)q z7y_|SIfaWQDrD#cdIuOU@$X!bl?pO!!q?_Im!$<>314eax$FpZM*5iaO3C; z=Fe68s+f1|=biZDkH6aw{`8w3KU%aI47LfO7+?HhLDOMQ28SsO3@$jR&l^>^_!L2I zIuN(7|L_uc@{&6=*YKD+rehBrp2^Rgj3rna@|S5{RKe5$^YY~l5sk^nt}sw7Kb&Zf zI}R#Dx8(+?}sx-2m^b0J@Y5#xii;>Vln|KEIG zpV9I2bba~tXZ%cI9|a_Ukr#1R+%5*C(Fi7-jqeat<8Sby&fO8QnwWH1%=t6fmMj80PM(p6B0K8#Ar~C6s&zkasnJf*;k{81+darvbCV;O% zgUVG`mNNR_h;L9O`tCp$53rQQ@IO+KO+fAeLtP(N8va6Q&&@saPCIR2zBzM_8^ev$ z4W$;lrxe@ho-X=oW-x7izumEw?!O<&eRJH<%kW`gdAn%xSyj7RZ|(TLfn0VTt&AC5 cb$=a09@CSXb-|Hwz(x&&r>mdKI;Vst0BBz)pa1{> literal 0 HcmV?d00001 diff --git a/tests/data/test_r2n2_render_by_categories_1.png b/tests/data/test_r2n2_render_by_categories_1.png new file mode 100644 index 0000000000000000000000000000000000000000..871e5594ced0202f220e944609adf90b26402b4b GIT binary patch literal 3764 zcmeH~SyU6o8pkJ#n6Stakwtb1u__@*5m|&-Km?J+rGX*=i-1C9Uu6qBh=|ofz@Q*h zb_0kA8aA8K04k|vv8)LNLXj;Jgv3B{u|3y^`+CoPx_z2CXU;s#cjo(kzkjxilbwvD zsw4nF#{Q(u833T$TPQ$$TH;i#9|5o{&fezuxtQXmiHR^B(>>{LsMXg(h;6CQs#G%R zR?3I-e={8ooOo?R{{HqW9z0Fb5tNG?zy z=K#Q0$bSz14~mfVwE^jJ%h=LGaS}}S>C^Yr>iM@@iC*mze^{8#N2z8pS9~iy&lUsR zQ1bk6f|iCy(uU_i3JZdBX79IX;}6p+q?n*^X47WD7?oM~^liZvh7s^PNy#WL7`^;5 zLKj>JV($(#lw)8YcNbi70GBKH__l9jyIUN+pk2*n5`RQW9i;hwg96yy#&S|vqy`E+ zKVTe!IaD#2IFbrN@{5byE~#)2KuTL?>Cqf$L2$uq1Rx_+a`X;A66kVwNhY?8ZuS<2 zJbFML}-@C=BSRMUIpoX{2 z#TJfy$kDbpwMUf6#o1cpxv@b4J#lXtEW)yjFZKYDxvKjmiYD|Bw#C?fm59~>BmOLh-2-bGaK7w2?-t2K`N~7){-UGp|vR~;V+gc7cDj69zKI>tm(cYE}ph{ zN|3PfB3btMM~$@QHu&HxVdkl^3rFHCBrJaYahN8OBlUT)`=TviZr%bVJwnukl6vPF zi~Mn_J5rz#Y{`=@%}tvZi+B{eH>R+81z!WWrI+q!*HQjU~sf1(##SQgFXQz3cvzAz}UAe?m`>1(H zFmCK0gS=d0l;@@-y7q|e8!PIOj&-k|7*A>&bzNDo;=$t7nagnR8ahOI^_K^(*DsVQ zsTb?e%zgN3R=nWac1K1W6vv#&6o+GfoFy-0GAlHMbF3{U7VEh#2zQ3{t(uc0`kcg) z^~q_nIY>t5>O8|9>M}U>)`a6cGfSY!YI?q=C`F-xM^VzlrHHF5Psme#mJK0xj{CqJ zNB)HyFEeV@Ew%3z8zb?yp2WzB=^9Hx;LMl12A$xgSc2+=X=V z8B?ZV{1DIXRpTQ^a6N2EeDLBycvgyBRtz3R4SMH3zK@>A}`Uv9WkiMV+sGrT*4Y z6~gvc^m9))wx}NrIHjL)2JW5FLGVtuW6g>A!b>Hrg=g`zio^KhP&_P{SIgfmwD8X{ zC=7w(jeP9_M$%tIx5-54?(@ExVsQDKhpdG)$S{QZ9T*~+9USQ$bv~NORd69-Q+4c0 zdk%*kWy6qN%1(Xg0?qYVsh*jJJRPw1G-Qvmbp!_Z=U+24yE84zKmhaK`7~F-?VUvH z7$p}U%kZ_ySJ`rjRU;W-vAo;g*W^}w+MSIKv;pLW)NPFS>SQz;&_eT}&I%-%D_yo~ zZb$ZQbM|n}D|69m0mV~*#HmL#K z2`fdTRQgfiW0A#HcZTO*BrVNF=l#qazua88YiO7r_8GGAF9Y=!PXiKl_eWjG(QZMD zVKjCO1*hReF1(*dhqFsPna?~ocEsxeoZD8`u$OlD-|+9?WdF1d z_YS9CWSc{`KUs3WAvEM?s{a^O2<^}XxHIIeD0d850q|1t4O2>mv{4-2F-6O`bHO(Cnw$8_h8qy zou!^=?c-~Z2it}P`TSQmr~HBy0-p>@!r65xnDHXC;&?xKayWJ(#aT4-*e#(#?U$pX#so006)`b+W>+B#8=71ZnQ#NWH8IqK1D~t~}Jh`_@zJ8}X zDzAy?Iw|GMc$}BA^X}%Kkxik`SZC=vFE5`45|sEePOZ;$$RnGC852sO`LW}5Ar>22 zfV^Kw={;d$wmwW2bHb-XO~V6Eh~1ze4_q;U-{0woj4Rm5eqQ1}$C+gPsaT@DKTLbk z24JK4b^0;SHxlXyy1*#!Vmi;iiGdf&lXi5NQ)tDr>PhciFkm8y523VgL|~v@$J;fu z2vYd!drs>y=^x&9q9E*6$-IGjVQ;wJ8^0CfG9{KoU+4;^->LIm;rxuDwAHm@U2x<5 zMTKL^j()?@i53~V)VI5sjH~GnmwtA)bLYP`Prj`fTJqWUxW@IWjmQ2cR`^9w%h)?K zUjOso-?mRuu%Hbty&j|X*& z!LPqa@FG(68&Hta=lYqP0!o=Wq?Ec;IkUd^$D@$as^Y1_rn5Ejps!agtBB^O4=KIH z)mH9E;Gn=B-2DAHw@`HmJM6^oZZ|RI2rymba>tadwhu^qgQ2Mob}Y0j)bDiYvevf1 z5dG>RW{A^95VpHskHmO(8wOZ(vwy!~z8k{+BBPr+-ZTX#*-2-?muKBXdbm$jh`0~= zSpnc-e_J)y^<-sI20Fnie{&Vt}3o$fkFfw$2sRU!~|HcPSr8n0LGkn|p?A*Fa ztsgB68`$>1=Bfnk&y4T8~BFf=p5 zmE)zco6j>i?}=N#^q*WrFGE3xdU}+QMKi;OjiPN8t;PqcSsMPX*|{@&9+#aW!<&@9 z6J!6ke3;Czp;0!x^tf@bGJ{U;cgeed|I9i2o^iuE=4+oD+#e*|V~Ds|xx4b7a3MoJ zgU<64Yi&L)yfN3DK`bcZehl|{ zVFtO>XLb*ZDrPWzkYWD5KSt0WTns)?w7dOJP5fvw3xk3eLq(JS*7!%eD%uzf=2%uY zdp~9b7rs4DC%?HJBfA(}m>Enszv2E9CP^=l`5;XaN3V--FX-)2hSf(r54*diclxji zf{NP*b#@!?ottXZgq;4(Kj(f`(3^G}xpXX0zaPy1v})&cWQD`8@GP&|8@IW?!07tv zz2MLTMbhk_vyueFPrro={QkM=QQ5jFf-DV+JlA?-)S5XM7SC84y~FL8AOo9G93=Q9 zPPgnmd)06jFs>9KadklW%+2XL^yYmmjJ^KGSRm<^-SfJNN#?+)e*lhpb6_GV+g~K@ zqsVY3<;NWc;h$U`Gar2xudiA5neBj3R&Ydh+%*-3HBaB9+WdZbqmN<3ipkqPElGc< z!gS!(x7hQw{(itJvN!fy$@JnTh7GGr%gP`%`SXgjjl|NT+W4bp eKxynGEe(wa8YgX|%N%yje*Tli;e6#=Muv!X1_l>S28Su5 z)MyY4s#I|0@Ga&LpUYdmamNrf%=KGdIdA^FdDrGC zJ;|H*d};mhqTTuW(WUyuW_$PSxpV*idRx}|x9b>p-^)ME_8_f0|Fm@3>CFe(-aa}Z zle|;6l>5ao|A$Z4-#)sU{l?tTveQrBo_Z*@lj+6@E;F__SC;vmb!&@E>QUa$`DPQt z4!@M-9{vMn@;kn-+-51afi?Ob|78W!y3EG*d*kuz3qZ&%;d&c?W(~CSY(Yw+5nyx~F1|vg<5W_%HQZfyhvyT09V7SK< zzhNB%!y*<21+O7Vy-Iw`C}!w;*?d3e2l@Bak0)@aTxOK`-68+!Cv)E8SFD?!o_#DP zKfh67x8n3x7Qt_L=yjoMD7=>Nsc%d(;I UU&#t!K4oC=boFyt=akR{0AV%?p#T5? literal 0 HcmV?d00001 diff --git a/tests/data/test_r2n2_render_by_idxs_and_ids_1.png b/tests/data/test_r2n2_render_by_idxs_and_ids_1.png new file mode 100644 index 0000000000000000000000000000000000000000..ae285391ef7a7341fabeddc4b6bbb49b911c2a76 GIT binary patch literal 2031 zcmeAS@N?(olHy`uVBq!ia0y~yU;;9k7&t&wwUqN(1_t(Ro-U3d6?5L+4fH!~AmHHG zWw9ea@1J1Pwj+FJXYX`atnl`-nC_itbMhI!?3w?Ufnkm_1H&Q~1_iHCYBUJwlnU(5 z)iFFsOzzn%ulVFp{_jVOJ3jcoyDybo@%DAaUYQf&4l-vdqE5Yj>cGSxpvu4yIB--1guwZ5Qv3P0o6*e59v(=2jkC9>CwyGAK_5d?{Rwcuo zCPtgY!ymRVVzCEg`CE>LV1|bFH7#cJf89Ic}sTIDA1zZdT%X678<8los0)I0- z5aW3eZtadU0JE|g<}fk*N%%Ygw`=%7uE~EZ-;dKZ!b}fD7#?&7vn_^(d|vUn&GR4X z?XzIgdwJwx8Z#CzVD%SH!K#`6Ie89j%o)&BIk0(R)OP9@O7q{?dliak9dQCyEvp$Ao(V8CXfQH#2#r#sK`?w% z!LM1i2VQ-voo~yOzw!OP0Hfw5d_t_t2V}773=Gtw8rB&edn0)7UDe)ma*d^&46(Pjot1m| zR&V<0H-Zt literal 0 HcmV?d00001 diff --git a/tests/test_r2n2.py b/tests/test_r2n2.py index 8cd2ed8c..ba5fdf69 100644 --- a/tests/test_r2n2.py +++ b/tests/test_r2n2.py @@ -5,10 +5,19 @@ Sanity checks for loading R2N2. import json import os import unittest +from pathlib import Path +import numpy as np import torch -from common_testing import TestCaseMixin +from common_testing import TestCaseMixin, load_rgb_image +from PIL import Image from pytorch3d.datasets import R2N2, collate_batched_meshes +from pytorch3d.renderer import ( + OpenGLPerspectiveCameras, + PointLights, + RasterizationSettings, + look_at_view_transform, +) from torch.utils.data import DataLoader @@ -17,6 +26,9 @@ R2N2_PATH = None SHAPENET_PATH = None SPLITS_PATH = None +DEBUG = False +DATA_DIR = Path(__file__).resolve().parent / "data" + class TestR2N2(TestCaseMixin, unittest.TestCase): def setUp(self): @@ -44,16 +56,14 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): def test_load_R2N2(self): """ - Test loading the train split of R2N2. Check the loaded dataset return items - of the correct shapes and types. + Test the loaded train split of R2N2 return items of the correct shapes and types. """ # Load dataset in the train split. - split = "train" - r2n2_dataset = R2N2(split, SHAPENET_PATH, R2N2_PATH, SPLITS_PATH) + r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH) # Check total number of objects in the dataset is correct. with open(SPLITS_PATH) as splits: - split_dict = json.load(splits)[split] + split_dict = json.load(splits)["train"] model_nums = [len(split_dict[synset].keys()) for synset in split_dict.keys()] self.assertEqual(len(r2n2_dataset), sum(model_nums)) @@ -75,8 +85,7 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): the correct shapes and types are returned. """ # Load dataset in the train split. - split = "train" - r2n2_dataset = R2N2(split, SHAPENET_PATH, R2N2_PATH, SPLITS_PATH) + r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH) # Randomly retrieve several objects from the dataset and collate them. collated_meshes = collate_batched_meshes( @@ -109,3 +118,117 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): self.assertEqual(len(object_batch["label"]), batch_size) self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size) self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size) + + def test_catch_render_arg_errors(self): + """ + Test rendering R2N2 with an invalid model_id, category or index, and + catch corresponding errors. + """ + # Load dataset in the train split. + r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH) + + # Try loading with an invalid model_id and catch error. + with self.assertRaises(ValueError) as err: + r2n2_dataset.render(model_ids=["lamp0"]) + self.assertTrue("not found in the loaded dataset" in str(err.exception)) + + # Try loading with an index out of bounds and catch error. + with self.assertRaises(IndexError) as err: + r2n2_dataset.render(idxs=[1000000]) + self.assertTrue("are out of bounds" in str(err.exception)) + + def test_render_r2n2(self): + """ + Test rendering objects from R2N2 selected both by indices and model_ids. + """ + # Set up device and seed for random selections. + device = torch.device("cuda:0") + torch.manual_seed(39) + + # Load dataset in the train split. + r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH) + + # Render first three models in the dataset. + R, T = look_at_view_transform(1.0, 1.0, 90) + cameras = OpenGLPerspectiveCameras(R=R, T=T, device=device) + raster_settings = RasterizationSettings(image_size=512) + lights = PointLights( + location=torch.tensor([0.0, 1.0, -2.0], device=device)[None], + # TODO: debug the source of the discrepancy in two images when rendering on GPU. + diffuse_color=((0, 0, 0),), + specular_color=((0, 0, 0),), + device=device, + ) + + r2n2_by_idxs = r2n2_dataset.render( + idxs=list(range(3)), + device=device, + cameras=cameras, + raster_settings=raster_settings, + lights=lights, + ) + # Check that there are three images in the batch. + self.assertEqual(r2n2_by_idxs.shape[0], 3) + + # Compare the rendered models to the reference images. + for idx in range(3): + r2n2_by_idxs_rgb = r2n2_by_idxs[idx, ..., :3].squeeze().cpu() + if DEBUG: + Image.fromarray((r2n2_by_idxs_rgb.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / ("DEBUG_r2n2_render_by_idxs_%s.png" % idx) + ) + image_ref = load_rgb_image( + "test_r2n2_render_by_idxs_and_ids_%s.png" % idx, DATA_DIR + ) + self.assertClose(r2n2_by_idxs_rgb, image_ref, atol=0.05) + + # Render the same models but by model_ids this time. + r2n2_by_model_ids = r2n2_dataset.render( + model_ids=[ + "1a4a8592046253ab5ff61a3a2a0e2484", + "1a04dcce7027357ab540cc4083acfa57", + "1a9d0480b74d782698f5bccb3529a48d", + ], + device=device, + cameras=cameras, + raster_settings=raster_settings, + lights=lights, + ) + + # Compare the rendered models to the reference images. + for idx in range(3): + r2n2_by_model_ids_rgb = r2n2_by_model_ids[idx, ..., :3].squeeze().cpu() + if DEBUG: + Image.fromarray( + (r2n2_by_model_ids_rgb.numpy() * 255).astype(np.uint8) + ).save(DATA_DIR / ("DEBUG_r2n2_render_by_model_ids_%s.png" % idx)) + image_ref = load_rgb_image( + "test_r2n2_render_by_idxs_and_ids_%s.png" % idx, DATA_DIR + ) + self.assertClose(r2n2_by_model_ids_rgb, image_ref, atol=0.05) + + ############################### + # Test rendering by categories + ############################### + + # Render a mixture of categories. + categories = ["chair", "lamp"] + mixed_objs = r2n2_dataset.render( + categories=categories, + sample_nums=[1, 2], + device=device, + cameras=cameras, + raster_settings=raster_settings, + lights=lights, + ) + # Compare the rendered models to the reference images. + for idx in range(3): + mixed_rgb = mixed_objs[idx, ..., :3].squeeze().cpu() + if DEBUG: + Image.fromarray((mixed_rgb.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / ("DEBUG_r2n2_render_by_categories_%s.png" % idx) + ) + image_ref = load_rgb_image( + "test_r2n2_render_by_categories_%s.png" % idx, DATA_DIR + ) + self.assertClose(mixed_rgb, image_ref, atol=0.05)