From 326e4ccb5bc3de3bdee97cc5907e2e2c75bdcc73 Mon Sep 17 00:00:00 2001 From: Luya Gao Date: Fri, 7 Aug 2020 13:21:26 -0700 Subject: [PATCH] Return R2N2 R,T,K Summary: Return rotation, translation and intrinsic matrices necessary to reproduce R2N2's own renderings. Reviewed By: nikhilaravi Differential Revision: D22462520 fbshipit-source-id: 46a3859743ebc43c7a24f75827d2be3adf3f486b --- pytorch3d/datasets/r2n2/r2n2.py | 126 +++++++++++++++++- pytorch3d/datasets/shapenet_base.py | 29 ++-- pytorch3d/datasets/utils.py | 77 ++++++++++- ...2n2_render_with_blender_calibrations_0.png | Bin 0 -> 3119 bytes ...2n2_render_with_blender_calibrations_1.png | Bin 0 -> 2792 bytes ...2n2_render_with_blender_calibrations_2.png | Bin 0 -> 3000 bytes ...2n2_render_with_blender_calibrations_3.png | Bin 0 -> 3203 bytes tests/test_r2n2.py | 65 ++++++++- 8 files changed, 277 insertions(+), 20 deletions(-) create mode 100644 tests/data/test_r2n2_render_with_blender_calibrations_0.png create mode 100644 tests/data/test_r2n2_render_with_blender_calibrations_1.png create mode 100644 tests/data/test_r2n2_render_with_blender_calibrations_2.png create mode 100644 tests/data/test_r2n2_render_with_blender_calibrations_3.png diff --git a/pytorch3d/datasets/r2n2/r2n2.py b/pytorch3d/datasets/r2n2/r2n2.py index 13214c39..ecf3fe62 100644 --- a/pytorch3d/datasets/r2n2/r2n2.py +++ b/pytorch3d/datasets/r2n2/r2n2.py @@ -10,7 +10,9 @@ import numpy as np import torch from PIL import Image from pytorch3d.datasets.shapenet_base import ShapeNetBase +from pytorch3d.datasets.utils import compute_extrinsic_matrix from pytorch3d.io import load_obj +from pytorch3d.renderer import HardPhongShader from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.transforms import Transform3d from tabulate import tabulate @@ -168,6 +170,9 @@ class R2N2(ShapeNetBase): - label (str): synset label. - images: FloatTensor of shape (V, H, W, C), where V is number of views returned. Returns a batch of the renderings of the models from the R2N2 dataset. + - R: Rotation matrix of shape (V, 3, 3), where V is number of views returned. + - T: Translation matrix of shape (V, 3), where V is number of views returned. + - K: Intrinsic matrix of shape (V, 4, 4), where V is number of views returned. """ if isinstance(model_idx, tuple): model_idx, view_idxs = model_idx @@ -213,7 +218,11 @@ class R2N2(ShapeNetBase): "rendering", ) - images = [] + # Read metadata file to obtain params for calibration matrices. + with open(path.join(rendering_path, "rendering_metadata.txt"), "r") as f: + metadata_lines = f.readlines() + + images, Rs, Ts = [], [], [] for i in model_views: # Read image. image_path = path.join(rendering_path, "%02d.png" % i) @@ -221,10 +230,125 @@ class R2N2(ShapeNetBase): image = torch.from_numpy(np.array(raw_img) / 255.0)[..., :3] images.append(image.to(dtype=torch.float32)) + # Get camera calibration. + azim, elev, yaw, dist_ratio, fov = [ + float(v) for v in metadata_lines[i].strip().split(" ") + ] + R, T = self._compute_camera_calibration(azim, elev, dist_ratio) + Rs.append(R) + Ts.append(T) + + # Intrinsic matrix extracted from the Blender with slight modification to work with + # PyTorch3D world space. Taken from meshrcnn codebase: + # https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/coords.py + K = torch.tensor( + [ + [2.1875, 0.0, 0.0, 0.0], + [0.0, 2.1875, 0.0, 0.0], + [0.0, 0.0, -1.002002, -0.2002002], + [0.0, 0.0, 1.0, 0.0], + ] + ) model["images"] = torch.stack(images) + model["R"] = torch.stack(Rs) + model["T"] = torch.stack(Ts) + model["K"] = K.expand(len(model_views), 4, 4) return model + def _compute_camera_calibration(self, azim: float, elev: float, dist_ratio: float): + """ + Helper function for calculating rotation and translation matrices from azimuth + angle, elevation and distance ratio. + + Args: + azim: Rotation about the z-axis, in degrees. + elev: Rotation above the xy-plane, in degrees. + dist_ratio: Ratio of distance from the origin to the maximum camera distance. + + Returns: + - R: Rotation matrix of shape (3, 3). + - T: Translation matrix of shape (3). + """ + # Retrive R,T,K of the selected view(s) by reading the metadata. + MAX_CAMERA_DISTANCE = 1.75 # Constant from R2N2. + dist = dist_ratio * MAX_CAMERA_DISTANCE + RT = compute_extrinsic_matrix(azim, elev, dist) + + # Transform the mesh vertices from shapenet world to pytorch3d world. + shapenet_to_pytorch3d = torch.tensor( + [ + [-1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, -1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ], + dtype=torch.float32, + ) + RT = compute_extrinsic_matrix(azim, elev, dist) # (4, 4) + RT = torch.transpose(RT, 0, 1).mm(shapenet_to_pytorch3d) # (4, 4) + + # Extract rotation and translation matrices from RT. + R = RT[:3, :3] + T = RT[3, :3] + return R, T + + def render( + self, + model_ids: Optional[List[str]] = None, + categories: Optional[List[str]] = None, + sample_nums: Optional[List[int]] = None, + idxs: Optional[List[int]] = None, + view_idxs: Optional[List[int]] = None, + shader_type=HardPhongShader, + device="cpu", + **kwargs + ) -> torch.Tensor: + """ + Render models with BlenderCamera by default to achieve the same orientations as the + R2N2 renderings. Also accepts other types of cameras and any of the args that the + render function in the ShapeNetBase class accepts. + + Args: + view_idxs: each model will be rendered with the orientation(s) of the specified + views. Only render by view_idxs if no camera or args for BlenderCamera is + supplied. + Accepts any of the args of the render function in ShapnetBase: + model_ids: List[str] of model_ids of models intended to be rendered. + categories: List[str] of categories intended to be rendered. categories + and sample_nums must be specified at the same time. categories can be given + in the form of synset offsets or labels, or a combination of both. + sample_nums: List[int] of number of models to be randomly sampled from + each category. Could also contain one single integer, in which case it + will be broadcasted for every category. + idxs: List[int] of indices of models to be rendered in the dataset. + shader_type: Shader to use for rendering. Examples include HardPhongShader + (default), SoftPhongShader etc or any other type of valid Shader class. + device: torch.device on which the tensors should be located. + **kwargs: Accepts any of the kwargs that the renderer supports and any of the + args that BlenderCamera supports. + + Returns: + Batch of rendered images of shape (N, H, W, 3). + """ + idxs = self._handle_render_inputs(model_ids, categories, sample_nums, idxs) + r = torch.cat([self[idxs[i], view_idxs]["R"] for i in range(len(idxs))]) + t = torch.cat([self[idxs[i], view_idxs]["T"] for i in range(len(idxs))]) + k = torch.cat([self[idxs[i], view_idxs]["K"] for i in range(len(idxs))]) + # Initialize default camera using R, T, K from kwargs or R, T, K of the specified views. + blend_cameras = BlenderCamera( + R=kwargs.get("R", r), + T=kwargs.get("T", t), + K=kwargs.get("K", k), + device=device, + ) + cameras = kwargs.get("cameras", blend_cameras).to(device) + kwargs.pop("cameras", None) + # pass down all the same inputs + return super().render( + idxs=idxs, shader_type=shader_type, device=device, cameras=cameras, **kwargs + ) + class BlenderCamera(CamerasBase): """ diff --git a/pytorch3d/datasets/shapenet_base.py b/pytorch3d/datasets/shapenet_base.py index 722bde09..735a8414 100644 --- a/pytorch3d/datasets/shapenet_base.py +++ b/pytorch3d/datasets/shapenet_base.py @@ -111,12 +111,27 @@ class ShapeNetBase(torch.utils.data.Dataset): Returns: Batch of rendered images of shape (N, H, W, 3). """ - paths = self._handle_render_inputs(model_ids, categories, sample_nums, idxs) + idxs = self._handle_render_inputs(model_ids, categories, sample_nums, idxs) + paths = [ + path.join( + self.shapenet_dir, + self.synset_ids[idx], + self.model_ids[idx], + self.model_dir, + ) + for idx in idxs + ] meshes = load_objs_as_meshes(paths, device=device, load_textures=False) meshes.textures = TexturesVertex( verts_features=torch.ones_like(meshes.verts_padded(), device=device) ) cameras = kwargs.get("cameras", OpenGLPerspectiveCameras()).to(device) + if len(cameras) != 1 and len(cameras) % len(meshes) != 0: + raise ValueError("Mismatch between batch dims of cameras and meshes.") + if len(cameras) > 1: + # When rendering R2N2 models, if more than one views are provided, broadcast + # the meshes so that each mesh can be rendered for each of the views. + meshes = meshes.extend(len(cameras) // len(meshes)) renderer = MeshRenderer( rasterizer=MeshRasterizer( cameras=cameras, @@ -136,7 +151,7 @@ class ShapeNetBase(torch.utils.data.Dataset): categories: Optional[List[str]] = None, sample_nums: Optional[List[int]] = None, idxs: Optional[List[int]] = None, - ) -> List[str]: + ) -> List[int]: """ Helper function for converting user provided model_ids, categories and sample_nums to indices of models in the loaded dataset. If model idxs are provided, we check if @@ -206,15 +221,7 @@ class ShapeNetBase(torch.utils.data.Dataset): ) warnings.warn(msg) idxs = self._sample_idxs_from_category(sample_nums[0]) - return [ - path.join( - self.shapenet_dir, - self.synset_ids[idx], - self.model_ids[idx], - self.model_dir, - ) - for idx in idxs - ] + return idxs def _sample_idxs_from_category( self, sample_num: int = 1, category: Optional[str] = None diff --git a/pytorch3d/datasets/utils.py b/pytorch3d/datasets/utils.py index 5c2f4bc0..43243f5c 100644 --- a/pytorch3d/datasets/utils.py +++ b/pytorch3d/datasets/utils.py @@ -1,5 +1,5 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. - +import math from typing import Dict, List import torch @@ -34,10 +34,77 @@ def collate_batched_meshes(batch: List[Dict]): verts=collated_dict["verts"], faces=collated_dict["faces"] ) - # If collate_batched_meshes receives R2N2 items, stack the batches of - # views of each model into a new batch of shape (N, V, H, W, 3) where - # V is the number of views. + # If collate_batched_meshes receives R2N2 items with images and that + # all models have the same number of views V, stack the batches of + # views of each model into a new batch of shape (N, V, H, W, 3). + # Otherwise leave it as a list. if "images" in collated_dict: - collated_dict["images"] = torch.stack(collated_dict["images"]) + try: + collated_dict["images"] = torch.stack(collated_dict["images"]) + except RuntimeError: + print( + "Models don't have the same number of views. Now returning " + "lists of images instead of batches." + ) + + # If collate_batched_meshes receives R2N2 items with camera calibration + # matrices and that all models have the same number of views V, stack each + # type of matrices into a new batch of shape (N, V, ...). + # Otherwise leave them as lists. + if all(x in collated_dict for x in ["R", "T", "K"]): + try: + collated_dict["R"] = torch.stack(collated_dict["R"]) # (N, V, 3, 3) + collated_dict["T"] = torch.stack(collated_dict["T"]) # (N, V, 3) + collated_dict["K"] = torch.stack(collated_dict["K"]) # (N, V, 4, 4) + except RuntimeError: + print( + "Models don't have the same number of views. Now returning " + "lists of calibration matrices instead of batches." + ) return collated_dict + + +def compute_extrinsic_matrix(azimuth, elevation, distance): + """ + Copied from meshrcnn codebase: + https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/coords.py#L96 + + Compute 4x4 extrinsic matrix that converts from homogenous world coordinates + to homogenous camera coordinates. We assume that the camera is looking at the + origin. + Used in R2N2 Dataset when computing calibration matrices. + + Args: + azimuth: Rotation about the z-axis, in degrees. + elevation: Rotation above the xy-plane, in degrees. + distance: Distance from the origin. + + Returns: + FloatTensor of shape (4, 4). + """ + azimuth, elevation, distance = float(azimuth), float(elevation), float(distance) + + az_rad = -math.pi * azimuth / 180.0 + el_rad = -math.pi * elevation / 180.0 + sa = math.sin(az_rad) + ca = math.cos(az_rad) + se = math.sin(el_rad) + ce = math.cos(el_rad) + R_world2obj = torch.tensor( + [[ca * ce, sa * ce, -se], [-sa, ca, 0], [ca * se, sa * se, ce]] + ) + R_obj2cam = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]) + R_world2cam = R_obj2cam.mm(R_world2obj) + cam_location = torch.tensor([[distance, 0, 0]]).t() + T_world2cam = -(R_obj2cam.mm(cam_location)) + RT = torch.cat([R_world2cam, T_world2cam], dim=1) + RT = torch.cat([RT, torch.tensor([[0.0, 0, 0, 1]])]) + + # Georgia: For some reason I cannot fathom, when Blender loads a .obj file it + # rotates the model 90 degrees about the x axis. To compensate for this quirk we + # roll that rotation into the extrinsic matrix here + rot = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) + RT = RT.mm(rot.to(RT)) + + return RT diff --git a/tests/data/test_r2n2_render_with_blender_calibrations_0.png b/tests/data/test_r2n2_render_with_blender_calibrations_0.png new file mode 100644 index 0000000000000000000000000000000000000000..c9c169c99b1313f3dcc29cccffdf565b2e0cc3b8 GIT binary patch literal 3119 zcmeHJ`BxKJ7QR&}AYlsz#HdJEWKlr^hy+=J3J8(`bzI^`f}r#$+QAjtG%1u86cy|q z1(Y_ePDe+lHNuD}v=nZrD2VX@wjf~;AuTGJfRLDEDzTkEqR;$Nug-n%-22^k@Auue ze^Jr`JL^f-0AQC8AF~7iC4W(%nwH&@=Klgy{!{~2*ki`A;9VXFMqNt_`Q)V zZ{};2O7%!U6cs2`+chDuIxJjs6rYd;CA~bt#s~Oqed@Ke;~2DUzr`!fn{)}sxpkAy zXQ8-7*QPx)8h-mE0iqTYZbsok;FW7705LT1yE7tfqQTJ2u$3o(nDcPcMGt_*C=}J` zM!Iv><)Gl;1#+}@#}+qWM=GSxsUX^S&WQ>SH`gZUHJYiEP3qdbpz_4HB^xu}sBg)l zp4cNo@7&Vf(AnC#Q60Jfb*no(OmAroZYm2Onv*eBIEn+ip!Mz}YJ2d_kO(fK+1aIt#q;5M%{vA|rKDiil*xb&;DsAZ*Gr<3Cj1$WIJRuzZc=^q*$ zsP~p01$M`qG(VCCqVPv2kDCqAe+Ra8fTJ)2 zh0ge@3?Fl@WQK%)opB}pG95mK}<&efdMye zYXrr+e{(ly`X+OH>PD0}#FLc$2qGJWWZB!y10V_sM%&zAO{1n@Jr!24_3f_=ZB!@? znj#D|U*nqFBdj)05I1WAM5TK%Jh{w9J{?hd%D8?!$SzfN_oRfOo4O1K_@9YrOjIr9qMrgs&}nE(=wjP|^Q_;;Qbxwnkd?kKUWD>XtR~l(Sw}3%G>+kY4UDT@@x=?)o=n5hh$H;|}dW*(`mW0O2m;7W*bPg6AvB9k@m9aAc z>b|ldMx;miE#k$#aMBrXpN_Eg$cuQ4JQ=w>L~u11B77jj-<5Qtj4ZN(W9b>U)>2JO zH0dhuu8#G&9N@O`J#hT0n2E>3S4EJH7nmostVd7(=tYw0wG|=TAHu)i*`-B5>1Xcn zGW;W#EESw)AoAVl=n4;DhlDd{f0kxaj@Yx_NjeUbftI1d$reL#q~(lhD0s<(5Ijr& zX-qwkC?s2$JPM0X!KiE5HU8yHeO9easYz&aL!lK z%>eLoYry5=#0swmTegyW-}7Qj#|L-kxPZ`*mCiJNVkXg@WL!GAEE`>9&C$P+4hL|}1)qTKk6$Y07#+lneTHd~Qklo!M~@WTh| zLOWH1Aj;bEC1wJ%=LijSD{HNilaq_|F@7yG+~NXt8u2v)J@>#fNpxE;M)-=$ZAtdV z&FX48{r-}J!3%i8)2^f5`2JCQTs0+16!1>bQ)+!F3}OGezuAiIga|_lB;|wt4pl?e zg(d#4K0K4WdL|I(=H_N+kB+JPyDlMHda>wxmjjAiS*$l#nFrwU6$#zd)wSBNg6PA* zP}9@2_oJb@yIvwO8uZ literal 0 HcmV?d00001 diff --git a/tests/data/test_r2n2_render_with_blender_calibrations_1.png b/tests/data/test_r2n2_render_with_blender_calibrations_1.png new file mode 100644 index 0000000000000000000000000000000000000000..7c86933805f4485224516a27353792ccf807e9cf GIT binary patch literal 2792 zcmeHJ{Z|ub7QT}VGyxJKNJKs(ir~s(FcPB@i3U*|HrREy(qk6}C{oW_8?hK_A{jpl zR1p!{5Ka|Z-PRv!G1Rr7r81S;l;vw%+?E!^umNLF5fD_6YO<4egg;_W=a=`*JNLcw z+~+>eJ$YxFX0yO6$O`}jDO-~M48UQw9N@VZeUSQ70Iy3aNgMuBc;^1?uU@*9K410X z*3g+w@pPx=%0Y6)`#<;{k7(1k{iI*go+2I^?hhvSU3(?g7ake{4`EQY2;fByh$4Sq z{{JX=wELll>c$hp`m>sr^#+3>n5gvv=+&2`BxK^%9Gc{kFB2I-B%<{JgfN}8`*oT*RcN7 zb({>~(91O!dhQW6jT|zTx^dD(S%u!Q{9A9FYMB_Ja{NJ|SmbyEhw?OT(^jxqo|=4D z_(Q!+7Di8o4?gl3PD5YZmc_TNMmte9-?j!-`C0=$egYtg4@2pk=@j^23s(eGEJj5W zjDs1^{X3Um>W^4kMW|{}nT-J)A!uy}T7 zjKP+1=z=g@TgZW@rylEgZi1lYGM>bP9SYYr3D5z>6bFO)9;($JiX6S_V*tk7AcY)3 zv_+aIu0J@fyGw)^WMsuC{=l9Wm?*U$IPDec;{dDU!b460G=5JVo#%#Y%Pd~EcX0h0 zKQ{LD7_0=%5&CFNd1HQS$Wc^^)|a3$9%xoF1Zbnqw+=8%qM$e4eTiD@A2K#1{&@wP z;ma6o9r9nrE^W@T(T|i!OJK_4_lBbI*vX(1`?ud`*d)HELJ^44l^-{R>3+;$OK!i6 zj~sD-R(1N9GC@=4-6*6wV#fLGw|j)SG?L8c05LQDz8`#~YkY0QTm}DBldu_E_hy3_$R_ zsuomfmv^aw`gAwF1oe-?VlbJA!b=3xRRB3=q8N1=iW~5mK?efN!xrRlGLL{Z{GEf?9Bd#U^d0lY<_PggqV)v&8a|LHUCY%7!{jp z^|J^t#Zt!$h3^5#R(`!lh4L_J5`H`C)4vg@g*7>H2?}Q9w>f9uUFiR&j}c`w3nF&F zwI`XeR2#OJDLct0ZYL^RAw2ItzLWzUH$<1?P$Y9f!TUm{cZ3bpqGn5L9~hChsl`$y zrV{}v_3wyA;2Bq0^grYs=oBPW04(OyD(1KV4WYyD2fnM2HcpA^k~87bemNCi=&@>D zn$fG@q2|%I&UleZ4y1n7bXkd+t0&>&R-j%EJSay_bsMPaGw+_Et~-@1garh>3n!C> zV8Jl>ILGtnV)lG7L3`rl+EG0jJPSy=V9aQX%FE);9zd$~x4PGPbW;CA+Juu491z|x zp+>rgjJb3wk);B5Wj(!PMtR*%+dl5h({rQ?@}RfuIS;pW?2IF&cGnyMP&V|e+g*1g z^u&q28{8XC%(jM#Yj!Qv%>BZeT{VYMDoecGwS=js`~`PC%+w^ei~HYV<=z+POxrYf z=VYNv7VHk8Qfl6#t}yEMkaZqoW%!?kF2|DDj?>qe>?SsQgx%GFlNlPJg|$5S$IpA) z(13MMEQ{Be@wta$x~E&b5%le{WbD^@KxTPXf60o*^yny+|K0PvW zYvgTxfsV!D5$KxQGKY9-!kKr0eW8nUj#axM(L%F@isvyyqNlBX4v02t%=nn4r>W$D z3ub#lhF|PUy(xIh2O*>T0$sjpl)jK_cbzr_9m42|QGKfM$;5DL$*<}U9nM93xOh`t z)pj&8>`RiD(=hbI;pp3M||J#2{$hMnW{2ToY6*2I?dTIc3mZg1UG zLaMU#a3*e(bU7w}dGp##X4nk<*CL;XlG@tZBB#Bt@r`rB{u@(Jes|K*-JQ6~$u;J1 t73GO_k7}K?x4{=A|ML9(Kh|@rxj;E|`kjQfYGxP!q-@e8)o&z^{0}v2{YL-* literal 0 HcmV?d00001 diff --git a/tests/data/test_r2n2_render_with_blender_calibrations_2.png b/tests/data/test_r2n2_render_with_blender_calibrations_2.png new file mode 100644 index 0000000000000000000000000000000000000000..1cbda3d4c45a3559df87eae90e83b0552952392c GIT binary patch literal 3000 zcmeH}`%{x=7RS%~@)C%VYXHMV39<&p1vXp?p(4<72?z~XtH!F~r9fj>M3h2;5O@(Q zv?v#gT#A5N>(mmLivnVRhQfwAAe6xsSZ)DCM1c@nAe7ri+S$KhXV#z2dCv1ZGiSc@ z{hT>phH}_CTgVIO-M5HxeGu^4fK8QBjL5#;iCzI{wCf%1+&wU zmu`Fbg|H1*e2&D_+XZ2nY=o}}X%FTdz3fV9;#pc+9+p_(zB~=q4O6c^ zFAooKoW1lxBPn)w_UFVPSsK{?d)gffq%>{VvZy6XTJrlKXK#{LSi-Aqh5=`=`=xqy zeWM#avy!Scvj-Uv9XV3htZIY#EQk)jLoMr={N$BZ!MLCNU-l282G`Ow$1JWuC z_e)5{BR5$&shsd)vJmbU!KcTBiBoT!Z(BHW92`QAH> zR7GGX^>_BnF<1?RxhTHr1l$mz5 zy-4?=lO`_M%vISn`xFgb-ww%?WGPm#J;V9=&8CeXWR1twef4q^Ec#d#g1o;r9CJ9` zd^_rzxzNZ{3yjl~NvB#i6eVuyI%hTX+AwDE!GA7CH~bKyU)HNnJhHRqoNn^-{Hd3J z;Uw*XHC=8W7HSZV7PTHYFQzt%;&MA^X=EtE7i00k?YRKXiG7cI+|h3yeVn0pY3Sk& zhfl*8fG4^R_^4`@G$EQ@oMHpcNBo&rHKuyBp(;;asWiO?oqTQ3{Q%%52H!Jlf3yNY zbr~^X2!t&HnTAA?yC$l;@6EN^CqOHonl8tH{zt*|c91;wzb8jd8qO^%Q*3_aUMY7X zG?YPO+ZK~Fegj%Dh+iCre4yn*^wvRa8S)TVc!W zOuUU~?`~0I;<9q0edMr9F^x&@cx}QUtQvP*ZEkFw5n092D020vSdkoa=zF!GRGv{# zj$ENtS$8P>QRzbuyZaLaa1l;ztBeH{mE2=&ga^hRKs+$9F+$H zI6rfto{5B7k@%B&rH?W1FREdLiR3a!>?woB3^L&ofkZJt7_>F}C+m0{Dl|sf(ccGo zW%9G4xHt^!wQZq$=r1`VQWqJ@X~lpNdj4`l>$Y%|&{WH^lL}%#@}cqU<6n{C>GhKG29LpFo2e6nB2K`&yEUP+qlw#U|_zxOE{Ury+#Iq zz8$`c0jygelmA{XFjr@Pi9nX9!1Azk|8xi2JC?)$s--xlBQzLdIWp8U3MF!_W%A6v zWFrP>JDSYzp~92P7U_AFvsf^M^6h?hLO?;YB3U16Z72ppK%;EIK$V$eQC?(C$7^)< z(Ou0q(P)nlcxoYFvbegMJeB$$UN`m*;5A-uFOIHX=qUo?5s%07T-#IF5yjKUk8oB~ zRr9y7^|b-z^78VlzDmpJfe~L9Ck2Y$ZrHd#CBv&+sSsOedbF+@*GFL{5E7pG;eYT1 zIBMg{-?_Qnbn@vemp738$r{KHaaKe8fhhh>{$dj=!aX@}o~#Yb_SV+ckwR$`21U=O zr$18!;ijm1oG~$-4FITOi-OA61*@cBE7l4t$(Yhlm-SC;%M?0DBxbYy?yc0rD}>O9MEf z0fYw%H17A{_a^*rCgGwa*%nyVxy=V==(3fH5j;3*E-x?)xWj^CI&p5BPx}M#=0pAG z=|#&aq(*-;>vX}B3so*_Nwxya-4jkE>Epv5KkA%j+Y?UCnzsgS-@rq5htPo5r+wlENJo; z61*qjbo%G*f!gwuH<8GI@kx*$olCll_;9+SH;Wy=3td0fR+xdNs(|E~iIlygj3^Qe z1g+T4V!$yiHDQk8iqO5Q9$AF!M1@K@1`f^^#za|JS!;jDsHZ%c#o0xn2p=qNaHX8* z%_a;l&`p^-0bXFjj*1u*N`e(R8TYT?)>;}c@Ibsx&{=&1ymf=ZuWol5qzIi1p)H|k zVC@avEt*!DN@P=Aj_>})b|cv5q}I1+XF3%L?C?C%kb-cP$u;XVuWefW-j8z~?b+<; z9BK>)=oaNEi?8g1DPO_5CzQ7r7eLvpe=4f2PWwY(x;xTB3EO-S4Rob7oN89%fyc?e z30<#T@STt!wU(BqL*D#*Rm`ERnkxTJ3<}BxIi>Q%B%6{ssM*K1d+jeQ)z|!%5SzUL zA?6MhBu7X5B+6=JAzHz=^5b7E*n6sSGLi8{-^}2>!RIrN`aTe;?IJWVsoGZ5vwL7> z@K)|p2FN9<$_=5p6dsU_tv#^|bR&C$pX@>3uza-0U_g0q?>&ylgSK)Aep7l-Y&Ldv;%Yv09vN(Uk&_VCrvZ1A0PLj1}7U*-r*Jq)~_ zP{Sa$=#F3Rh%=8Y?c1(M^vt|A88;q0ia=muH*(^jY~|Ghl^sv40gA1%`OXWgNB&YK zX^>d*LE3H$a9d;#stA~-(*U;V0T{lk%rG~AutSiu(g46Ffd4h@p1}afS+PV>blxS0 z)5V?$2(1*4{>o)vmCQ^g5MjqB>O+NJ#jdKV0IrRygg(gfGZ2vJBOskSx zKSwSf=Td6l7Z7$6V{a&(ExkbzUt*k55V)z4P0N4_fx!t~#&n7-uy(m4ok5l*DsjqR zzza5V!UbkFpn3@q>3|bhh5Dh~a~SznP2gQ~mo9dSt6UF0G(x#(hV;rorryjv`7+Xz zfVG?io56t~Q!06LNzpu1o)uS|F|3E)q7lY6r5BbW9$6sf<3uJVwBthBB2bpMXjRXv z7*Uc5Yfu>*^@z~`YX%XN2Ay%xb{;d^kZrvcv=ypODVS{9*}IYn0d(S88jKqvyWMf3 zLxt2~&ErI!A96#|>%76n0LFZaJ6iikXO0yl_d0Q5$`OVl1k-ji=S&U3mQGl5!TC~H zE6!Q2o}`}zCM%+&eq?g;Hult>6?1W+pVIJR2J~1k$Z(XTaAD&!^yR!?!2ME0Lpp@V zlmL|m(N1`tsSl@!hs#oLtct;jQxTH+A^ISGNTqv;X2<(2M)oWbd@zErT}H%ejM%|O ztpA)2SKlW_Jma;@cq6WVVI2J$v@=!VFbz^oj66Q9(ZZiT){b%1MW}Z+v?LzTzPDA3 zm><&ca61<81S7GZ;$=!G$>u`CGxQfrL!rd)+G+;ei`cDdM+a%Ao-|pE(A{uvzk*}+ z80SX|qz;3bk&6WH4?^b5C2t<|voM13lXDh{7!c0Gwd$C(+xGbQ<*rTJQDRH@?nVWv zRPH()e?g2?O{D30)M|ENFkr!GkpdZ;M${{gyYCBgE)-wF;giTDPX!_Fv60#vN>~hg zt#g!MB?IQZoZj`$1@hPn9K}sc+VA<8Qfag&4_aa~h)^6x3$UVUCQVhJIyn&i$PLVn zlo0+n@Cq=KowHS5ZCHT{3l2FS{8R~9Yy;gCk_P76B) z^%#h&sh&|Nh+Xk|Md|OZ?!hxzpLC5txSyr1sSo5yF-`X&ljh;tq|s(`S+D|u(sVrR zL*E+;jCX&IY$x1sASoNI4S-{}am%ah2r9IN`saTcy(o*ct-Nz5n_sA#d~FQ^2md@J z3#t%5CaPLv0i5m>-|iUr*qRyb?{wF9x%!s{`!HY+{xKn0R;7FQObEUoqa;j(4Ms<- zl?+O