From 63ba74f1a8d5aa73d36999f5b5a7bf2af0fd8066 Mon Sep 17 00:00:00 2001 From: Luya Gao Date: Fri, 7 Aug 2020 13:21:26 -0700 Subject: [PATCH] Return R2N2 voxel coordinates Summary: Return R2N2's voxel coordinates. Reviewed By: nikhilaravi Differential Revision: D22462530 fbshipit-source-id: a995cfa0957b2561eb3b0f4591cb1db42170bc68 --- pytorch3d/datasets/__init__.py | 2 +- pytorch3d/datasets/r2n2/__init__.py | 3 +- pytorch3d/datasets/r2n2/r2n2.py | 126 +++-- pytorch3d/datasets/r2n2/utils.py | 483 ++++++++++++++++++ pytorch3d/datasets/utils.py | 76 --- tests/data/test_r2n2_voxel_to_mesh_render.png | Bin 0 -> 16918 bytes tests/test_r2n2.py | 47 +- 7 files changed, 602 insertions(+), 135 deletions(-) create mode 100644 pytorch3d/datasets/r2n2/utils.py create mode 100644 tests/data/test_r2n2_voxel_to_mesh_render.png diff --git a/pytorch3d/datasets/__init__.py b/pytorch3d/datasets/__init__.py index 16872130..f35ab33c 100644 --- a/pytorch3d/datasets/__init__.py +++ b/pytorch3d/datasets/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from .r2n2 import R2N2, BlenderCamera +from .r2n2 import R2N2, BlenderCamera, collate_batched_R2N2, render_cubified_voxels from .shapenet import ShapeNetCore from .utils import collate_batched_meshes diff --git a/pytorch3d/datasets/r2n2/__init__.py b/pytorch3d/datasets/r2n2/__init__.py index f18dc45d..b40be2c1 100644 --- a/pytorch3d/datasets/r2n2/__init__.py +++ b/pytorch3d/datasets/r2n2/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from .r2n2 import R2N2, BlenderCamera +from .r2n2 import R2N2 +from .utils import BlenderCamera, collate_batched_R2N2, render_cubified_voxels __all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/pytorch3d/datasets/r2n2/r2n2.py b/pytorch3d/datasets/r2n2/r2n2.py index ecf3fe62..3abe4927 100644 --- a/pytorch3d/datasets/r2n2/r2n2.py +++ b/pytorch3d/datasets/r2n2/r2n2.py @@ -10,20 +10,32 @@ import numpy as np import torch from PIL import Image from pytorch3d.datasets.shapenet_base import ShapeNetBase -from pytorch3d.datasets.utils import compute_extrinsic_matrix from pytorch3d.io import load_obj from pytorch3d.renderer import HardPhongShader -from pytorch3d.renderer.cameras import CamerasBase -from pytorch3d.transforms import Transform3d from tabulate import tabulate +from .utils import ( + BlenderCamera, + align_bbox, + compute_extrinsic_matrix, + read_binvox_coords, + voxelize, +) + SYNSET_DICT_DIR = Path(__file__).resolve().parent - -# Default values of rotation, translation and intrinsic matrices for BlenderCamera. -r = np.expand_dims(np.eye(3), axis=0) # (1, 3, 3) -t = np.expand_dims(np.zeros(3), axis=0) # (1, 3) -k = np.expand_dims(np.eye(4), axis=0) # (1, 4, 4) +MAX_CAMERA_DISTANCE = 1.75 # Constant from R2N2. +VOXEL_SIZE = 128 +# Intrinsic matrix extracted from Blender. Taken from meshrcnn codebase: +# https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/coords.py +BLENDER_INTRINSIC = torch.tensor( + [ + [2.1875, 0.0, 0.0, 0.0], + [0.0, 2.1875, 0.0, 0.0], + [0.0, 0.0, -1.002002, -0.2002002], + [0.0, 0.0, -1.0, 0.0], + ] +) class R2N2(ShapeNetBase): @@ -42,6 +54,7 @@ class R2N2(ShapeNetBase): r2n2_dir, splits_file, return_all_views: bool = True, + return_voxels: bool = False, ): """ Store each object's synset id and models id the given directories. @@ -54,6 +67,8 @@ class R2N2(ShapeNetBase): return_all_views (bool): Indicator of whether or not to load all the views in the split. If set to False, one of the views in the split will be randomly selected and loaded. + return_voxels(bool): Indicator of whether or not to return voxels as a tensor + of shape (D, D, D) where D is the number of voxels along each dimension. """ super().__init__() self.shapenet_dir = shapenet_dir @@ -83,6 +98,16 @@ class R2N2(ShapeNetBase): ) % (r2n2_dir) warnings.warn(msg) + self.return_voxels = return_voxels + # Check if the folder containing voxel coordinates is included in r2n2_dir. + if not path.isdir(path.join(r2n2_dir, "ShapeNetVox32")): + self.return_voxels = False + msg = ( + "ShapeNetVox32 not found in %s. Voxel coordinates will " + "be skipped when returning models." + ) % (r2n2_dir) + warnings.warn(msg) + synset_set = set() # Store lists of views of each model in a list. self.views_per_model_list = [] @@ -173,6 +198,8 @@ class R2N2(ShapeNetBase): - R: Rotation matrix of shape (V, 3, 3), where V is number of views returned. - T: Translation matrix of shape (V, 3), where V is number of views returned. - K: Intrinsic matrix of shape (V, 4, 4), where V is number of views returned. + - voxels: Voxels of shape (D, D, D), where D is the number of voxels along each + dimension. """ if isinstance(model_idx, tuple): model_idx, view_idxs = model_idx @@ -208,6 +235,7 @@ class R2N2(ShapeNetBase): model["label"] = self.synset_dict[model["synset_id"]] model["images"] = None + images, Rs, Ts, voxel_RTs = [], [], [], [] # Retrieve R2N2's renderings if required. if self.return_images: rendering_path = path.join( @@ -217,12 +245,9 @@ class R2N2(ShapeNetBase): model["model_id"], "rendering", ) - # Read metadata file to obtain params for calibration matrices. with open(path.join(rendering_path, "rendering_metadata.txt"), "r") as f: metadata_lines = f.readlines() - - images, Rs, Ts = [], [], [] for i in model_views: # Read image. image_path = path.join(rendering_path, "%02d.png" % i) @@ -234,9 +259,13 @@ class R2N2(ShapeNetBase): azim, elev, yaw, dist_ratio, fov = [ float(v) for v in metadata_lines[i].strip().split(" ") ] - R, T = self._compute_camera_calibration(azim, elev, dist_ratio) + dist = dist_ratio * MAX_CAMERA_DISTANCE + # Extrinsic matrix before transformation to PyTorch3D world space. + RT = compute_extrinsic_matrix(azim, elev, dist) + R, T = self._compute_camera_calibration(RT) Rs.append(R) Ts.append(T) + voxel_RTs.append(RT) # Intrinsic matrix extracted from the Blender with slight modification to work with # PyTorch3D world space. Taken from meshrcnn codebase: @@ -254,27 +283,48 @@ class R2N2(ShapeNetBase): model["T"] = torch.stack(Ts) model["K"] = K.expand(len(model_views), 4, 4) + voxels_list = [] + # Read voxels if required. + voxel_path = path.join( + self.r2n2_dir, + "ShapeNetVox32", + model["synset_id"], + model["model_id"], + "model.binvox", + ) + if self.return_voxels: + if not path.isfile(voxel_path): + msg = "Voxel file not found for model %s from category %s." + raise FileNotFoundError(msg % (model["model_id"], model["synset_id"])) + + with open(voxel_path, "rb") as f: + # Read voxel coordinates as a tensor of shape (N, 3). + voxel_coords = read_binvox_coords(f) + # Align voxels to the same coordinate system as mesh verts. + voxel_coords = align_bbox(voxel_coords, model["verts"]) + for RT in voxel_RTs: + # Compute projection matrix. + P = BLENDER_INTRINSIC.mm(RT) + # Convert voxel coordinates of shape (N, 3) to voxels of shape (D, D, D). + voxels = voxelize(voxel_coords, P, VOXEL_SIZE) + voxels_list.append(voxels) + model["voxels"] = torch.stack(voxels_list) + return model - def _compute_camera_calibration(self, azim: float, elev: float, dist_ratio: float): + def _compute_camera_calibration(self, RT): """ - Helper function for calculating rotation and translation matrices from azimuth - angle, elevation and distance ratio. + Helper function for calculating rotation and translation matrices from ShapeNet + to camera transformation and ShapeNet to PyTorch3D transformation. Args: - azim: Rotation about the z-axis, in degrees. - elev: Rotation above the xy-plane, in degrees. - dist_ratio: Ratio of distance from the origin to the maximum camera distance. + RT: Extrinsic matrix that performs ShapeNet world view to camera view + transformation. Returns: - - R: Rotation matrix of shape (3, 3). - - T: Translation matrix of shape (3). + R: Rotation matrix of shape (3, 3). + T: Translation matrix of shape (3). """ - # Retrive R,T,K of the selected view(s) by reading the metadata. - MAX_CAMERA_DISTANCE = 1.75 # Constant from R2N2. - dist = dist_ratio * MAX_CAMERA_DISTANCE - RT = compute_extrinsic_matrix(azim, elev, dist) - # Transform the mesh vertices from shapenet world to pytorch3d world. shapenet_to_pytorch3d = torch.tensor( [ @@ -285,9 +335,7 @@ class R2N2(ShapeNetBase): ], dtype=torch.float32, ) - RT = compute_extrinsic_matrix(azim, elev, dist) # (4, 4) RT = torch.transpose(RT, 0, 1).mm(shapenet_to_pytorch3d) # (4, 4) - # Extract rotation and translation matrices from RT. R = RT[:3, :3] T = RT[3, :3] @@ -348,27 +396,3 @@ class R2N2(ShapeNetBase): return super().render( idxs=idxs, shader_type=shader_type, device=device, cameras=cameras, **kwargs ) - - -class BlenderCamera(CamerasBase): - """ - Camera for rendering objects with calibration matrices from the R2N2 dataset - (which uses Blender for rendering the views for each model). - """ - - def __init__(self, R=r, T=t, K=k, device="cpu"): - """ - Args: - R: Rotation matrix of shape (N, 3, 3). - T: Translation matrix of shape (N, 3). - K: Intrinsic matrix of shape (N, 4, 4). - device: torch.device or str. - """ - # The initializer formats all inputs to torch tensors and broadcasts - # all the inputs to have the same batch dimension where necessary. - super().__init__(device=device, R=R, T=T, K=K) - - def get_projection_transform(self, **kwargs) -> Transform3d: - transform = Transform3d(device=self.device) - transform._matrix = self.K.transpose(1, 2).contiguous() # pyre-ignore[16] - return transform diff --git a/pytorch3d/datasets/r2n2/utils.py b/pytorch3d/datasets/r2n2/utils.py new file mode 100644 index 00000000..f82636b1 --- /dev/null +++ b/pytorch3d/datasets/r2n2/utils.py @@ -0,0 +1,483 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import math +from typing import Dict, List + +import numpy as np +import torch +from pytorch3d.datasets.utils import collate_batched_meshes +from pytorch3d.ops import cubify +from pytorch3d.renderer import ( + HardPhongShader, + MeshRasterizer, + MeshRenderer, + PointLights, + RasterizationSettings, + TexturesVertex, +) +from pytorch3d.renderer.cameras import CamerasBase +from pytorch3d.transforms import Transform3d + + +# Empirical min and max over the dataset from meshrcnn. +# https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/coords.py#L9 +SHAPENET_MIN_ZMIN = 0.67 +SHAPENET_MAX_ZMAX = 0.92 +# Threshold for cubify from meshrcnn: +# https://github.com/facebookresearch/meshrcnn/blob/master/configs/shapenet/voxmesh_R50.yaml#L11 +CUBIFY_THRESH = 0.2 + +# Default values of rotation, translation and intrinsic matrices for BlenderCamera. +r = np.expand_dims(np.eye(3), axis=0) # (1, 3, 3) +t = np.expand_dims(np.zeros(3), axis=0) # (1, 3) +k = np.expand_dims(np.eye(4), axis=0) # (1, 4, 4) + + +def collate_batched_R2N2(batch: List[Dict]): + """ + Take a list of objects in the form of dictionaries and merge them + into a single dictionary. This function can be used with a Dataset + object to create a torch.utils.data.Dataloader which directly + returns Meshes objects. + TODO: Add support for textures. + + Args: + batch: List of dictionaries containing information about objects + in the dataset. + + Returns: + collated_dict: Dictionary of collated lists. If batch contains both + verts and faces, a collated mesh batch is also returned. + """ + collated_dict = collate_batched_meshes(batch) + + # If collate_batched_meshes receives R2N2 items with images and that + # all models have the same number of views V, stack the batches of + # views of each model into a new batch of shape (N, V, H, W, 3). + # Otherwise leave it as a list. + if "images" in collated_dict: + try: + collated_dict["images"] = torch.stack(collated_dict["images"]) + except RuntimeError: + print( + "Models don't have the same number of views. Now returning " + "lists of images instead of batches." + ) + + # If collate_batched_meshes receives R2N2 items with camera calibration + # matrices and that all models have the same number of views V, stack each + # type of matrices into a new batch of shape (N, V, ...). + # Otherwise leave them as lists. + if all(x in collated_dict for x in ["R", "T", "K"]): + try: + collated_dict["R"] = torch.stack(collated_dict["R"]) # (N, V, 3, 3) + collated_dict["T"] = torch.stack(collated_dict["T"]) # (N, V, 3) + collated_dict["K"] = torch.stack(collated_dict["K"]) # (N, V, 4, 4) + except RuntimeError: + print( + "Models don't have the same number of views. Now returning " + "lists of calibration matrices instead of a batched tensor." + ) + + # If collate_batched_meshes receives voxels and all models have the same + # number of views V, stack the batches of voxels into a new batch of shape + # (N, V, S, S, S), where S is the voxel size. + if "voxels" in collated_dict: + try: + collated_dict["voxels"] = torch.stack(collated_dict["voxels"]) + except RuntimeError: + print( + "Models don't have the same number of views. Now returning " + "lists of voxels instead of a batched tensor." + ) + return collated_dict + + +def compute_extrinsic_matrix(azimuth, elevation, distance): + """ + Copied from meshrcnn codebase: + https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/coords.py#L96 + + Compute 4x4 extrinsic matrix that converts from homogenous world coordinates + to homogenous camera coordinates. We assume that the camera is looking at the + origin. + Used in R2N2 Dataset when computing calibration matrices. + + Args: + azimuth: Rotation about the z-axis, in degrees. + elevation: Rotation above the xy-plane, in degrees. + distance: Distance from the origin. + + Returns: + FloatTensor of shape (4, 4). + """ + azimuth, elevation, distance = float(azimuth), float(elevation), float(distance) + + az_rad = -math.pi * azimuth / 180.0 + el_rad = -math.pi * elevation / 180.0 + sa = math.sin(az_rad) + ca = math.cos(az_rad) + se = math.sin(el_rad) + ce = math.cos(el_rad) + R_world2obj = torch.tensor( + [[ca * ce, sa * ce, -se], [-sa, ca, 0], [ca * se, sa * se, ce]] + ) + R_obj2cam = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]) + R_world2cam = R_obj2cam.mm(R_world2obj) + cam_location = torch.tensor([[distance, 0, 0]]).t() + T_world2cam = -(R_obj2cam.mm(cam_location)) + RT = torch.cat([R_world2cam, T_world2cam], dim=1) + RT = torch.cat([RT, torch.tensor([[0.0, 0, 0, 1]])]) + + # Georgia: For some reason I cannot fathom, when Blender loads a .obj file it + # rotates the model 90 degrees about the x axis. To compensate for this quirk we + # roll that rotation into the extrinsic matrix here + rot = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) + RT = RT.mm(rot.to(RT)) + + return RT + + +def read_binvox_coords(f, integer_division=True, dtype=torch.float32): + """ + Copied from meshrcnn codebase: + https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/binvox_torch.py#L5 + + Read a binvox file and return the indices of all nonzero voxels. + + This matches the behavior of binvox_rw.read_as_coord_array + (https://github.com/dimatura/binvox-rw-py/blob/public/binvox_rw.py#L153) + but this implementation uses torch rather than numpy, and is more efficient + due to improved vectorization. + + Georgia: I think that binvox_rw.read_as_coord_array actually has a bug; when converting + linear indices into three-dimensional indices, they use floating-point + division instead of integer division. We can reproduce their incorrect + implementation by passing integer_division=False. + + Args: + f (str): A file pointer to the binvox file to read + integer_division (bool): If False, then match the buggy implementation from binvox_rw + dtype: Datatype of the output tensor. Use float64 to match binvox_rw + + Returns: + coords (tensor): A tensor of shape (N, 3) where N is the number of nonzero voxels, + and coords[i] = (x, y, z) gives the index of the ith nonzero voxel. If the + voxel grid has shape (V, V, V) then we have 0 <= x, y, z < V. + """ + size, translation, scale = _read_binvox_header(f) + storage = torch.ByteStorage.from_buffer(f.read()) + data = torch.tensor([], dtype=torch.uint8) + data.set_(source=storage) + vals, counts = data[::2], data[1::2] + idxs = _compute_idxs(vals, counts) + if not integer_division: + idxs = idxs.to(dtype) + x_idxs = idxs // (size * size) + zy_idxs = idxs % (size * size) + z_idxs = zy_idxs // size + y_idxs = zy_idxs % size + coords = torch.stack([x_idxs, y_idxs, z_idxs], dim=1) + return coords.to(dtype) + + +def _compute_idxs(vals, counts): + """ + Copied from meshrcnn codebase: + https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/binvox_torch.py#L58 + + Fast vectorized version of index computation. + + Args: + vals: tensor of binary values indicating voxel presence in a dense format. + counts: tensor of number of occurence of each value in vals. + + Returns: + idxs: A tensor of shape (N), where N is the number of nonzero voxels. + """ + # Consider an example where: + # vals = [0, 1, 0, 1, 1] + # counts = [2, 3, 3, 2, 1] + # + # These values of counts and vals mean that the dense binary grid is: + # [0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1] + # + # So the nonzero indices we want to return are: + # [2, 3, 4, 8, 9, 10] + + # After the cumsum we will have: + # end_idxs = [2, 5, 8, 10, 11] + end_idxs = counts.cumsum(dim=0) + + # After masking and computing start_idx we have: + # end_idxs = [5, 10, 11] + # counts = [3, 2, 1] + # start_idxs = [2, 8, 10] + mask = vals == 1 + end_idxs = end_idxs[mask] + counts = counts[mask].to(end_idxs) + start_idxs = end_idxs - counts + + # We initialize delta as: + # [2, 1, 1, 1, 1, 1] + delta = torch.ones(counts.sum().item(), dtype=torch.int64) + delta[0] = start_idxs[0] + + # We compute pos = [3, 5], val = [3, 0]; then delta is + # [2, 1, 1, 4, 1, 1] + pos = counts.cumsum(dim=0)[:-1] + val = start_idxs[1:] - end_idxs[:-1] + delta[pos] += val + + # A final cumsum gives the idx we want: [2, 3, 4, 8, 9, 10] + idxs = delta.cumsum(dim=0) + return idxs + + +def _read_binvox_header(f): + """ + Copied from meshrcnn codebase: + https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/binvox_torch.py#L99 + + Read binvox header and extract information regarding voxel sizes and translations + to original voxel coordinates. + + Args: + f (str): A file pointer to the binvox file to read. + + Returns: + size (int): size of voxel. + translation (tuple(float)): translation to original voxel coordinates. + scale (float): scale to original voxel coordinates. + """ + # First line of the header should be "#binvox 1" + line = f.readline().strip() + if line != b"#binvox 1": + raise ValueError("Invalid header (line 1)") + + # Second line of the header should be "dim [int] [int] [int]" + # and all three int should be the same + line = f.readline().strip() + if not line.startswith(b"dim "): + raise ValueError("Invalid header (line 2)") + dims = line.split(b" ") + try: + dims = [int(d) for d in dims[1:]] + except ValueError: + raise ValueError("Invalid header (line 2)") + if len(dims) != 3 or dims[0] != dims[1] or dims[0] != dims[2]: + raise ValueError("Invalid header (line 2)") + size = dims[0] + + # Third line of the header should be "translate [float] [float] [float]" + line = f.readline().strip() + if not line.startswith(b"translate "): + raise ValueError("Invalid header (line 3)") + translation = line.split(b" ") + if len(translation) != 4: + raise ValueError("Invalid header (line 3)") + try: + translation = tuple(float(t) for t in translation[1:]) + except ValueError: + raise ValueError("Invalid header (line 3)") + + # Fourth line of the header should be "scale [float]" + line = f.readline().strip() + if not line.startswith(b"scale "): + raise ValueError("Invalid header (line 4)") + line = line.split(b" ") + if not len(line) == 2: + raise ValueError("Invalid header (line 4)") + scale = float(line[1]) + + # Fifth line of the header should be "data" + line = f.readline().strip() + if not line == b"data": + raise ValueError("Invalid header (line 5)") + + return size, translation, scale + + +def align_bbox(src, tgt): + """ + Copied from meshrcnn codebase: + https://github.com/facebookresearch/meshrcnn/blob/master/tools/preprocess_shapenet.py#L263 + + Return a copy of src points in the coordinate system of tgt by applying a + scale and shift along each coordinate axis to make the min / max values align. + + Args: + src, tgt: Torch Tensor of shape (N, 3) + + Returns: + out: Torch Tensor of shape (N, 3) + """ + if src.ndim != 2 or tgt.ndim != 2: + raise ValueError("Both src and tgt need to have dimensions of 2.") + if src.shape[-1] != 3 or tgt.shape[-1] != 3: + raise ValueError( + "Both src and tgt need to have sizes of 3 along the second dimension." + ) + src_min = src.min(dim=0)[0] + src_max = src.max(dim=0)[0] + tgt_min = tgt.min(dim=0)[0] + tgt_max = tgt.max(dim=0)[0] + scale = (tgt_max - tgt_min) / (src_max - src_min) + shift = tgt_min - scale * src_min + out = scale * src + shift + return out + + +def voxelize(voxel_coords, P, V): + """ + Copied from meshrcnn codebase: + https://github.com/facebookresearch/meshrcnn/blob/master/tools/preprocess_shapenet.py#L284 + but changing flip y to flip x. + + Creating voxels of shape (D, D, D) from voxel_coords and projection matrix. + + Args: + voxel_coords: FloatTensor of shape (V, 3) giving voxel's coordinates aligned to + the vertices. + P: FloatTensor of shape (4, 4) giving the projection matrix. + V: Voxel size of the output. + + Returns: + voxels: Tensor of shape (D, D, D) giving the voxelized result. + """ + device = voxel_coords.device + voxel_coords = project_verts(voxel_coords, P) + + # Using the actual zmin and zmax of the model is bad because we need them + # to perform the inverse transform, which transform voxels back into world + # space for refinement or evaluation. Instead we use an empirical min and + # max over the dataset; that way it is consistent for all images. + zmin = SHAPENET_MIN_ZMIN + zmax = SHAPENET_MAX_ZMAX + + # Once we know zmin and zmax, we need to adjust the z coordinates so the + # range [zmin, zmax] instead runs from [-1, 1] + m = 2.0 / (zmax - zmin) + b = -2.0 * zmin / (zmax - zmin) - 1 + voxel_coords[:, 2].mul_(m).add_(b) + voxel_coords[:, 0].mul_(-1) # Flip x + + # Now voxels are in [-1, 1]^3; map to [0, V-1)^3 + voxel_coords = 0.5 * (V - 1) * (voxel_coords + 1.0) + voxel_coords = voxel_coords.round().to(torch.int64) + valid = (0 <= voxel_coords) * (voxel_coords < V) + valid = valid[:, 0] * valid[:, 1] * valid[:, 2] + x, y, z = voxel_coords.unbind(dim=1) + x, y, z = x[valid], y[valid], z[valid] + voxels = torch.zeros(V, V, V, dtype=torch.uint8, device=device) + voxels[z, y, x] = 1 + + return voxels + + +def project_verts(verts, P, eps=1e-1): + """ + Copied from meshrcnn codebase: + https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/coords.py#L159 + + Project verticies using a 4x4 transformation matrix. + + Args: + verts: FloatTensor of shape (N, V, 3) giving a batch of vertex positions or of + shape (V, 3) giving a single set of vertex positions. + P: FloatTensor of shape (N, 4, 4) giving projection matrices or of shape (4, 4) + giving a single projection matrix. + + Returns: + verts_out: FloatTensor of shape (N, V, 3) giving vertex positions (x, y, z) + where verts_out[i] is the result of transforming verts[i] by P[i]. + """ + # Handle unbatched inputs + singleton = False + if verts.dim() == 2: + assert P.dim() == 2 + singleton = True + verts, P = verts[None], P[None] + + N, V = verts.shape[0], verts.shape[1] + dtype, device = verts.dtype, verts.device + + # Add an extra row of ones to the world-space coordinates of verts before + # multiplying by the projection matrix. We could avoid this allocation by + # instead multiplying by a 4x3 submatrix of the projectio matrix, then + # adding the remaining 4x1 vector. Not sure whether there will be much + # performance difference between the two. + ones = torch.ones(N, V, 1, dtype=dtype, device=device) + verts_hom = torch.cat([verts, ones], dim=2) + verts_cam_hom = torch.bmm(verts_hom, P.transpose(1, 2)) + + # Avoid division by zero by clamping the absolute value + w = verts_cam_hom[:, :, 3:] + w_sign = w.sign() + w_sign[w == 0] = 1 + w = w_sign * w.abs().clamp(min=eps) + + verts_proj = verts_cam_hom[:, :, :3] / w + + if singleton: + return verts_proj[0] + return verts_proj + + +class BlenderCamera(CamerasBase): + """ + Camera for rendering objects with calibration matrices from the R2N2 dataset + (which uses Blender for rendering the views for each model). + """ + + def __init__(self, R=r, T=t, K=k, device="cpu"): + """ + Args: + R: Rotation matrix of shape (N, 3, 3). + T: Translation matrix of shape (N, 3). + K: Intrinsic matrix of shape (N, 4, 4). + device: torch.device or str. + """ + # The initializer formats all inputs to torch tensors and broadcasts + # all the inputs to have the same batch dimension where necessary. + super().__init__(device=device, R=R, T=T, K=K) + + def get_projection_transform(self, **kwargs) -> Transform3d: + transform = Transform3d(device=self.device) + transform._matrix = self.K.transpose(1, 2).contiguous() # pyre-ignore[16] + return transform + + +def render_cubified_voxels( + voxels: torch.Tensor, shader_type=HardPhongShader, device="cpu", **kwargs +): + """ + Use the Cubify operator to convert inputs voxels to a mesh and then render that mesh. + + Args: + voxels: FloatTensor of shape (N, D, D, D) where N is the batch size and + D is the number of voxels along each dimension. + shader_type: shader_type: shader_type: Shader to use for rendering. Examples + include HardPhongShader (default), SoftPhongShader etc or any other type + of valid Shader class. + device: torch.device on which the tensors should be located. + **kwargs: Accepts any of the kwargs that the renderer supports. + Returns: + Batch of rendered images of shape (N, H, W, 3). + """ + cubified_voxels = cubify(voxels, CUBIFY_THRESH).to(device) + cubified_voxels.textures = TexturesVertex( + verts_features=torch.ones_like(cubified_voxels.verts_padded(), device=device) + ) + cameras = BlenderCamera(device=device) + renderer = MeshRenderer( + rasterizer=MeshRasterizer( + cameras=cameras, + raster_settings=kwargs.get("raster_settings", RasterizationSettings()), + ), + shader=shader_type( + device=device, + cameras=cameras, + lights=kwargs.get("lights", PointLights()).to(device), + ), + ) + return renderer(cubified_voxels) diff --git a/pytorch3d/datasets/utils.py b/pytorch3d/datasets/utils.py index 43243f5c..5d8dd2ae 100644 --- a/pytorch3d/datasets/utils.py +++ b/pytorch3d/datasets/utils.py @@ -1,8 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -import math from typing import Dict, List -import torch from pytorch3d.structures import Meshes @@ -33,78 +31,4 @@ def collate_batched_meshes(batch: List[Dict]): collated_dict["mesh"] = Meshes( verts=collated_dict["verts"], faces=collated_dict["faces"] ) - - # If collate_batched_meshes receives R2N2 items with images and that - # all models have the same number of views V, stack the batches of - # views of each model into a new batch of shape (N, V, H, W, 3). - # Otherwise leave it as a list. - if "images" in collated_dict: - try: - collated_dict["images"] = torch.stack(collated_dict["images"]) - except RuntimeError: - print( - "Models don't have the same number of views. Now returning " - "lists of images instead of batches." - ) - - # If collate_batched_meshes receives R2N2 items with camera calibration - # matrices and that all models have the same number of views V, stack each - # type of matrices into a new batch of shape (N, V, ...). - # Otherwise leave them as lists. - if all(x in collated_dict for x in ["R", "T", "K"]): - try: - collated_dict["R"] = torch.stack(collated_dict["R"]) # (N, V, 3, 3) - collated_dict["T"] = torch.stack(collated_dict["T"]) # (N, V, 3) - collated_dict["K"] = torch.stack(collated_dict["K"]) # (N, V, 4, 4) - except RuntimeError: - print( - "Models don't have the same number of views. Now returning " - "lists of calibration matrices instead of batches." - ) - return collated_dict - - -def compute_extrinsic_matrix(azimuth, elevation, distance): - """ - Copied from meshrcnn codebase: - https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/coords.py#L96 - - Compute 4x4 extrinsic matrix that converts from homogenous world coordinates - to homogenous camera coordinates. We assume that the camera is looking at the - origin. - Used in R2N2 Dataset when computing calibration matrices. - - Args: - azimuth: Rotation about the z-axis, in degrees. - elevation: Rotation above the xy-plane, in degrees. - distance: Distance from the origin. - - Returns: - FloatTensor of shape (4, 4). - """ - azimuth, elevation, distance = float(azimuth), float(elevation), float(distance) - - az_rad = -math.pi * azimuth / 180.0 - el_rad = -math.pi * elevation / 180.0 - sa = math.sin(az_rad) - ca = math.cos(az_rad) - se = math.sin(el_rad) - ce = math.cos(el_rad) - R_world2obj = torch.tensor( - [[ca * ce, sa * ce, -se], [-sa, ca, 0], [ca * se, sa * se, ce]] - ) - R_obj2cam = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]) - R_world2cam = R_obj2cam.mm(R_world2obj) - cam_location = torch.tensor([[distance, 0, 0]]).t() - T_world2cam = -(R_obj2cam.mm(cam_location)) - RT = torch.cat([R_world2cam, T_world2cam], dim=1) - RT = torch.cat([RT, torch.tensor([[0.0, 0, 0, 1]])]) - - # Georgia: For some reason I cannot fathom, when Blender loads a .obj file it - # rotates the model 90 degrees about the x axis. To compensate for this quirk we - # roll that rotation into the extrinsic matrix here - rot = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) - RT = RT.mm(rot.to(RT)) - - return RT diff --git a/tests/data/test_r2n2_voxel_to_mesh_render.png b/tests/data/test_r2n2_voxel_to_mesh_render.png new file mode 100644 index 0000000000000000000000000000000000000000..d458298ac6681aedb00559eee7da55290305bae6 GIT binary patch literal 16918 zcmeHvXFQdE`2T&5y=TZgA%v`~Bzq?!A<14LN%lPU%1V;5B?;NdCJCt=WF?t}PBz*6 zF5hSW=l=)42fy>c7q6Ui-=F*XT-SSjzOmQzG^xm0$q@vh($-QlL=Y5wi$X}z@Sj0i zMR5dCp3qiPHuB9}3i9#1HSeduCm!?Vs*XfD@Aay)H`kqpFM18m`o6k!RkL9r?M=;{ zKZ!M0XB%o$8;o7AtUZ>>$EiNCzxm`H?2C4^U&*B}!^O208<^D};uOFzW2 zCHx95ktZd0^>WB6cCz{1JHnNHg_WbZ3|)$@@PH1 zdiXSp8{>K;>f^_cFXA(F9RdRb&$~tEUBmR*Nw@^j^iIJWy12R?luhoPKga)VWaRkY zy<=$$M~RFCHkqS#u7h^7Nzo2}^%gNKqsbfd6cGoBj1rUbM%~UW($YsC?;Q{u|H0#nUcMYytT8KlQ&JM%he_=TUJd7Z{P^+TiAJBYV|3J3 z^npE+q02}5eqljick6qt%c7?lUuvP@x$llopFU+_Wi=}ZurrC+3_m_z=#wuPTs)f7Kbre^XivEt(?lAsb0*Hl*j$TI-5t%r{46tb)4y3F|BzS2gcb|=w zHJM5__~70667iLT{nq1vyLZcmoP*|4H0>nDQg}3VbaaA`=0hy_Ql(^Mo_}$OX@_6B zN{@>kycZSTSDu)d=y%%O$T^m`c)-40t+-zeH=5=1^PHfdppeilGc*64>1dNf3M}t%G;GeEE zZRj;=5R#d3l{YpvUXqh*Ib6C|KC#_i!P(OaCX?jIaOLh_&J$7Nw(tGTP!%;NN9n+Zd z@$qq=rSG*&mPp5I?j%Abv-ypk*(67a)32si0_QI+da};T@6M(6*gaCHuBqAV*W@~z zl2ab|O9<{^w9?k0diU$4mD1O*_cn_Z6ADh9I@Q_PsZp?&9=w=dG4So1=K}6bYERAm zSq+c&tV=$xjgGf}Wy*Q!lo?mccbt!v-$EebL6hH3lEb`CYzv zF_wAc&Fu$rSxoWv4h~W5TmH@PGml04&*lJ44ULyYMbS}F@U_fXXqbb#tGk=jqJ!(* zoshE~X52}ESpJ+qkKQNeEe~dxdp`~Af=ff#-t&=vD zn13Og3kKC=pEsJ?tpr3xlQar+3Lnk;<@7m9vJ!-1d+grZ({Rc=NwO+SB`F@g`{FRT z=;;rIH6oOzQ$A8Mww;}K z8U@<)?Pt%Z) zvW^})aZK7I4u@=&4msoh{Bhq*ytJI#X}j_J3t#Huhgb&x=HQMfs$ai;$z_S8Xf^%& z?xImZmoDeG{=&F~gN<$c?OnFOsMuIC*iwzH`;As);qC6nhufOmbmrABU%q_zPBu+v z(NpdRlAsz*eMgeDMo6sGWv&=OV%5T&4t?eIY>we}%ZyV8V}@l|C~v)WU0hi)DK>Jh zm1Bx8HWFmil+Dr$Z*RNL{xkoYfSLjCkb+e#+;pK~?5Sr|FRZST)={p1CGnk!<@c1E zkZDt0avTn>ZH8m9)t}H)i5;HS14?ph5kJ7c&!6A=7ga6NS<#a0Ly4R;gM5M4Lw{fxTJEMI1WNB5#>xT3bohnz1a4m#Hohv&f?EqS}tn$;9> z$op2jdKELg*Oka|U4YTOB_-Oyl)3v7F_MhGS~RFv#DJ!aRU;-L=`k?;8>?bqx^dyZ z7t|AHlBzt+l%W!nUmkvZsi8TD6b%>F4$)ed6}pX#y1q3MRL|krqV4VyFBP5RAjthn zlJ~qQC%7Rm5B0C-y8h)GmX=A9V}tsN=SjnK3SXC%4J~?Btfcn5&?q2_?obQqcoxlY z@R99a8-vn8QB6%vQ_~MCvCp4BD-Gg38>rZ%)wolpr>B)L*a8^XMT2Bv*NlulY&7*d zMM!d9_BSdsR(%!Lo@-Wi)I7O=d7YpAlDDo%ddDrr8%Kh-Vf}*ZuJD za9SjLUOW3K;R7w#Xthw;#OozP#YW_(ZEDS^P%IoAXU(_*CB()1aX2NDw7tNV*RNkI zVPI4+{r9R`zFN>fMU<}*u)W;uKCL3m)GL?ZKyWTcZ0^*_h(12(J9c$-{ZODkyHq!6 zEN}u~=+q4W{N&uQuxP2MsJJ^W-E3iLX!w4kX?)4T-26v_myq$yn(SUvBA4QuA?Jer zR~F?YiUKf%bWSB`NU>7&ZOB5oLKDJ}HhM3`sd2{;amaaUl?*vZoUS$FNYLp0^hqTz z*|7Lw*Q{XL>&i-TM)Irr`a+CxfBrbfcdi%f2a&sC;<$IFVTy)f4%T4bu?7`W_F`68 zkQj{uwe3Lx5s|NU67qe6@z;hrIy%hCm}-@@W)c$<1=BtSEZ4*4o6@ZtPgtNm^mZpY zAoOo8E?yicyz*T7?)!4jzsJy_KSy%~YvssayjbsrH{K!ccY17OJm}h1MdBADZ;!C* zU*39bE|T8+Yinz3^=3}+-s1YYM@}D%R)8r`(#9pf{8}0s{5?J4;WyRPNU|mnqfuqX zmlPB%jE#S<`F`r{rQ3yKk$!&n{{4Q3>f#~i83L!;?*JL)*s05mpX(L1ld&f7F_Pc? zUekl$qSwdJ$G#f+Q+BHk`j#iBjEoGl)5rWAP%k3s@%f|*J$4cyo0=hn(zJx3iF6zn zJ$O^JD#jRG{<~ha$xH0iq)6BO4^(Xw%@#R&MYC%=-@ZNWo8SLdq(GE|lFEyryQ-_} z`r{p4J3GtX%?Dhtn-lGM)IIHD zt*=6BsJJ)T^5wZq$Ah3h#YMi^{B&m`Rbci5iY+&<5Bt{*qpH^IvYec=^z(~}0LP#X z;id9}_}6{kfa}*8HY@dt!Z~uZcmciGR6d2d!j}qV6aZ%jO?3NuSidm zMMU{#L#vKeQzDFK-0EzH0NkB=gn3~ZUtOE^oqTG=#e#L`bI4iZv3r#hrym` z!f<8wEnoTg`1lyKGJ+Mxa~be>l6AC%DJIbO@9z5D83g#9$`_tu%h#%>tBWCywCv%3 zfAQt>(^z}IaBHa)A4*JTOVBfwWJ$eAi7(WS|ai-GVW@mdhpUj+CuI z=lA(>KU$+8w$sL}>}_kSC12{&XeG}TJkRfNiA62m{C=l4Hi0ZsLUe6yEd>Pyj98$J z#YWt14SpNHcXxLO1_s<^q10X&38oxMJ^vCE6jV3i{oGu{L@>=NPmP<7I$GTT$V3<> zLgkDyZSe7S2z&;hvc&n_!oor|?ls4de~wn+zVe9@{JGWD4?I0L`(Xc%Y~ouxPi{n= zm%eg*i59V`?TGP_%aY5YMq^WG`BIDX^WXpSpyNo|;GougGR*%-AMeWA%M{@T@Ob(q zv^(5oGLQNuXx5{fCgm1iKGECl5-8-{^YD1&K(i?M_xNz~`0&N#>5eY-Y>75t8P_c> z^^+b1{2Oz`L-klRn*pF+T?i#0advUJA$UWcGx|3lLvSMJWivCgKU@CLE*|ypUi`DY zZ$RZ-3srX4z<};2eq)!CI9w$ULy0tUDQvIJcI7SJ5KM#R5MTsMUuXzJK2n8_lGu8= z_m_uf>fvg%;^Q7Whw7V*ozNiTY_4C2v9i8@$a0s1lrVQBiQ^K$)Yl~?fW&0k6V>qn&(M*zY_+5 zF>h&VI^`E^ZGGknKA`!hul(}%excb-a?-X>JZCrqSKI&n^Y)LlOYhk_3_b?zS=cy% z5s_3>)bDG>5a~%)4Nz0cko$c?+5iaN&eGd93E48S`_&do;N=DIRuv$)=Pp%^!e+Bs zS&R3Q8PF1Qkr5}!iYb16e&8Y3{$d5L$#cSQHa0fw?d>hf>t4R(OM}fLn8_92i8CxP zE;&U`^m)%+mYtFqIHN>{n(h4?H-_H74>K$l78Fz&eqnpbqVe zEDZa%%A%Z^Ft51yj+^nP&n`UOubPBmswB$L!{JXdzQTP%B06*jyZRooRnI~l=VwjCR@1_kcI%Cp z`pB}^y!{%t$&ma&VY|NX@MMEmg>+Vp&_s9m{Pdd#oc`eh;X4~^>(Cu!aBwhRKqn~a z=;)|aGVI$0>&#WMzFg zX44jM3Zj>a(fO6{*L>OWnn|1hn_*MrUDGnV{_V#PXpkRdt%>H& z5g<1i^~*`nq4wKNN_aVy3{xNgMVG>3jisM|k&=>vtCsIedic<8Y(b+yoHgN5tS7d{ zMv}Sey)5)+8V6E?z7iT!hkmO6X_f%g73~)Q@w;r1xGgW~3(t za8DLtCvh%r1{xB)$rQm;UA_v`-sXjjheOW4{1i^+RnMbGkBkHXjKM~O7Qmf?CG(v5 zK$g=NEx-2sUd!nOjm583Fe=?;%`GjVDJ(RsaUD#N1h=+URY}m2PBlwDd;T17dh5|v zdHJVSHx5ksX#|s!Bdnryb>aPcbX>+6WL2u%vUPdOpRQQ#)KQV*;o_S&Z%(}o@t4o8 ztgHkaxA-`^FXU+Br1zgMo7>fbmTK-7)f}+>_U+r5g}FK2HI#z`2xXLn*q?h6Fn_&i zh{JhZ$*1mpu5CX(8zBr&kX_b27xbgNXwFA(Muvv2>gj>}@d)+x%a$C=rly9=ZY5~g%4H>X+PJGr-~CP)pEJL!#%*x(CM+gQ1MZZHG1pi%bXHe+ z>yf`CtGcK%&@6dQkdWzPzYPs_+K9iPc6N5QsbnGiy)|M9H-g&(Wl6rz_HcjlxJ9F& zz1!$c?P9!Wi$HoW#asR}`lk^yOF<-pYN6sVO~?fILd6)c5(e1;g>FWfnXr9!4MU_M z@nLg|PIRxEa*h0^U2zV!EOBuQ78lkq`*HwN>p*6>7BwOEjTAsI9c(f+WQY5?(p`o*-rEz!pTODu?G>2}UG^>_G)gTZJLG7$p`igJN+xDz{MH^2u}BpZ z!@}o`3%tKCC@64}M@rxO=_>u)poE~AItLZb5FZ!UM!ae8E!eEgxQd5C=cRZCg=796 z)g2N8#_4)zN!AL~IueHYU?Z+f3q|?Xwky$7E`dZQwa4{mlb@TL2vaA_ z0|9oYw5c?`b7!&275Du+YkX(L&!0aN<0Cqzp+C|}BNictTS^$*=qL+Rs@K%c+AOTM zp<-SxOb;c3wP1gk68(2>dqyZDpp?*VgeG<|Spl83^H6)^e63icwQ`)uZpo@r5-6CIais{yn}D5Bcme_f-uu5db^%lz=FZ+`@3JQ&ZhPC@<3l{!8qzwsQnXDS zFdfMZ3~kI+uV|3BY2xS4>>j&JJ^nkn?YC$hK?8*udLtA1kL|Ry>!|zEcd7gp)x2a- zR`FYEld4*6WUp2MuES(ZD>xwYLlS3`uW#%I(nbNtF%^z#SgZT# z5Bwlh+{x3&r|-)b?Z`Z=4P-bLiW^hz0 zG`uOywE4ck*7RBNmmxOHuY?@h8X8s0e=yY+{``5AE&7FEs@9^M|6+7#Xee9|Dz~pF z@0!4<*mYLFFGw4OXwx+l6Q9GqEx_5xzvAmfuo{DlrLSHIoI9tycA+Nt!GjCm%Nfau zbP6Ve_wSYr$w^7wFf>$|QG|KPLTM)a1oY67p|}^qlxMQ%#e+Fsl75r^LZ(wlpNTlt znmPF!Fj{YS&EBWq*9Y+L*#JUMmL7kq_9g+wf>4z!!-}4$%B+mH)bC*De)c>-!{p%K zew>%+GiC3G&Ay9^BX`FO;7wafry~X zRfNa^rWr&v&IFCZ-QQ?siwh^~2=yJ)?s}rGynemz-`stlK20>W$HIT{tkeBML+x92 z=g*&y!S~u#!Apzxkxrr2r&KR!_2C8rp-f>`VljXEuL*eWFV)#{^m8^*#9mD5ss@&C z?wSnRNNfa#lqsIb2#F1Ip-#)u<;&YLu!yTs6-t??W~2Uh279Ac_7mZ*N>_ zaj6?GY}A}Ios%LE78c$KR3KFX^!38>*=%Oc$*ar73=y$x(S40QK_lt~28|PXBShL8 zh|&{{bO-A@x%vrHwtHlf1i0J@`p5A)LaK!w=N2->socXR-1LfjphZW(;q5BvT5cU``4Dhb5t5F{5TN%CK7-(AY-k}_sNAgs0S zh;o+|CJhsaWA924G^javZ6s*GNH-B}^HH7~1eE7Le|g-UXgHOst<6XpdMy^Iz}}_K zoigs$NJZtu-ZkxX>((tvRw>qmS5oA};g*MgOz4Do@PNmG({)pUf<_NL&hOW$NEHB> ziDP3oi^UL`J6{)JH-Vn@;LhP-kvYCm1S2uu-MjRf>gye7aeA?NO zy{rsY5NHQv)tc2`d9kcdG$H^)b=r*8++ldsC<#yRyNkltL&uEow9-!8-)BSx1;f5P z+463j$kyU*Ghhi*<66SwuM`bBCFG8zh-3HF%1OeN4JA=nYi|JUDvx&fzFfS06Yu!T z(2aDMilFIWz1L3g!`vKzF`lytFyHb8Eb2WXd8uHQ$R<)ms*pzlth9zrf(LN+A+hUk z8PI`tSw3;DNEPLoLtZ8|kLIU3bKCpk$&6<;FbvAHv`e8ANo4N%0xt*em>yuqT$dLY zi(kJUz;9K2m?E%~IBiJrSc|v&6-9g}eSGH>P43Opw&$j3Tt_QVv&GiGY>lPQd6t!y zKGY~+lXhM8mFNA+KuNs)ckIF8o>%E$;(AXiKN-K4SBRq{f7kuZT^T}t{cKOUtoqHp zPnoN!wi9AV{~k!f&hUY3Su=F+AESU{to!W^szVu$Jb(Zj&7X^|$5mvpha zIEA*ZVv79Azq_wqv^8|v)F1vEhfmyqqMKgeNs>qRQgNs272PQ#)N4D?D1N)X8@TY9 zH=DrhC$_%69)!oYsoZlX`?RO0=jwB5`*R+k4~hz<^V3oi-+PBlsXmx?2tebbP)>m> z0h1-Ar5}c^nBpVfWL)z8`}B*0sxI;IYd}Fz^9LhP1$U$r0{*%FKDGs;@&5g_P!i4< zHEw8q0OT0SWAhg_b{!lYPk4{HC>g5VCUg;fiwhsCW@BeZXs^iom4_6B<}h(pVtwD{ zpMV4G2%L!9em}_<;3g-0F31T+QHGPG*^L{Ke1zHa=b0!pZWO$wQP-SyfA{)z+`M0E zPlWUJfpgyLb3NAtjBek)Dvt_OkP;l*k2zC`=Jehs#< z_Jy8xTf{`9`3Srvz=h8_j!{rxIE<>i0h2OCTdm8;@A-MU&9>J0^ z@}~7qz|5;Zz39ul$P+iy)`k?L|12yh{Z(h|gxg5nV&Ku zwz*34*u;PH5N4_#bA7eZ(UFje@U^?Mo7}o}Q zxQ3>pB1Z=YjZegF8tdMoIh`W8Rt+O2&GvIQ0Te7?NhJAWUavS*T$mFCxR}j9eE-E^@`bdiRtEy&% zo5nZGIH-{}@V)k)%X21ysE>Ghd8IfdF8AAQ1yyizD55)fZLy_dfgtcG?E!}d zIgvFXO3fF%^!L+Z=UC60DY8(qQWAfQ=Jty8QuDoriYxo0y>Z{q&kvTYhc*^%-JNMs zuBmz&trPI?uQD6KM{0FTJ6jTm1-1jDqjs9T;!!Yb5IRbH1Sma7KkW6oh1KwH{K9Dm2I++n6p=$ z15N$O&(}uyL0^amz+7HqMHy#mSBL5qJ&QXq6n&R{scUjrQ4q=^QL}@};2r(>6dd2!h&ofh#mCkI8>s{r46))%c z>g`)^8yoa>DQm>adj?Q-r5YV0D$5^;x68k@*J{#k%A4n%avv6DrwH5CRxu;}^h zMbFf2oLLTKi9g%{DB1EPUJAI);MNcuo0*>hE0LKm49}JF&_-Anum~k2wC8Q5sAw>f zK8p4FF%A>Xw({wab3)fFc-*+XRg@OvtEMJ-rub>=%YN&;b6oY-iL(nSIx8*vZUT%F zER>*GrsPvFi>%698X2k8w19#-|HE%Si!({Hpx^EWRX=mh_!2zR(=03}vK@1V?muw4 z9wsKL#Lur2{`)(oYIAAefm7z?fVXvZ>FS|JG$13z#0-A9N{a<6$1YOcAObOmp265Q zWlc}m?p*tCb2EAM;_!bDNlu@VfkFW>NTtP=%Hf2$u-;vTuGengzAP_KnUm%ao{@3- zF)k&F@CVyN+aJhP8%h4pU%%3$aue>{p?GT4`Tmn8=_wv*Cy|EP#7pv=G|$A+Pa)>+ zR#sS++bYdjdAm6Zk39l-dTDO`&#A0GGz1S~#8 zkxQ!^ydXRLJ60XRwGgtKalvubG$ucWgz;_52p&Y6v04}AxAD3~gSbtnPj~L#jXdB5 zCPK~=8u=Tg8`&lcrZmEZ{f5r?@*7ljN^#C*KyEzb5rOW)O#4HYhd1`94bFiY*n03w zG}Y^`PL%%{|M%v+^nP$n;!0$bZI@P>0I^i2k3~{2i)khKqKou%s0q;m0v{SC+@gnz z&TJ7{;RmykyZ8>FcaR>4eQBA&;Kn$7nn3sJ)h~U0j$~~oK9<~jV+|jlI*o!n6X^6Y zr>{~V5rnx!O3xl#6JWF>(K1VDKNE`JToL(AoT}Y{;(#-54_t;`H|E$(gj3lpFO0zn zmHhJzW{DgDeXjg=V^p*JDi!X}@SAeq+>DD$ zM>E1yu}yPob-DxIu*Yt_orH@w9>5{w2&ApP<|yXq6fU$(KbZg8{#CnZ&{rvRc-VOJ zFW@~EBV-hz<20`?C?H_vxX!uym-}B5dGMTn%Aoo`88GmkSrLqGSJv-`?O&+_G={x}7( z;-8|7ks1X_@k4>S+FH8v)Zj$Zp4ZFK9PUXT@6V@Qu{|5|Rh@I9k?jnzYC2c0 zTukgR@I`ynXKqr0ju5Sq2hi2JR8WF=PQdvJkuoW;oZU|*ReQ00*_YYKBSCh7F~rBm zmpk$N*|Q&R+n2Y03{Fl?j*dP(Me)5Q=+blqU|0}+Am~V!*(n-h58;&Rm3^ z!b2{L)`0g#UTfYW2{AFDa`MM8Cp{K+_Rn?_C6h!Nv%_`>EAGc;N6^v1uARh1X=#nO zY^uUBrc1m#G28ifSh7X5avTUoRZJCqmf8FFd>?kpNk7(P6oC5W4@7)PacfotHB8B7 zn5x9+qY0nMv(G1N4QX)$1L(#?J7V{ml9FEh)&SlroJR2gH@~5b9~|fV_U&8lrJ2L+ zNuX@#3-0XEyARyz_96ZEN=;BybQ`>rqK)y%BWcohw>pqTt??X<6D{CWG)_=UkHX^4 z$b&T~o0GVK8$SLLzjx)niv&VNX;wglC?FGTxS^rQ17_zWZHL^F5}}d{EHs`^Xq1NA z*#F+@`%I{VtgN(BKc5Zl#T|WndlCeQ@o}LqVn0lO z>`@|tfHLH=^cp5-JL8!Rx-`m@Cq7iIx2xd?(e9wf8J^3!aW_t$HDgtGIF<`dVNo(f z&)ig3ClgPFLBt-@oeR@Mq1qG>!%jt6S<+h}H}&9w?nCzNrIbUOy!90GOQMdQGX-x! zoUdea6-+ZTHO;Ml|6$ou`YS;h4L)ahQ)`5d(!S9-^%UAvT9#AH|oM{Gwl~@vHlb5@5n~Q@_iPT0B*?@n-4^w^R zj|xSDVk-!ZE3z(g8aFZT+Ljq#@h8A(Q~qqzpNpki3LT!TuqNR2d8!I|49${~677dt z1^vk^2G_3-Zf@3eOut)tJO0MJp1W+B@B+?~L!xE-XTZIXkYmN;4aLctko7M0lMD4F z0vG3nKV$Nr=kd&~qpsge_V5}qw;dXC7qeA5P zsjS}9Jd`whP;e%S6zM@ggL>4!o#bmAjG-k3idWFDB^>Q*yRy7V)7IeVtNeuI#+2&`$W*tj7zTa(hC%7c%K%E!OaGi z@_4H$Rzb_*@m>suJ+u|UEEo3F-niI27VjQFfm~Xx|j@Xoxl=HXVV%!Kg zEX~ZwQ*vuB#qQ=dB0}B!13%IJU)+Cp&?x>G8!I*}mQ4)jQnCE9=;=1uGO_6SGiZAP zVm~nlN{y|sQr?Y00N=eEIgEqBi-TqRnsrJ0ElXuCz=87ea+(#8FBb9m#<5|=YcZ53 z!-?(&F$7D59+(>$MZ}uWA`1bAhZ@4Zc3I*Znu16TjS_{DzJ3yhw1#GNDxAg}`sy|! zeFBdGt}l4$2BjXI@7@87UV3*K2c`6a`7q+$#p-CC9YAZ z2{8VM&`|ZS;r2}PZd)7eV5t_iz!`tfzgh55fJ*vh*RH1;>A%avm5RKsLAsIMpX?Z< zu}cbaat#67Zr^1vim{2JLP9*>vkrJK+dqAicvTjnu4MdeTmt9McfNyI8n-78bQ18s zA58{9f)^sC-pfB0-DPvP7N6x@aQsS1Nr^=*!$uAbML8W-JNEtC-X1h-ZEkL^tJ`+i zCyVSwBy+dug1>%ZnYzi3htQ*1=Kzh|?UjsEJls4r3nX>#-t7ROWO)^{(N0WyBX5@& z4}Ywt4WhMhLCnv~^gHjFTin-X)ESL&d)W@egzI;La$E`|LGM|^I6}xrB9~3L^6!(6 z7Uk)I?6)^m=SiIOF{}wXivFh#2tOiY*x{o1Igc{cM(}7#-%*VsLAN&juUw%JJgVlj z`LGUXXN)J?LNn?qbp`#uZq^53UKJ$ymthwvB}g9ZK3M*PR!TesM5 zYSu(A3$;DI^GqCI9t4AbdDQ2YP5RAq6b;6{QY=Pu@1opQ;E_IA)--p5IBp3$hYe*i zm0d_78sH}&4zhsmea*ui?zQOkjiWs!mrE_Y9XBzN5JP(&iTRg16&8%6+m8;& zlA4beiHzyAv4nMaW6>OE5C`c-B_AJ|!+I@T*lF;dcv480nMP8+W)eh2MPXRxu0JXJ zTXZpWcn^_Bt(!5b7oDVUOf>O>R+JU=3o%qw3l4kJ7NJxz!|8ILrU;9NHHic!eyi}w zCgOzVe(0=Ya-{y1B!RQAf|7 zq&HVJsi>)sW)+Wk4u4^oB*ny(%I67K2udTM1v%VELa!1okg&R9>*KkisStPcc|@b2 zvZRD!SnNFFP+bq-wV;9`H97qV&-eDa1%Ta}++a351L^e`^ce!B%?1$c6qUqS)I?eQ z2|k^q-f#po@=DU;?u`?nN=l6C(H+Y46fs2oGO-2J(W=j4Egy&x?>HrYR+8tm9F{SC zLiHkrOKq*t6=p|aWF*nIW1Koe#bTu@CFlK^55O66XJeP6TSU*F2SnMw_jD^c?KHxM z!&?#5nVv|}hJZ^%X9iX2{LzKJvAqLKMA2~^i^ z7xBDJdoibpV&+U4D>sv34`u1!GgyYUQ?CfT>dhf*Twtp#l8%SL?;*>M!^aSXqH-r7 z!`4P{czAdc%Qm^WP4E!q#!I`^_vgY#Z`&8KDIY=5tTvN`kUBtFF?AuAbW28KS_)3I zn5=9J2}eW5x6NB+#)g=y&u&lASh0(qONgSu28wZvSB|-^QaObn*-#kYoNGTgno56F zeJO8eA6aOYQd;mA7&zXNy~hQYX8p}-C_g@^rRhOZ+?|mH3Wu-S|~w+T>0z z3{p@~5OF#N0im~4T7;KjVS$#)7d^N2-)9ASV5CKtn7Zw^!>L}>qcO~EQ@3N#{UYnT z8}mI!;$oqz*qOn+m1@B-r{3&ks_xf}+$r2P+o~t=%E~ndL0XGB+c;fD(}lkx)-mN(CyUVx#X(m2SZn8l@3OXE+laL@*46W}$tDA~6m z{hVqjck^>h=`${f&r@4f5YDK`%-~zattj1X5$R;-pat7zi`G)9B)Q;4Z`Jxy#Ek2P z(!fI$2Zf3URV5Go7Qa7-R1vHWWYbfsP>W1cveerTTN9W0!b8Hgsi~-dh7N)~HC7dA zN6Ssj{*yaLdZlH0wn69H^WjsGH1_&&sG!8DNg0MD|c1jrRVUk#!q0r6t$$%jOs7s^VeeIEMV?&2h6g)Uf}2l`M4< zGSkMaA{z=yIg!2VxoLO!+03~1WBQvNU^kIeRLgU#LlnjH4OT6Jc2U#&bio6In}ko= z(kcVi0Nz%vYfB?^*#H2dJ5Y^W-X)2s`37emH56qD46S zUHlaKVkmSNciB=S!Ak23yH0j~+Iq~P=Zc59SXsFV{=g_2R{QVcL~@e~lNvW60GNAm z4kQ7!LEF|Ydmb%-xf*l`SNPi%I%y)dw)D3C|152tKanr!cpoWhY{Y$W;OXbD8*hea z@n7_BkNUUwCYL?xkKouCQzhmvnUXKWKjJ_W8C&{iu6d6WR?xhSr28xqvpOZf2SHRw zs%pPt{EXds*lsoD2RrU5q2MW`f{+7f6D19et%=klf?@m?uSj(A{EvYTWv#@v$e3t0)-M?{9X{vb3h>31A;%BZ4u zEN{kb#9Ud^Yl?jK`ip)}IYYHu$Fe%k3EhoWI@T>Uoi?GSA8jhZTphpH8U%IWeE`z9 z=8Ot%0S4Eu)edn<`AQ)a{R32^VmhMKcqJ+9^b53@m>BjEgaq$MKljTnB0@4ya&Po) zl3~gC(jlJD*4#u55;YY{ZiS zNijZe#o|qNPEP9R4i;`Td_Yo{>Z2~VI=1s9E~Fok-7v0^Ws=-}y4iGN!v4zR(%8b@ zJ<831*sYv2+n3edH^0hoP&nVcd)6!`@Kx`xN8k4V^-)I?WLE5*T$yP*u;qOUO6Z_1 z`~d_6TAv2Sx%xE?&N^3`zMrlzLUNyK^U2c3;?UScv1SV{Qo<_2Hi$o86;0UHYDQbx z*+qy5nAwXFcWcfK?5(sOE9kI2x%{j*UdiJ0p2eO>!>~YlA9gVxw!a(u^HlTwnBYLd-HNHUoS zLy{tWgxq4e<#PWE{GGyT1uc%OX{I-Dmh)`JX zA})Ki&X9&F~W`!xub#A@tT*|2aBP3>(nGgj*8 z6tXAeSb;dGonCUR+|Y7$LZ5OXZ>MNv>Xo_bkSS|F)&?XGzxmQLf zwWykhBjKgL(A%+(`a!gLGn3F`qN zi21+)q-V3Tl+^i->dTW5L)#t{+9iNnVjC?oi_1})ymFBc+hW;16}lo+>$jS>3!NO# z)~B$4BE+$)oTFx!g)%F4U9T1Zd z99rHj-MS-2Y&vN=miD=m0NFgOn(dAK{?EVpCjP;L2XnPuvprsGfx~-}FrMUO-KUV# z_eFk1FkASB8b#@x-$}RdkF5yx)R~4TQ*!SzNy~QgN@U0M!N|5Dzz{DaDCbH6;Q3Ok z_Kudmq5Q<{H`u1v_x3B>Sf-3WZuBYk-8X+zUoG0Q!`*d%K<9Vb*5MV=W=860lc#rT zElNWGwv|0h>sT_}=QO=J8DjT0q}lWsI_)j4$>XE#^xA#Y4zb}AU$|G*ylZbr2cXQL zcN}fj2HL>KZ&ki~_wIea`ChR7zo%JQhp{U{qHQd>6{|Dxq8i+E#YD%j5SrX6=7Ysh zCL!>s(4lSqmHV(D0>X1u`zA39A?%HcCMt5r?eJ&VzvZ6u&cMjc2c6+mB?}ZIfr974 zHD3CjX#d|iwBRIB2;KT8D^6cWEbQ0yVf(0-sv@rna$Qt