From 8abbe22ffbc306b7be0e2e09ba1ce167430f2c7f Mon Sep 17 00:00:00 2001 From: David Novotny Date: Thu, 16 Apr 2020 13:59:34 -0700 Subject: [PATCH] ICP - point-to-point version MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: The iterative closest point algorithm - point-to-point version. Output of `bm_iterative_closest_point`: Argument key: `batch_size dim n_points_X n_points_Y use_pointclouds` ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- IterativeClosestPoint_1_3_100_100_False 107569 111323 5 IterativeClosestPoint_1_3_100_1000_False 118972 122306 5 IterativeClosestPoint_1_3_1000_100_False 108576 110978 5 IterativeClosestPoint_1_3_1000_1000_False 331836 333515 2 IterativeClosestPoint_1_20_100_100_False 134387 137842 4 IterativeClosestPoint_1_20_100_1000_False 149218 153405 4 IterativeClosestPoint_1_20_1000_100_False 414248 416595 2 IterativeClosestPoint_1_20_1000_1000_False 374318 374662 2 IterativeClosestPoint_10_3_100_100_False 539852 539852 1 IterativeClosestPoint_10_3_100_1000_False 752784 752784 1 IterativeClosestPoint_10_3_1000_100_False 1070700 1070700 1 IterativeClosestPoint_10_3_1000_1000_False 1164020 1164020 1 IterativeClosestPoint_10_20_100_100_False 374548 377337 2 IterativeClosestPoint_10_20_100_1000_False 472764 476685 2 IterativeClosestPoint_10_20_1000_100_False 1457175 1457175 1 IterativeClosestPoint_10_20_1000_1000_False 2195820 2195820 1 IterativeClosestPoint_1_3_100_100_True 110084 115824 5 IterativeClosestPoint_1_3_100_1000_True 142728 147696 4 IterativeClosestPoint_1_3_1000_100_True 212966 213966 3 IterativeClosestPoint_1_3_1000_1000_True 369130 375114 2 IterativeClosestPoint_10_3_100_100_True 354615 355179 2 IterativeClosestPoint_10_3_100_1000_True 451815 452704 2 IterativeClosestPoint_10_3_1000_100_True 511833 511833 1 IterativeClosestPoint_10_3_1000_1000_True 798453 798453 1 -------------------------------------------------------------------------------- ``` Reviewed By: shapovalov, gkioxari Differential Revision: D19909952 fbshipit-source-id: f77fadc88fb7c53999909d594114b182ee2a3def --- pytorch3d/ops/__init__.py | 3 +- pytorch3d/ops/points_alignment.py | 287 ++++++++++++++++++++++++++---- pytorch3d/ops/utils.py | 58 +++++- tests/bm_points_alignment.py | 35 +++- tests/icp_data.pth | Bin 0 -> 81765 bytes tests/test_points_alignment.py | 265 ++++++++++++++++++++++++++- 6 files changed, 603 insertions(+), 45 deletions(-) create mode 100644 tests/icp_data.pth diff --git a/pytorch3d/ops/__init__.py b/pytorch3d/ops/__init__.py index 48703e2a..fe522d3d 100644 --- a/pytorch3d/ops/__init__.py +++ b/pytorch3d/ops/__init__.py @@ -6,9 +6,10 @@ from .graph_conv import GraphConv from .knn import knn_gather, knn_points from .mesh_face_areas_normals import mesh_face_areas_normals from .packed_to_padded import packed_to_padded, padded_to_packed -from .points_alignment import corresponding_points_alignment +from .points_alignment import corresponding_points_alignment, iterative_closest_point from .sample_points_from_meshes import sample_points_from_meshes from .subdivide_meshes import SubdivideMeshes +from .utils import convert_pointclouds_to_tensor, eyes, is_pointclouds, wmean from .vert_align import vert_align diff --git a/pytorch3d/ops/points_alignment.py b/pytorch3d/ops/points_alignment.py index 80100f5b..7ac3f182 100644 --- a/pytorch3d/ops/points_alignment.py +++ b/pytorch3d/ops/points_alignment.py @@ -1,22 +1,231 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. import warnings -from typing import List, Tuple, Union +from typing import TYPE_CHECKING, List, NamedTuple, Optional, Union import torch -from pytorch3d.ops import utils as oputil +from pytorch3d.ops import knn_points from pytorch3d.structures import utils as strutil -from pytorch3d.structures.pointclouds import Pointclouds + +from . import utils as oputil + + +if TYPE_CHECKING: + from pytorch3d.structures.pointclouds import Pointclouds + + +# named tuples for inputs/outputs +class SimilarityTransform(NamedTuple): + R: torch.Tensor + T: torch.Tensor + s: torch.Tensor + + +class ICPSolution(NamedTuple): + converged: bool + rmse: Union[torch.Tensor, None] + Xt: torch.Tensor + RTs: SimilarityTransform + t_history: List[SimilarityTransform] + + +def iterative_closest_point( + X: Union[torch.Tensor, "Pointclouds"], + Y: Union[torch.Tensor, "Pointclouds"], + init_transform: Optional[SimilarityTransform] = None, + max_iterations: int = 100, + relative_rmse_thr: float = 1e-6, + estimate_scale: bool = False, + allow_reflection: bool = False, + verbose: bool = False, +) -> ICPSolution: + """ + Executes the iterative closest point (ICP) algorithm [1, 2] in order to find + a similarity transformation (rotation `R`, translation `T`, and + optionally scale `s`) between two given differently-sized sets of + `d`-dimensional points `X` and `Y`, such that: + + `s[i] X[i] R[i] + T[i] = Y[NN[i]]`, + + for all batch indices `i` in the least squares sense. Here, Y[NN[i]] stands + for the indices of nearest neighbors from `Y` to each point in `X`. + Note, however, that the solution is only a local optimum. + + Args: + **X**: Batch of `d`-dimensional points + of shape `(minibatch, num_points_X, d)` or a `Pointclouds` object. + **Y**: Batch of `d`-dimensional points + of shape `(minibatch, num_points_Y, d)` or a `Pointclouds` object. + **init_transform**: A named-tuple `SimilarityTransform` of tensors + `R`, `T, `s`, where `R` is a batch of orthonormal matrices of + shape `(minibatch, d, d)`, `T` is a batch of translations + of shape `(minibatch, d)` and `s` is a batch of scaling factors + of shape `(minibatch,)`. + **max_iterations**: The maximum number of ICP iterations. + **relative_rmse_thr**: A threshold on the relative root mean squared error + used to terminate the algorithm. + **estimate_scale**: If `True`, also estimates a scaling component `s` + of the transformation. Otherwise assumes the identity + scale and returns a tensor of ones. + **allow_reflection**: If `True`, allows the algorithm to return `R` + which is orthonormal but has determinant==-1. + **verbose**: If `True`, prints status messages during each ICP iteration. + + Returns: + A named tuple `ICPSolution` with the following fields: + **converged**: A boolean flag denoting whether the algorithm converged + successfully (=`True`) or not (=`False`). + **rmse**: Attained root mean squared error after termination of ICP. + **Xt**: The point cloud `X` transformed with the final transformation + (`R`, `T`, `s`). If `X` is a `Pointclouds` object, returns an + instance of `Pointclouds`, otherwise returns `torch.Tensor`. + **RTs**: A named tuple `SimilarityTransform` containing + a batch of similarity transforms with fields: + **R**: Batch of orthonormal matrices of shape `(minibatch, d, d)`. + **T**: Batch of translations of shape `(minibatch, d)`. + **s**: batch of scaling factors of shape `(minibatch, )`. + **t_history**: A list of named tuples `SimilarityTransform` + the transformation parameters after each ICP iteration. + + References: + [1] Besl & McKay: A Method for Registration of 3-D Shapes. TPAMI, 1992. + [2] https://en.wikipedia.org/wiki/Iterative_closest_point + """ + + # make sure we convert input Pointclouds structures to + # padded tensors of shape (N, P, 3) + Xt, num_points_X = oputil.convert_pointclouds_to_tensor(X) + Yt, num_points_Y = oputil.convert_pointclouds_to_tensor(Y) + + b, size_X, dim = Xt.shape + + if (Xt.shape[2] != Yt.shape[2]) or (Xt.shape[0] != Yt.shape[0]): + raise ValueError( + "Point sets X and Y have to have the same " + + "number of batches and data dimensions." + ) + + if ((num_points_Y < Yt.shape[1]).any() or (num_points_X < Xt.shape[1]).any()) and ( + num_points_Y != num_points_X + ).any(): + # we have a heterogeneous input (e.g. because X/Y is + # an instance of Pointclouds) + mask_X = ( + torch.arange(size_X, dtype=torch.int64, device=Xt.device)[None] + < num_points_X[:, None] + ).type_as(Xt) + else: + mask_X = Xt.new_ones(b, size_X) + + # clone the initial point cloud + Xt_init = Xt.clone() + + if init_transform is not None: + # parse the initial transform from the input and apply to Xt + try: + R, T, s = init_transform + assert ( + R.shape == torch.Size((b, dim, dim)) + and T.shape == torch.Size((b, dim)) + and s.shape == torch.Size((b,)) + ) + except Exception: + raise ValueError( + "The initial transformation init_transform has to be " + "a named tuple SimilarityTransform with elements (R, T, s). " + "R are dim x dim orthonormal matrices of shape " + "(minibatch, dim, dim), T is a batch of dim-dimensional " + "translations of shape (minibatch, dim) and s is a batch " + "of scalars of shape (minibatch,)." + ) + # apply the init transform to the input point cloud + Xt = _apply_similarity_transform(Xt, R, T, s) + else: + # initialize the transformation with identity + R = oputil.eyes(dim, b, device=Xt.device, dtype=Xt.dtype) + T = Xt.new_zeros((b, dim)) + s = Xt.new_ones(b) + + prev_rmse = None + rmse = None + iteration = -1 + converged = False + + # initialize the transformation history + t_history = [] + + # the main loop over ICP iterations + for iteration in range(max_iterations): + Xt_nn_points = knn_points( + Xt, Yt, lengths1=num_points_X, lengths2=num_points_Y, K=1, return_nn=True + )[2][:, :, 0, :] + + # get the alignment of the nearest neighbors from Yt with Xt_init + R, T, s = corresponding_points_alignment( + Xt_init, + Xt_nn_points, + weights=mask_X, + estimate_scale=estimate_scale, + allow_reflection=allow_reflection, + ) + + # apply the estimated similarity transform to Xt_init + Xt = _apply_similarity_transform(Xt_init, R, T, s) + + # add the current transformation to the history + t_history.append(SimilarityTransform(R, T, s)) + + # compute the root mean squared error + Xt_sq_diff = ((Xt - Xt_nn_points) ** 2).sum(2) + rmse = oputil.wmean(Xt_sq_diff[:, :, None], mask_X).sqrt()[:, 0, 0] + + # compute the relative rmse + if prev_rmse is None: + relative_rmse = rmse.new_ones(b) + else: + relative_rmse = (prev_rmse - rmse) / prev_rmse + + if verbose: + rmse_msg = ( + f"ICP iteration {iteration}: mean/max rmse = " + + f"{rmse.mean():1.2e}/{rmse.max():1.2e} " + + f"; mean relative rmse = {relative_rmse.mean():1.2e}" + ) + print(rmse_msg) + + # check for convergence + if (relative_rmse <= relative_rmse_thr).all(): + converged = True + break + + # update the previous rmse + prev_rmse = rmse + + if verbose: + if converged: + print(f"ICP has converged in {iteration + 1} iterations.") + else: + print(f"ICP has not converged in {max_iterations} iterations.") + + if oputil.is_pointclouds(X): + Xt = X.update_padded(Xt) # type: ignore + + return ICPSolution(converged, rmse, Xt, SimilarityTransform(R, T, s), t_history) + + +# threshold for checking that point crosscorelation +# is full rank in corresponding_points_alignment +AMBIGUOUS_ROT_SINGULAR_THR = 1e-15 def corresponding_points_alignment( - X: Union[torch.Tensor, Pointclouds], - Y: Union[torch.Tensor, Pointclouds], + X: Union[torch.Tensor, "Pointclouds"], + Y: Union[torch.Tensor, "Pointclouds"], weights: Union[torch.Tensor, List[torch.Tensor], None] = None, estimate_scale: bool = False, allow_reflection: bool = False, - eps: float = 1e-8, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + eps: float = 1e-9, +) -> SimilarityTransform: """ Finds a similarity transformation (rotation `R`, translation `T` and optionally scale `s`) between two given sets of corresponding @@ -29,25 +238,25 @@ def corresponding_points_alignment( The algorithm is also known as Umeyama [1]. Args: - X: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)` + **X**: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)` or a `Pointclouds` object. - Y: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)` + **Y**: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)` or a `Pointclouds` object. - weights: Batch of non-negative weights of + **weights**: Batch of non-negative weights of shape `(minibatch, num_point)` or list of `minibatch` 1-dimensional tensors that may have different shapes; in that case, the length of i-th tensor should be equal to the number of points in X_i and Y_i. Passing `None` means uniform weights. - estimate_scale: If `True`, also estimates a scaling component `s` + **estimate_scale**: If `True`, also estimates a scaling component `s` of the transformation. Otherwise assumes an identity scale and returns a tensor of ones. - allow_reflection: If `True`, allows the algorithm to return `R` + **allow_reflection**: If `True`, allows the algorithm to return `R` which is orthonormal but has determinant==-1. - eps: A scalar for clamping to avoid dividing by zero. Active for the + **eps**: A scalar for clamping to avoid dividing by zero. Active for the code that estimates the output scale `s`. Returns: - 3-element tuple containing + 3-element named tuple `SimilarityTransform` containing - **R**: Batch of orthonormal matrices of shape `(minibatch, d, d)`. - **T**: Batch of translations of shape `(minibatch, d)`. - **s**: batch of scaling factors of shape `(minibatch, )`. @@ -58,8 +267,8 @@ def corresponding_points_alignment( """ # make sure we convert input Pointclouds structures to tensors - Xt, num_points = _convert_point_cloud_to_tensor(X) - Yt, num_points_Y = _convert_point_cloud_to_tensor(Y) + Xt, num_points = oputil.convert_pointclouds_to_tensor(X) + Yt, num_points_Y = oputil.convert_pointclouds_to_tensor(Y) if (Xt.shape != Yt.shape) or (num_points != num_points_Y).any(): raise ValueError( @@ -90,8 +299,8 @@ def corresponding_points_alignment( weights = mask if weights is None else mask * weights.type_as(Xt) # compute the centroids of the point sets - Xmu = oputil.wmean(Xt, weights, eps=eps) - Ymu = oputil.wmean(Yt, weights, eps=eps) + Xmu = oputil.wmean(Xt, weight=weights, eps=eps) + Ymu = oputil.wmean(Yt, weight=weights, eps=eps) # mean-center the point sets Xc = Xt - Xmu @@ -107,7 +316,7 @@ def corresponding_points_alignment( if (num_points < (dim + 1)).any(): warnings.warn( "The size of one of the point clouds is <= dim+1. " - + "corresponding_points_alignment can't return a unique solution." + + "corresponding_points_alignment cannot return a unique rotation." ) # compute the covariance XYcov between the point sets Xc, Yc @@ -117,6 +326,16 @@ def corresponding_points_alignment( # decompose the covariance matrix XYcov U, S, V = torch.svd(XYcov) + # catch ambiguous rotation by checking the magnitude of singular values + if (S.abs() <= AMBIGUOUS_ROT_SINGULAR_THR).any() and not ( + num_points < (dim + 1) + ).any(): + warnings.warn( + "Excessively low rank of " + + "cross-correlation between aligned point clouds. " + + "corresponding_points_alignment cannot return a unique rotation." + ) + # identity matrix used for fixing reflections E = torch.eye(dim, dtype=XYcov.dtype, device=XYcov.device)[None].repeat(b, 1, 1) @@ -148,26 +367,18 @@ def corresponding_points_alignment( # unit scaling since we do not estimate scale s = T.new_ones(b) - return R, T, s + return SimilarityTransform(R, T, s) -def _convert_point_cloud_to_tensor(pcl: Union[torch.Tensor, Pointclouds]): +def _apply_similarity_transform( + X: torch.Tensor, R: torch.Tensor, T: torch.Tensor, s: torch.Tensor +) -> torch.Tensor: """ - If `type(pcl)==Pointclouds`, converts a `pcl` object to a - padded representation and returns it together with the number of points - per batch. Otherwise, returns the input itself with the number of points - set to the size of the second dimension of `pcl`. + Applies a similarity transformation parametrized with a batch of orthonormal + matrices `R` of shape `(minibatch, d, d)`, a batch of translations `T` + of shape `(minibatch, d)` and a batch of scaling factors `s` + of shape `(minibatch,)` to a given `d`-dimensional cloud `X` + of shape `(minibatch, num_points, d)` """ - if isinstance(pcl, Pointclouds): - X = pcl.points_padded() - num_points = pcl.num_points_per_cloud() - elif torch.is_tensor(pcl): - X = pcl - num_points = X.shape[1] * torch.ones( - X.shape[0], device=X.device, dtype=torch.int64 - ) - else: - raise ValueError( - "The inputs X, Y should be either Pointclouds objects or tensors." - ) - return X, num_points + X = s[:, None, None] * torch.bmm(X, R) + T[:, None, :] + return X diff --git a/pytorch3d/ops/utils.py b/pytorch3d/ops/utils.py index fa690ee1..134172b0 100644 --- a/pytorch3d/ops/utils.py +++ b/pytorch3d/ops/utils.py @@ -1,9 +1,13 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from typing import Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Tuple, Union import torch +if TYPE_CHECKING: + from pytorch3d.structures import Pointclouds + + def wmean( x: torch.Tensor, weight: Optional[torch.Tensor] = None, @@ -41,3 +45,55 @@ def wmean( return (x * weight[..., None]).sum(**args) / weight[..., None].sum(**args).clamp( eps ) + + +def eyes( + dim: int, + N: int, + device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Generates a batch of `N` identity matrices of shape `(N, dim, dim)`. + + Args: + **dim**: The dimensionality of the identity matrices. + **N**: The number of identity matrices. + **device**: The device to be used for allocating the matrices. + **dtype**: The datatype of the matrices. + + Returns: + **identities**: A batch of identity matrices of shape `(N, dim, dim)`. + """ + identities = torch.eye(dim, device=device, dtype=dtype) + return identities[None].repeat(N, 1, 1) + + +def convert_pointclouds_to_tensor(pcl: Union[torch.Tensor, "Pointclouds"]): + """ + If `type(pcl)==Pointclouds`, converts a `pcl` object to a + padded representation and returns it together with the number of points + per batch. Otherwise, returns the input itself with the number of points + set to the size of the second dimension of `pcl`. + """ + if is_pointclouds(pcl): + X = pcl.points_padded() # type: ignore + num_points = pcl.num_points_per_cloud() # type: ignore + elif torch.is_tensor(pcl): + X = pcl + num_points = X.shape[1] * torch.ones( # type: ignore + X.shape[0], device=X.device, dtype=torch.int64 + ) + else: + raise ValueError( + "The inputs X, Y should be either Pointclouds objects or tensors." + ) + return X, num_points + + +def is_pointclouds(pcl: Union[torch.Tensor, "Pointclouds"]): + """ Checks whether the input `pcl` is an instance `Pointclouds` of + by checking the existence of `points_padded` and `num_points_per_cloud` + functions. + """ + return hasattr(pcl, "points_padded") and hasattr(pcl, "num_points_per_cloud") diff --git a/tests/bm_points_alignment.py b/tests/bm_points_alignment.py index 24f8d0d2..942e76aa 100644 --- a/tests/bm_points_alignment.py +++ b/tests/bm_points_alignment.py @@ -5,7 +5,38 @@ from copy import deepcopy from itertools import product from fvcore.common.benchmark import benchmark -from test_points_alignment import TestCorrespondingPointsAlignment +from test_points_alignment import TestCorrespondingPointsAlignment, TestICP + + +def bm_iterative_closest_point() -> None: + + case_grid = { + "batch_size": [1, 10], + "dim": [3, 20], + "n_points_X": [100, 1000], + "n_points_Y": [100, 1000], + "use_pointclouds": [False], + } + + test_args = sorted(case_grid.keys()) + test_cases = product(*case_grid.values()) + kwargs_list = [dict(zip(test_args, case)) for case in test_cases] + + # add the use_pointclouds=True test cases whenever we have dim==3 + kwargs_to_add = [] + for entry in kwargs_list: + if entry["dim"] == 3: + entry_add = deepcopy(entry) + entry_add["use_pointclouds"] = True + kwargs_to_add.append(entry_add) + kwargs_list.extend(kwargs_to_add) + + benchmark( + TestICP.iterative_closest_point, + "IterativeClosestPoint", + kwargs_list, + warmup_iters=1, + ) def bm_corresponding_points_alignment() -> None: @@ -21,7 +52,7 @@ def bm_corresponding_points_alignment() -> None: } test_args = sorted(case_grid.keys()) - test_cases = product(*[case_grid[k] for k in test_args]) + test_cases = product(*case_grid.values()) kwargs_list = [dict(zip(test_args, case)) for case in test_cases] # add the use_pointclouds=True test cases whenever we have dim==3 diff --git a/tests/icp_data.pth b/tests/icp_data.pth new file mode 100644 index 0000000000000000000000000000000000000000..e99a61b3704b93dd1df117bdda89851e8db8a87b GIT binary patch literal 81765 zcmZ^p1zc3k*T+dIML{gE5CjAjSX#JqRs;)BT)RNP7J++N?7*Nw!2m_YLdBL4C3en= zB6c7qc4K0N0>NdouJ*1@&Th`UW?f4{BBEGJ<_wbZ zT0?wPd^FfctH<%$J{^73eN^LkorDl}hikXB)p!K zZz|!NO=azvsy?;B=eXI~+lzSpsY-tHfACxU)eVQN{E9F-U&MR1&aJe0w4R$^SF zysLzF`wxtrgR3Loi(wQocCQbP0^{*F#+(vEq&!G?apM?!J2w|+zBj`tV(eLuiRJq! zS`y29DG}aMzORJu_a6wBB|E-9LnvZApbink4^+h1NBMSNDeouYS><3BW04}p{*A}j z-qq2?!I>YzSc)&WMA@9WFZj^EF&?I54wv#HB>c!mn`H0i>|*E0kDA)VZw_Fthlt_B6+w)kitG9uv!_D5JbWiC8J+Qzd*_BM3)VHq^SYM6Ifi^6EN7 z48KMZ<+aMUr%U+^3BOLtuV;fr-JqY@SdUML*xDY))N2-}JYiXDOMR zrTi8NzqRouiR{_H!*634MUU97#MC7#`){&#C^0*w{4NQf^B)+IJ?rPY8AdT&?5Piq zBF20F#(19+vtP;|knjf^$2f=-{qqpRD29vNdQ2>zr;KsF5>X)K4@>yM|3KKevKuHS=FV+V~ z5#vjLV|-bOsg&|pBz#pP7)LhZI5@F>b(LWhF}_xhiRG^=V|+u2xGCjtN%-o<5UwIS zksF)GZ!?4<#&_xvG5lRcjPEJmeqYKzknj(s{390Ox*_@Tf5h04O~j7;lltI1{YP+~ z{SD4@C9_7#zmV`R8*P%ko2xVHxc@W|qZYF#|oybP?t-g z3sXw28Bv9!(wC(~e-vE~xJlf*sS!_A9!DV6WgPA65hr)5_0tU`i@D zGAj|Bm~HePMD|Xsmk7>^`nrijaA8u(R?HoOD=XGdAA%dxN_H=%q=GxM62XJnM%YgF z4s3j4@!=Hp%A67=>*KCq11tW`shDXcyEju(!IN2u(1+PZ*ouXY$l29Z@KV$(bL!1x zJv&zD%L>Yz_G3!P?a!1{7{IJV7|85@a2@R&M1qf^UYS#0ChNE{f*+HLoDO2IO6Fju zq=G-Q5@AT4mN3h5s+f#~p-eZPQ<0049UBjZu?zm@bU4$>_SfZfgo0C_(~(RoKW`LM zQeiZ+5@8IpjoyQui=z|kOJf!Fb@QDtj!7lEZWI{LihoCe0H&4fK&GU^1ZE{d5VQZq zR($7!74^!ThA>$lcSTMovf|&IPGVZgp3Ia~n8K_?n96LUYcq6`cBuc?+46 z3X7PP2#cBhk9$x&`U=sCdgWS0Sif;kizVLuw7l8M8XP1y)vmQnXG5W3aPB1 zOllfaO71GAq{3=uCBhnJ8|8|e*#o$puvSs8Olmrlb=(*sgGohF*RfY6b3Id1A(L5& zu%S+GWR|70Zbse2bmK`C+1s&)b0Lde@Hg(8nf^!IwIwE zbw88!kym8(04x5@>OrQJ>_bdRgXlh7V6vVaD;#D8WmXHB zQgV+lB^8b`D-nvAZJcZG>gwn!6f5eLSuJ6*jvFJCGO5U_jJ+zElqsoz%t{1#oj%4a z%W2(|TE_H$WL0G6!a{SLUGO*XCz$?6;LBO@e}O;A3d#?xU`i^SVpbxYX10-=aAL!l ziwo;cXB73yIrS`)N_gFzdX5$U&Z*~_R&!OBW{UxPHY*7?6!pra-ej_#9V^^o1!YpJ znNo6ZGbI)7Fe?%6GW#D~k)yMta8FUMOzM3m>$ow(111$oeaK#w%tuT~g~!ZFgeP_S zDYGo8b#v-7rW;SHy_-lebw6hp{7q^N)5`YO^`;lB_`jsSWCi61zG6x$yk=G+ykWMH zn{cvsa%Bqvk?>YgubfnCnN-5-Ce?SW_;*r$&$N>Lfhno*ky(lGiP=WjYzM&Ej_pDT zpB44Wq<&$tKJbd9er3hKN&UvOlKq`2sqlkYiSU!z#@K8@z@{ZT;g_ObnbhA**0W=U zKdhiks*2iQt}0VfL5*37pw4U~T=r1yNQFZPjK< zD(EmP5p?Uc9 zLuMsHOJ*Cn3B}fn8(Z87MvD6SS+y0@N_yRt-I^8uj_z%kR>IpdB^BB+D-n#DZ4_?r z=D=3@f{CJDnbr18)<<5E)efxqH>(|)R$ow34U>wjc4x0jW)G&M zLQiHTf^D6)W0qyLZdSEty78>C4SNS?J3+)Q_?uM+rj_ll%c>(Q{x7Retf2frXQrfr z3$qfzmD&He31`LQjGLmqep>Ct^grnCtoT2=2P-J)98*$(Fe?#&*~aN2XM08$E9#YL z?agF;=oM-8WW~Q}?ZdQ^?ZuQ-@Mcyb^kudYwv&jh^Au_Ar>Iw^wLg>f>{wv{D=5=C zkSQhChbgJx%dAB3W3~~lVzs0=N)QGq>Xm67%w!!mM(}4+k=7yXRmqevB^8D;D-nj( z>EX;O&alQS&dvUD;Pm(CMfG1!<(R{({mWUuTc!Rlr*gxh!68=tS&ZgCuVkxW5!-&? zn>27Lmd3nR1E*s7!)rHmvULlu^Dn1z0nh9H+gbM*&+9dC*6lnuZRk`S&@^k{R4hz+ z{RU3Os*-Qs(5cuWZPCD4cYtTmz^Pck@P-YXijk3T*}z$Mpl0-Mr(4}&U8{yp#W_Li zhECRV_%;okig}f9+tA6rseHQzPQ{4F8#iz&R&~5d183d2OZx`Sx4V`Rbi0|0I zS+_^nsiBjtH2BU9oOQc^rvGv(A9VOG4V;Rn6yB_%Q?Wm4-oU9CU-+&KoQjc?xA?cS zZuP)hHgqaZ(*D(tT-oNm;%uy2gDVxAGrV;}Cwr{tZ5lY0M~)4hibIMX4V;Q6Y`*8e zoh-xbgKQf**+VjK*T7kK%3$Bnsptrz2F|*J1cwIBy0Z?)hE6u3@=gt$bzAMu4V;Sk zns@oP)1mI1z_o!>u|VP78afp{u2%zR-O+&izn#jxY>x)cxr#xQVz*#r2^lsp+J23NX;8e_+e4hqR#VVNh`nR)gbBXuR~8RC7ARY_tq(nFtWIB) z_n>)eJm{*WHstU7-j6*&J5C!yuO^$3b(e?Wur`-v^Y4a|_P56nnn=8jit+JUa@{;b@J1^R!Z%}VUi?NVnSA-{X=R%SCLD*kl2)1k7ApP}I*b`6+ zQmrBQ_(KbfI8#c>TZCcrtT?)IeoOJdZ6ConY{Mvae=s8bwzt(Iiqv z@VxpS9HrkAtY?>F?#I2PtI7;=EPE$ph_hhU3{xmBAB+Jc8|QrgNp6L0r}Dg}blf#( zSRA^G+mqB4gI0Ormhyc1BDoD@jq;Es-#bhO%fFPie4GgbpF~33;&_sG;VwxQM}U44 zPcqm!3?9Dk#r=s^5$}1PM}E&*iDzzg!?+RKX}}0Ek7ztbR`~&E%%dg zMeRt;hkXQmrbBqb0Xl5=MdDpJm+X#rf=Ba@ki^;*aJQ!=PBYaN=a#>O4G|piiX4m! zSH%;*4{gD1_9Cbnn@+z!tHiLwt)b=CwX)P5-ng)I8*crThexVU&e{m@|soIi2y7v|jUh7ESZ)s2ad{|Ga9>C3qdm$-zJ{;oC(*qlp!W(e{=FDgfhu^jpPrzddNvPx z*}f>kt@O;hLQvdk2WJt;jdh_-ZQ2ew@`4$T!a3q7&NspX{jsSlog*Amil3a}*PhRkzT9{F%!FisBYBtEizE%lo!#mJEd zu-%XX`YNO+j4^yc+U-yy5k@B1t@S2cI4cDrlXh`MD|&$ChfTQoO#o?89f2clreW%s z7m%M;z)h!$4k)gXMEG@by(+xHz@D{N$;35G*|kwqLw)Ti+(Kh;_fHuw)=wCN)K! zsL^0`{UP?|*TK{_muQgd0V0n6M6G9?hew_gXx*_ItG*m1Q9F|$(bJZ`tqvm1pZ+F~ zr#8dr)KTu+`ktek_EZqlul{&5cs=>Hb0-=sx$%KH?$ixGi(6`$1h{6XfyhFSS&rJRe@6tpFw!}dAcny7KIndAP-FC zDqr>nm)D+H!4=WtZ+Ftv$o&x2_Z(+ed?EPadu8kgz^JDko-u++caF-GIu?^vBV+hPIABzDFBJ$owpL}bg1L4(YAU$|F z*QbvT(emoV&5qeZvLzpItldPgel(Rj=Eut_-?xW9uaVnq^$NEpjS;VXG#{7m&j!!W zv!Lma4lu~E9~rY}4~{jDBi?C?!ps5ogZByEo#w^8@_x@DzY~GxTA} z{&z6<*$VFJ)y)t!dm|cKdBUC(hVXs7Cz&&_2+c1Q(3PI)80gxIl_;~y&aIItU zn&=S?AD2jAW-oBg_+6T&b&q=;^aHXV?!yZqRvg_i63*u55s~R;n9*ws_byBV5ieIl z^7l&+?z$e$L?WD}eH4%ArjsKJcj1HMF&-Ku5U=DTh4qv3c=3B5$^(6=lZUe0bETigIa=_MoG||rz z)1zlc;Dw}BEyUM7&E%nfw^Ptr0&Ko*7By^7IPI}->F-uI$nuOdSZk0(LKpeq-d{HGF{ZhAMY}O@ zuX!iXYZ~NHqA?5$?QEf{>1;d~wgQ(7rLgPEE@=HS1uXld5tFe|IQEe-wJic{5q6YT z9p8hqyMKq0uV?X%s*d>B=qaH3EVI;tFXf`&UWK!@JQ=B*h|S+{)GlxWI_krIHP26sM7&8MOkQBQQbjb5Kaxu3( zKD=8Fb7jdg)9fxVFWd(wZ+%YNB;KayyC;zq;nAdZhLw2#@qE~&Q%t6vjfLM20=Q%*F zk0~L3k4B-_j3;zc##{2fNqczZ9*_DLOkw$ockVjVj&YSiw`hUQ39=@}TwXI!2S=tK zB>PjMY08el+<;Tbq+e|T)MPIsQ74CRPRCo5Yn%1xr{5ECs$T$okz&TRy!VRsJP-p{ z=iZW?GW-JHca6rrzg5K{lGhLxR|Oq*_aI%Ht)g)?RdhwybO_%O?J;mV(yD+OdTLb+ z9^bGQt_{7z>6|jg&UCPVD2 zO=!LR9=*Q&8~1edA2M~xMUp*cE2n-j3T?IbqG^u0c%!KSn9hg<dmUfiacrOKYgDo;*+y~k&e1=m3RAKI?LfY}sbSixx zLl*7x!=)#hLspa++0y_9w%mZXKDxuJkX*Qv*Mz{4IIJ|d1MWAo=#EGYtU28Sub(`} zjj>P_XLd=!%_@f=e3c1Mzwuz(B^Q&9I&wd1?eL9*741(>(ZT86>AWt%vZkY7(cb#m zq>brzQex_ZDFNo7e$WGt{k)1#d|bt2pBIvu(S|7I+F>XEy`;9V3GOYL#_h3rUiv4t zwK)6aL++Bv8I0V~O`bouJ$7%JMAh7^AlQ8p+5`{9dp(k%mx(6~=-iFYDUU(}-3@5J z-3gygK1SPIe#7N$G=npI5KcQOBW~YR#9yc0lMP$G2fZzaL6gM8)cBzp(T;wKww+>0 zV6K3vlf7W+k2B=P&b8e5w=1Yy%OOzysvGWa`jIXXo{$-n))4=muIQ0Cl3HCcK?|Ej z5IcDS+LYJ=NYYSUKZjnry&FGE5^?u0IewEmqg1;hw(n-y^|u9N(98;$ z^YXsz`l}phHZ%afE!idO7axKpMYkb*#dmOMWeCG^gUQ0HElJjgHFWHtk#yDCS9n2X z6m~apfwq0_lD568Y12FEP;_b*7{2^U%%-J5>*?=F#@@X+vT_o=p(`V|gIklM=a)l^ z%Mnz2$1?8m>t)y}t_$1`zDHiWAoQLYN>40rk7Y+S#m|gN;i2vtXr17Tk^P^OLXB(K zHYkzSKQ$kDk5h!^cv8-30PJhI6bC%HO z{&@?zR3*l+@C0I(FcNC}wSo(;)v(Q!$#BHd3d;xU%2sOMMemg(F)}y>%eHTY8(v*8 zzxpAWqurkKD%XSa?WfSJTOu&ev_J!IPdqcF0@rbR;(KS~JhmRn!ti)+^7HLxVjDIS z(w;1%__aMga=eTqcWfp5*cvQ3(E@+P$J28LgrcDlu-Mxf#nD#uVQy!f(Mv4* zb2o!K|KJA+yt5dO5N#|f+Cpdcd|4XhYXN2s_iEafavRO@?!C{k5Rim$_VE$xh5W zt0kY~WkrvR(`D`MjKjq{*PzR1eGL2XhMv6WgBuNesdW_VucU^0?bAWyl_6;E9|6Q- zBbs>`lc1CN;FA^u%~GA{19=44j%WtcTerq-ZJT3Mr$^AXX%y|N)j~Wzqd#|kln93{ z%Y$dfmSe=-&RG4cJ9OMuOMXkg(bldh;PpL)Jm9)O+S`wq?ROq~MTtS^;(@&j=U{mA z@nAFdBo?Umf|PkPxkKG|!j|^dU}C)t)QTUGbD>+&rjr_EOkGVfd#{9tHJ5SC+0~@& z&5gJsW+((c*iV}VI1s-!+aUPPJTjrmgnr!d1lfQkUe^69sINH&d%A_-yV>h8J3QWF zW4sI(m+Q!PO%JEsv}}AB;RlDu6p=gJVw9*hm3MDB5{$zaLmp~iru7`?wzPxX?fNii z=J<(x^RR&~lZK(@lBrNM+M8y2+#@|~*c>YNBwb#Wga4)oynVnBM}ZAg*)JjwZr#DV zb5_b8`R>MI{#L29?5ylhp*aRT&cMj#yZal=ZJNQ&v1PQ2>QcC}U!B~tK2CiuS0H*%hfli)z=0zJNachZ zaJq#S9WkL7wXxVq-P&J)EorX!dcGQF8;zqc$8DtVwuj)R^!?y*Hx~;+$B>!#Y-vg1 z4sz$vR#|Y`VH#Spl#Hx62a$abfN0PINa{JBWOX`C&5T+>_+lF2%Q=rb66mn^4%C1bVqi8--f08T3ok*=wIX_U zx;I>2U`;;xnTd1b%1E5N8GPE8hL%GYldjA4u+XN&qx{GgIR0}f9Pyh+C(ZFh&u!1> zx|##nB9+bGdbe@Hpm^-F%olpIHSLKTo5+mero?$c572DC3!5LJ(7m`V>^&%khcUK*zvtN|3dD0XNq#p35M-Py zmi7Oz1XPl4%G$3j=ALv2l3DNVAufnBhO1L9!t7-&p+(Fp@GG5zvk!%NtlSxkHQ!_4 zaE*>UbXz2J8|LLP{Q3(zAo&}OjJ3oA$r&(z?R5yAy&Ffd(_Y12Grx_pE3ig$RA{Z&%G#)dfWsvw-gYI30ats zuS+cVsKD=LWz>DxWq2MWBfD;vgG$>zF!SnRw0r!R20z$C7nUW{JAVCf_r~6^$VNhU zx#)6*sT-h)-bGwASf6-JIEqtx#=(Wt&8Tz#&+gCJ{;i+7iFjvcdC3{IpH%jH4`KZH%q9me zUB$^Qw4hkMkefbe6N=lfCkrlzLwPKv1^U^L^KG+Y3MZnztBzN?xM+E;a0ZqAd+NCx>q2H0@I{vD9C6ACs?16R^3~G#ESl26gSSpE$kIe%bT_ojGeRuK(tatJO@PpGh*r%sW70e;NYL0U|Q*pH`;omcC&+2JEIls zSye%%`S*i8dhNu1tJSfrayZVoxEr>7HiKfZD~231gHJv7!;}7x!Rpu<+;rIzTj=G` zSN&aJkb4l8XNt(=b@s6KP#i3f^o3FL7Er;eBii>Y#=9d8VN3ONoMxYk&&^Unv!yp) zn7$phK3xS4o{2DVj4k#w)yDn_a;O||7VWIlILBJHPU>|NZXY^oz50+CqzLL{PKCUN&g;l~SuAYv{dW zBe)Qo&s6Q>Gn6GC;hI=AFR2Z*1w)N&tZZ@twvBHKBUaYZpEsg0Hefw*4wT}#P9`v@ zd^l`qI}==$n+TVECf_gqTJ&29VX-Z%}?Sz`bUozW4WXllb+!GH@%4I`glOoqe4aoX!# zI&3w!gbk(7=x-}6@m1T03aql50VLR@arHgmQ z6j1B#VIX5;=;aj>>@5uC7KfXG{_?eS<*Zw@eUA!At7=BNT+oAc9lFxtx89&`(=>cp z^^TlhP)wgK-4ET{XrbM~t8`HIHnQqLDF|1hQ19w!>izJ@C5vPeSWsJX7%t=p_^wzA+%hT2Nl3;|Cy(Nc6MoPq&JX4Yiy&d{DhQqL z4(t1EMceAW5ZPZx+~=h){dH;s73&AEwZL>N&uv9UbvOv7qtf7F@82|TzYLw?qH$8~ z8ZcvPNfWpExN0JY?pyEBxOOYB$A>_wPn^I`qJh(Q9e`bMiwY`Qt>tmoL*kK0H-v@iCU z=QV-0^J*j4`4+^^63W2b{U*94sL8$c{ejQVFMK5x=n(z@?}cxat$cM2NZd-QUfTqn z^mPED)WlQdCh{oDx9~dg3{KM6Ktl&*0(2RVNqx*{>bbY{U>|eHzjKz8Bw6A12XDyw zrTy{!pozFQ&RlNrx(T-GUL%Wmu$DZ1zXlKWc!jSVvSe$f-jN0Pm0@^DYkB0UVEA@6 zhcxM$Pg*Z;h2QM=abZ21VS1va>{-whxcxkZT1B778&@^u=X^WST61mkRJV!v3(0}} zd-gEPs~hkGd%~72M%43yF?o}!1|=KX;iasnbYf&8T*xoNuEP^y(T-8Dy7CFB_%auA z&r+-y_o&oG(~MRO8$mX`><0N~V!*3tAC`_WLh~^?@|%(uq(jqa47t057JmLlaz7Yj zU;_uNVU5elmTPJOHP!YEX>WPQ9aFDzR$%!$8AmLS88gjH71lUi( zP01a_!oEqe!t5CMxOF``y>`X^k`tuzYDd(z44|j=3NZe)A9R?1R(5M-XJ{GdMmx+e zz+g)r#p~~r*8S5l^uuMm?QsgSYRu$;3(Lqpoh*7jc@cebGKKWhpMb=)m&dG+1lBh7 zgW^3u$jF?F5P!cZUhAp}&L>jfc6Ae3-w$gbcthl#h4<4mZVg(DXc9 z9=w_Q?IiedRT}2*--9|s=i#qax47Q({?MV(gV0cTN*?+RkbSqELi9rOXfy5@nnFMP z84^U!>sFHsuJb`-(Ofv(Nf+EgOyR@NK2)>X9ftmxL7k$Cxv(i~a6a!w7fKT#x2>3V zA%p2;ty9E=x5CU|UCb#jrJnt+;Dd@Y__+KtB%Iof8aEOl$tx7HRH~`(ll5?1R*3JW zoWXZzY~YF(l5p>Q2wq--T7LH8!c8@}ILU)-*7`*H?&?C0$JNr)d2!G@Zx4;n7=u}+ znpoy^0yW<1g4eVgM3$X|@2(GlH~MpF?v%^2X5(7Q=O$N?lH+Z0-X~wGXLuUIyL{my z-yJ05=O(xp+-fKLbJYl2s;I*G-n(H?@Cg#uzbk~A{-LK!e&I1L4|05;(%8&1^vNep zwAiti#-!TAjR$UIhS?1?3OqsO+rPl{iTjATa=E1PIdCq>EL4c8Oib_2XhLmD7J}uA zouw0Ico2tZXYNA4d{{Lj945x55Z|T2bkFJ4v|E!X5Y|+bo_{)obe@^ZZCx`L4#j)H z!n3c4YMVzS)Jh8r28@OmS6+IYc~eB?PRr;y$s`y?7Q^auqq#Q5A+j%Hf0fSd9Rc?Y z4-sA?m3y)80A2JY{Qvz}#s9zSyQ#YmL-xkk^iop?{20~-Vv`F&CI1tSS@egT z;*Zdb(Kkp{$BC#v=Pj46aNEM z__6O&Saigi9v<=vw!A9GoQ##+_$oc<9Na-3WSd)Fb@f|{Lz-;1{q!YNw<;Z{#%FUCx$HMm zUALENZXO1{I=SE-z6g%sAebA{AM-5%-uYeN4rk|(0`}|>bG$W<>)Hn%ti3{>R42j@ z!?$$bW;Z-M$rwKC>2Vd=NQ&FELL;B6FhXq$wl8}@YMvg%!4I0l&mIe*Yad@QIX0bM zTW}Q)c)i4*-!9UyPl$ia;^SW@9wQ#7}5BhFXC6q20B|DsOh(x=K$!@xvnud15x83eza|<0vs9phuZ@Xer zg#-qTd@Oq;?+zOVR+5P;6KU%8d2rXM3@4vx3$=?bg69(RV{t_73u0XOitf0ZLgy5>rTR53*x8r5 zIJUPAzBvB~o#UOM-N`&CbX#b-%9$XqAf9KNvK=-MyT?b16G5Z&_KH(aBJjN za;&9+{N$(8u=0Mk%yaY-OkB~Hk}KKd-rj3yHm5nNTW*COHm$`yK39+|y?E#ve-{%5 zSz~0#8d&8dB5OBP6Q{n%kh?2_ME_b@ZR&MguQx;%7&Q~aEl1*)gjA|I(H@>@^`kkn zEMdj27m%PmPTYO%Fk)a}jva?4bIlJ9!u8`E@qB1&{Pb-r70a{;So>SONw&Jf~%12;E01yz@R7@7niZZ4t}oMoyuJgyryZg06=vethyx?|!2>S6>{Me;J0-qd2F_89Lot7t3x>z#f^&_#nxe{T}P1Y~Ja~ zvKFT@>C60$kbLV3U9w$_8&t}xmuK!8e9Bf{)P^*9h1?1pt+sVWd*1^NxV{t}! zDOI^Qjs9HchI&1;#80l;u-|++gYCe(#33yR?PFtz_~S#Ew7WfN)owp`c2+C;@Yxyg zI2;5U?EB$`zFBzJY7vR+(hCFJ)VZPUcHkdX6Lfy^ooY7K0`F-Xq5Fz+ur%fq{b;?O zR^*Ez-CzJ&+xa2)$0&rh&w9pr?A!|-%KS0y{u?N{Wkm~*tYdvOj=o;B5FXgvhohwj z=&9-B>DnSyQedYeS3j}}Zrr$tH^*$p{ROq9jf&VJ#s!o2 z?tlcCNJqcvjlF+}#F4P#^bXSR z$40WiVp-`qM=5M9n9TM``e9&N1lg~V2{x%-82zd}wojWw#~d7vwxc#fsGG$7$%6IR zzwZXf-oVbKpWNqOl-;H=`|g3g?>Nl9mV~X-#!~N*NY@(<22Gh6nfW;s2X3E5RX&!$ zMg3seiv{cHg}vqEW_mfX`=lav^<;aGueHS8ZeD;8K954cRIbfp7qq+HnzN9uz~x*F z+&Q-zwtSldYghmFa2YDan8$mu^I#R|YM_g&t!rRmMQ?1jwYy&YJb z>P0F$j)TdWr=i#IXL#+~W-u-*MiqV+oC-D+4_LYx_f4FPd0if$$<``TUENW3!+$KE z|LO#HGA*&P*(&@nWjxu+S>UY;Q7|~%4(u&|kUxKFaMznVcsBDlJ(lIf`RgA+>~c%y zf1|AA6>UYVdm*mdF%k#I+v15?Vq7fhL4S0yhu@wDWFHG+5EiE4?JnP8`?MuEd!rx5 z&9Ncw>BsQG7nv+(;V4MIpN&^5KbH2AS3{xkELnep)im1UJI&m+3~vTsCF^p3}^NN1wlw{#kn=y0?qGs=%6B=$gQLdW2SfOhDRm zC26;A2t1s57Nj7?uNNl6%Pn2tL%Y@}G8hUD$G5WA5>^ z7N8Yj0&=~lSnH}GZX5oQ9=Go)zqyphu%J_*Caee3?7a}K(GF%G2=MU#IhQUiKL_1* z_k&ZnZuJV47|Xr%Hqwgw+ewN^2n_pbLNC6Oz%tG57#Z-AEE&@s2JR`xi+#$;nyB-1 z<>Vn~ccK%PRDVLz+O@O*GhlVlEAl1dJG|6bLvP5vVCLps@L75j4&PyC)1GHQRhosX zbDN57wQY!nRT7=jLA|JzXUfOd*Y?2sn~1mWKwl555M)jM!eofk|@42 ze6?x~&Fod+Ua5gRLgOK~?b!x;y!|ceJI5B^wLUC6J;MMhM*P76eFVDU_Iz27h|~1F zNhTeawNjRNbq%_8&jas12iTeMPEbGHhW++Y23;)<lZ8yG-FNigS z1s^N4SuPT;3`#;(EkVddhZYx zk24Z$9Ll8gGzX!@S%BQuhpEr+F4%m+1l)EwgskiKgpP||!zIoQCH4lHa7?EoYzq0| zaq;<8N_yx6cla=V&W|Q8L;=YeXay%?)#N+E^zi2NWUT$N9?ZTLW7p2zU}TOS zp8atMvNd$z$C@7`^_w1E$sS0Dncik+S&L!7xy7L3YX@&t0N~L*yp(K&SN0~-p=YvW zvbkSn`mQGE(X315u|u1qXaS4!7oi#FaNc(oY8K$hvq-Y`Lln_8P^0TZhk3Z*mv0i9ry^ zx~8H==_w*sbA=E7-uQH6XSAzn>zK_ z2K*d7knT^S(>k-6b5I#+(^6IbBlSGJv40D+-JgfSJ-gvWj;F863Sr`OQydw|z5^wP zWE-33QM2Pu>G1ib_-cF-@wq*hvxtbLo~AFj(WM5ZADh{V!_>CX?zREkjN@ut!9*9x z_?AfgcTI#TE`Xacp1xl;7mn;LAz@bAaqDm|jB2{P^nGv&=CnUVYx@@U6MdO;GHZr={1^djV7m}F5-g;4!C38Px9VbUfSp9K3ravjUta7xIA_O#JdiI z8D%TTjjhR${pxDz$l~p&vGgr**rhL?p*{4@H%!rlTSB#+DTGaPA!p2!Wk+9H z;Dq@Ou>MsxS{bOzpZ4mDcP$TrVL&Q*-zkCGH7me8S53U$Rvk+1zR}`t!yz;{ovJ#P zlC(voFwX8IzTpZ`a<4Va+uIT!ZZIY$rqy)HlXJ2+S9Nf8jUy(9N6Egk{qA9pTZsL7 z*@|n7)gZjpD_M5QAP^rjr7c1>V~-hif&~$u>4$J**^BUM5XGy@bzZk5yq!O)dfcI3$qM?WihyO+K5A|fkDe#n zz4p3t@2$o-%pCu2N%$=PMhHKMN@eD)*NTvXTLSQy&c|0YQgt~ z8^N;Wdh~f_Nj;qUu-{t<@PqX!odJ1RxvwX>IB3bQy_f^@U3R0&pa4?K&IfK78p*xa zu8@h+gUAMM5Bc%N1s>dY$JADn@tR{UJ#(N5syw+vEG`a(gr<*hPu5x3rCmUVmUn>- z0g1TR@C>H(enwDrKlIgSXIFEM(Z%E>G0LivW&ip`UY7orP1tb(J{OE5JLjcQt-$8g z_TqAUcDJKA<6V3B)xrjRMjqgzzZBuDwzFu5Pg-#1lZZ|P!T%tMixC`qRKI(HLkpn*!7QKE^`s66+F_y2r&*Lpr$*4k^c?|lu&@jH&wa}R3} zWP?$LKT)GzlkO8r$BPCsu;fS^XpReEEQQ0#-;E0JqACIhe5MeqlU|@3vkk1WDw!WH zWyJKZF0|-!K0s_DSoy>gz4p0yeZnH7RUFfBhM;%85L~~yjm$R8#pyF5@MhXEsQI3N zG?&{Sx@4jFNH{+Jv5}5iS^@j?5NoodmIz#kGmDu|$b8!}sQZ)%D}PP~!4_GVqcsCn zy-whb5G%;+l{DLSViVoF)*2cmSX}<(B8HU|leUk4P{G=ooxdfM`doO%ZvOF=W*_sx z`nKQX@oO2rL6a`-zFR~ls29Ud@pK%#wGj*57YNRj#g$_5Q1#~$O`YI~zLPg% z@TV+1Slfe@a>1Y-E&|^}%IOBlt8_H8iGKGEqt_zdqEVR;ti)k(Ke`WZ)_B4sJ1e@l zZ8IG{c@97QSOb#|H-X;B798ZB^V7w5LECf*HmX|h z9lz=f(?79-{D-ezGVknJGMdoKh8RtSLj@)v)xHm0-LtV~0utFvHrR}LP4G3Iro_d8 z(B!X7Nvjh5f#Xp6eH2Yh^QVI2WclT{l`-V`2AsG&4Gv8>$FUB5RCm@)X02{INh($c z6KMhT<~-QWvIBUI%NBVDYQi!LM-;!L({zeeg=I1)ng2hJ^PkT3Urw*wc@YT@|HZCK zm19TE57Xs5M`G0|!np3aMihRfwEB}dnUuo#+2%TG+9|JUms`hPlKzv6kgWLgEf zW2eAMD-9QZb*B+CveDzXA~v_rrKMxe#L)a6O?tDJTz}Ad zy%~e;{)+HJ)DkvqoQfKk&J)q#jbM0G79FEE&@0@!^rBG>iX=neTcImyl}*82))vT2 zECsz2)37vUE=lL|cs2jVXvUUmqP*)E?sB*cJ&9tp>pph}7p$xwJNFn*Z4E=q@FOTV zF%jIN5WdYmN5}PM;8M=-xa<3fbeXIMw+Z*z(BN#=M$d{?O&NtZ(K|`sgiPgSlV_r9K{-wai7`N$yyE_ zOfD;JzkwEp7Ggu@6fjvB4O!=h%;q_B>$k8zI#~@-L2Fz5yf6(q0#qnArEx(dS)6uC zk5_JT2HT$Hz|@u|dU)0=BC${bHildB>FyEMq-!H2r*C2HjD}g(vlt$$#xdPX5OYM* z@$KcsxLhm_!|GJ=tlVOHVnmNOAj8LjO+Dz2OeAC#7Geg2%j063`Pyv^MO4o1FLB?EMg>d&6K-FYI6Z5p6mP)dGx+>X$0zZB?fVP7l<_!chDU})^bJHHyYQLuz>63ZO zgs#xnv!rQNuNv?l$?)}D^1*M#F*;q}50?KcK%vnGSig8JggUK5hf^Xr;C>FaxSod_ zf|+bp!xQTM?lF-V2%%r9ZsDQx)$mp(6ie?W!V5h?c$Adiv?n1R$9w&zI~J;<%v5Xa zy1x}ivW8IoB#`qJCUmD_BilFI4E6Q@!M@abvaas}#?PNko%WQFle5c+-kVG^LpJ~) zTE~L@2D_&7TXb>N_FS4f>mMByn1Pvr-{{33=`_!#m59jhgte=}XzKITjNIjPbnFzP zI)@WU(mN}h@78a!_^veG&N&XZR{EmK1RIzm*wti}cZNM`{-mkr#40wUS{!QqkAvl? zCjMXF^Z#;u)i8hEjn0uj_fXEE&8CK%U zrrVM!cY|2RBm+8GSBJ|x^uy>9H#q2$kMI3W@Imimbbj^(j%#j(dTtH+R+|qe&yQgD z6b;zdu>c>g3nMxs=9rirj?EfH&=tdLn${J=Fr!aMy`()WmrZfB_($Wp)#i}<$N-XT zlZd}Er@gKscz8t+%nPZ(x#ObIzNwTNIGbYIiECtkX)9; zZnbIDbvsXN*Db z%~0B@9$sd%tfta~7Qv*`M zE`xyEea5U%f^XmZmCDPE!?p|xpNFLJ>c(w&He3UZj$R_dEA?o{WmP)PeiN-zwZx`A zf@Ma_aZB3>d|j4L?#?(v(I5tF2A`UJzjgyck845yia?I}zeS^!{K0tRQPgjoiUT&Z z$)JIt@D^i~zAHk-mvh|LJ5zYNg3rnyI!9-T#$oi67qoC78!tJD@eeFpj}Jo{X@z?% zz5L)cnLA}3oY0lwv??!Lm#LsaCdXr1J%eVKL_GBT9nq1Nf@Nkk^pa5zmoJXT2Z?5w_TdT4 zDg#g!7)LiRih-^_0n{gH9B<{D2js~SS^kArmuNtjFpO3Vg4@0lH2zwRxkctA`>!;A zMS3C#aOe7n7DMD~cM-pib6DB(hnRYug55h#gLTy&c=&OUh(^ZJqjqPgmr4xm42WXg zUq?`vh6h-p@Ry3rUI-NeS|lywBRwmXO;wj$@OCJ0I&R|(*7BteK3@bR^Q0c`zqFYg zIU|ad8_t@kL>A(Ov}UXllLc3`WUxyZBbDCzWKZ;2I(qIe5m?q`HtU2I?5DQ$NaYM> z{k&p)Szp5@CXPoXuOxbFc{XJFl~LaVaiBfpJ;ywJBWru4_|L}%h{*2>k{jhttU;myFmB1l%39Bcoss&M^=qMCTucnPJEwD0p9~`fL$sY8H#IwySL3v6H zbVdeI&%3Wtd4DY?oC=306Gt%dz)f;>UMJ{POvhu(V)0$FEkwN7kKabB!0PLGNX!a_ zzsIM8BFF8V`?mviTc@*6&HF&~fh~}do>}j!5Iw$u%gTCwOD|RN_(Vxc-Ei|WFKE%SjPcxyd zbPg=s83ih})6hsP7Z2wO(uEHMVf5c2G=LG>etkRaP)Wl5$MeYMl4v?^z5(fR;B-7K zmhSQGAWps4QOkY=_S*=-;U!T`YSPJO3&kq&){4t0_LT=qV$QHTRfMPv--4=$ARSb` zN==@M(*uM3pp`cTsm)A&bAKQF{O3jcuNKha^J*m9s~*R^-%&ch4tg)HhJv1K@`mH< z?hmKqv5OJRq|r)n`Dq5?b{XXICJSg5TMWNz_hD9S3|uf(1KZU%sL1phe69YK3Z1b5 zMK8pGTV0rcFb#&+iSf<0P%<$2K3!TD1%d|zNVyHyL7?5uh69BLu8+ZF!33N=B^HXl z^uypKG2SNb+@TyZ4_$X3#S2sO7{8)jcvN>E6aDlZ40jqbAvPHpFh_R_+7FM=B?-dt`zHf#@$ad*q$~+A2!azv%CId!37ai@AhY%@ zd7R9^!Ik~!zP5l~5vqsz<3!CoMkAXh&5FS8Gg)xs*gkSRk^xigJiPRE3pLZ)4?8tq zQV*9p#^Qh|zEcaMKg?3-_P^bW>mKZfHkmbGxpE`S@(XP&Zf-%@uer=f>USa%ycotZ z8L%mMO-j~9z@%;0>6?uEnDuxDXuR3T4%pklfKmx5pCybnj|f)UPT)PP;k3e*8j?8r zopiQUf%Wzq?1%eTxQsw56*pCacd=KfZ?!SR-O7bW$MWHw+~1~mB2l#GQzm_?l0pUM zc9WUQJ>mUlFU&JAAlIu7qm^4GX>=;5YXiN>u|_A**Dx{TsV(EWLnO?a-uhyMvj{zO zv571Qy3O`&y-6!SnlS$!E8`2>T>N`6p3`<#V}Gt7ub(H%6JHwzy|1RiYk_{MDOL== zlbXRgWdh$L-X8obk3h9n72TTyOY0U z&NnXxneb3D-B*KmMyHrWYaE3YQ+mkI_yqP>p*Ln1irr8 zgpzN>(1a6LR8~_Q|IrWM4c>>Fm!x=wpF*H`q@wYH))?#XY9+S*orU3ht589?j$T+&I8OV5-1p$sU*F8o$xk3>{}IyS7u=C>Kd zJX9Ho%4y=~JO?XRMx%{wFDbXHCr&~8VX?P9%u-)OTm*m6&E?Iku!0P%?0y51dinUM zE{o_V&m?-I*QwTaYxpK202yoFpn&5})}Tv>eRzry{QbJ2e5Bz$q~5fRleqg^pd7+|{+oL#p-%-wr_%5$mmo?U3R zz!65D=|SSfHFz*M#S{`4 zq|auqc7n|0qo|!53I>Z(=_JQiGCSiA6g{zLm$kh%Wu_$-Kp8yttp7CoyvZ1ipC%`Cli(qnBs!Y<((BIw*|> zTV%mLZ8EI9w};tu^Drt8rons1i^${nRFThqc$wjo&;eB#4$UOu_Y%-5`6+1T@rG+(iTKPyMjS~3D|%0%P4Ll?-EdHxvrsf1R>Wx>4{ z1{6Qdg^mSLaJ(uD2=4(LJUb1SiFMJkTo#UR(Zom%Et=)J1pe_=h`Ep$nf;^=Z&VcE zc(rDfyP&|2c_K$2a4cQaRX@u3EuuGu1c}~)$x!t(6+R5TV46myI9_}YIg~U)9lxD} z2TvFCLpW)vhs$g(pU?HPT$zZ`9b@F6mBD?)3EH&E6B;8XLD4XQub)pq zz}rId{oZwQBT@>i%w@Umj3gL;`4?92yh>$s<@n!dBGWl<8?$mu1J02W3~oI_&3L|K z|HB06oVprEa@ENbx#>JB+d}e6b}s7G29UHuN?cDG;N+(fP`BzS^E6bBe@n0s9IA7m zdhZqV+xm_wR!t`DpZ8EFe;)ik&zSbzi0u7OdcfpOC-^GRXSdJ zrU(L)9f?z126lwi5~n7_R1Agvn@-~O83*v=&`sv|qeZAGKAX!x)N-2GAy`+HMBZ#N zhh0z9+0oG)0^73*pZ2B3!tvk+7>OVc4+_ z<2ILrVt*Mn1&6`m$X`^q`8eqeqg3TXC6zkg1}8Pjv2M;k@>|jr$ET}8#?>{%>l+VR zm#Bm29YGAVdVzmlrC{cqB$##Z7y9HSVsqgRII?XMF4Yo;FK_%v%cE5M7W~yLiWZUP z6r|z*XY}`Lq2VLYQp@KKNV&1y`=o(JDoL`hi)zlz>-^gH!>%fATBs zSe(O}?7sr3r90Te#t)%bWg=0UAP)649-vwm%M9!1Fg3*kM0u|sQS--F49%@O#^<}the?^FC3$(?!oPJwW7ERLuYV`1b7d3i`3P8!QW76&ORx8YmE8aBps7i`KEA!EyqgIQe+hJ?+9A0IbjuMM}ZT{utM|7n88 z-e%~j3&l4aH>aFw*f=S*+D!MsP3#@|!QMYr!`z;jNau2UcFeRM^xSz8LRMad=-vMK zeu){li;d$MSzctk=TCukp(bGZ?+5uB>VlUhN?~j8Ia0dKgqJ*OLbm-5#jejj=v}DI zWuvMw(&;j;FH?v7*H_V5N`ZIPeG)Bu<_2Ar^|<`Y6*BGRelx=tne=4G4RYjqA}nh- zNK2mmrY)C~$j!XnB>v@gj&(1`;@TH%q4r@APq%?^Rf_R3+#Ptu&OBBg2LGQetjY2YlxtWHACesz z@jVVGt=|F463!6EopUt9Z{wHs!i>c92>KtFTH(D1s$c1_&PMgL;O8cS3*5kL_Gfa? z;}o&GC5TF~wxD5X4%RNyp+g{&cIK*M_h}<2WP|ANjCoAT!Um#fwg+l-w!@E-$29uc z+`19h+hom0SI`S7BX_bSsZ>cO?cdDl|7wR}!O#}EHgyLR=Dn4)`5cCsdS{#Zz1OfZ zKLi=Ch@k(~SG`0n(e1}Ay!^?4zUq#EFTH0P&v2UXp)UZfbDF>!0XI(Lj4CCxpUSZ=x~X*g z#*O%WN;|FFAOX4`3n1-j7maEtMT_o1v^RBP)x1}tM4}F_-pT!~*a=Ib za^chX>AV$=VjQED1ow5W5rbR*sNcbtFwj~L+e6171y|F&c4OSZ-i4z>SBbc37s=8M zLW`t*NY(_Rqf8S~bBZIce`lgXS|xNgr{g1WW2TeK9>o2;fo)+E@THFsxh_VDj%W#X zync)|UNNA%wiY@Rc>J%^;@R8>!I0ki2jZL$1W&SM8k#I7`?mW|KyMVdoJ^}_dxsj)5ObX z8M!8qOb=%I!u4SZdSk0MH9C~Wuc!AdHbNgD1yc;il$>j z_N2So2YyN)fm>HiXy&_jjN6D7k*_`ipW1qvORI$GVdpwFoIMKX{~d$X6?x2a#W}3k zM`KaTy8_mH*Mcq2>tT^+Gdv=nscm($ zS<%N^^yR|EsO>R;9h-w7Nk{^Zu1kXSFOh_YajU%h{?6KYbiaQav0C^I zefE3c(&&khGPM93g`=^Q>u2zP@|-g2))4L)j|z5;L}RBgjJqzxyHs$G3`O zNIkj#D2+t!NWiiCb9u*I9B{@BCA>G1MxS}E#SfDjL2O$Bc5bVsRn`KSthN+=w{Tre zbJmiphr3DQmM~iXEgX%Oi10jAXJJ4vh1pl1i`!?Loe- zn|fLtM31%<`m-RD>%{B8#E=-X{55B3+Jpp*PZZ-F%C*5y&CQrA`;^*oJaVMMGstk+ z(`f4`53@}!kna*A`0Vc$vOqnQhBn7Aw`L#{*S`Z3-+TeVhNtx+BpB6drao?-2 zf_r!(S{}nDSCTKBzaZvzjc7KmxrEZ?aOa-z@4jNS#jh267h(FO;UOk9h(4+cQ>%^{$B1^D;lXXEg| z6)?<0{JOJ_`U+2EP2byLI>o39|qFQk1J& zNw%APBl^7>811!)XWJeLj=5`?Z~Q3uFWkEP z1Q}mkf`>ij$&-*0_}64cH`|NzemO?a8Evj$^2q>to+QFm8A>K?ivXLEF?hylBPEi% z2#PU{$OynP*PF0+)f>FsvK_COETRenpHQ9$WbtB#?7TP?zkk^bc@lx(tt!v= zTABcoJ}=*!nP)Kq{5iaB z*1PaVoiJ5Y35NNz7BIV?t$?>G(hU1}4{YgZ#lrD|Fz3K_dPVIZ*}7j5kN!CU^KC0Z zS^pDgk%gp5F%niAw7_h?N}8KlLJyycMc*%1(SA)SP_8RPXEwrA_1Re3v%oANUXu6b zp%kuAktIRfgz#L(e!S4u4H<$yD0Aa7r}sW3zwHEg=hV}nADnTGejje$oQNyuY{#^+ zeRS060}cCl8^q46GxHNQWh&PXvBn;=iKcomdd`o8s+e_fIrtHT71YstiL(LY30`#idB@1&osBmDcv zWf3@{(Oa#S^sTvoVa4+F@~tn_KO!AQehc#!nI({g;klrwc?S0fJ%L9{F4GLB!KO3c zcS3hSJ`Q}42UV{?e7UI^c0~P#m48=|KWjMOyw#WWYTp5NVG{f^`7=mX&Mgp-b*3xT zw?W$AbkKdHj(2hg*je%tyeF0xa0Km1wZ9N9-jR(W(yeeMWt0YNOG3LX*7!*73Erv^ z;obfD78%I}C}uYYr>C4}WS*GAJ#`nbOpc%tr8D`o)tb}l8SJ(w#T(bv!DDX&-Q_{x zKF90lNn7GKrDW!Q%I+rB0~c{iRTbHvI+rNVXl93cUZI-ubntYeKN#{STyl49c4?ssK$UVeB#)hDTX^? zmq#0wr8(44Ru*37e;{Wr5a!OhP+GV(2trhn@%Wl-kat()O-=k^7SEJ|*7GEGwH1q! z8$F1so(Fi1Y4J4DbErgTEQZ$ikn}h!VjMgVa#p-V5z7V`Tb_?CRTrS_c{m8HYM@f% zRbl$s9;U{08a^;N!gMr>!LmpO>!KF1k4vVb)j=T~%3O}Kf9*qwHxu}WxV7?ytu-3y ze}nm78X2GYX}CSf2wgl6K%b>B)HW&Lv^o1A4yN!7Iu|li!*9^t5m(?@v^@FOZv`8M zv(1L)XwaqQC-G23J5@bb3$0BNcv5~C3OO!o$8aFzjGTna-wVjMWIoIix8yXEcI3-P z!Jv9N?HXw#LAz7v#&HDW9>$U#ZIv`B&zj`CUP5KYmy((5jbZXyVQ3#ohKH@!aZzC* z8O?2EHHxJ0a^wJ%LJZBEnF^(fN6=h+1TMr(;doF12&BU7 z)&FRo!XVcpw1Ea4YJ|O`5oBnZFTJwakVr3-IM@2y2${sdt66dGCGDh_iABbP_I!rtxi_U>5MAA9|wG49UmGLw9 zpDYES+jT0@O*>5aL+6MF*Maoz;}s~D(LhB75uT?`2wF@}ML}5)I7h9p>_sSzaasiZ zvd@^HXUAc8**#h^!gUiQs?*bVX0oXsV0bP4|$l z^@J=tvIV+sh=8TjL|)ayc6x2s9^yMb7HuXWQ?{$!bv6Ce|7=;IR68^xJ`j)+qfM)mI|SYXdJu9u9{@6eTPR@Zw4ivcH>&7 zn{-4{8kCjS(7y2@tii${v{mXv*UyiTS@Z^#vvkpO<{bQfPmlAOPtmJ8d~mu;0^GU3 zg5JB43(BkPAyR7sKlR!(rmHEF_5C_V?w;BPs(MMZdBh3wwOD3DU^=<6BouNU&B2MU zr7+`J7m=3v!0z)?qEDMEfNZt|DFYR-+Tf0Pa~<%!sWF&oz5-gr1Gil**eUv#&Y0$h z&nq6Y5)x~1TT&9cqt+Vc`Ut`HE=>#z)B(GfN70C>0+*$?NuK|CqA+eb*_WLKsuD9{ z`3GM}*fs(}ho$L)x^cYYlM>*m{{iZ>+Y4@axuM@40SN8(Bu*Nsn75b9QHa{Yog{U> z$qVkzpmGfL4y?fU_uL_FLk@X#O^WAOn1Y_dB4kU$9det~k}OryaF$OO^DKtXtnbf3 znUn&UpK8mWasNL1BqbUiX5WGt(`sPuj{>GI{|aPWIRUq7OnAQb+1P&a4|MyQW3tYC zTrT~PD~A=}6>xWpT#hk~G-uF4v!0#x^$eVxUJ8$%UL=#oYr>qX%duyJCLVkc0=vvB z@w~Y$T5Y&z#!Yc1Cz~fjcC{0D=qnNZvqE%*=oZ>n!Sxk&iC~I;Chl8(2Le4)aEkjo z;@go;T@2#!ec@^NXI(-K*bSwbB zr8ttG9tX(qaRpGxdq&N7eWU^Wemu9iiqup@VaxLu_2nVSP-`E`mO65s%zm6#IzNW4 z^s*w?jc;Jejj!-dIi4)|D@V39sv~i#r=nS5Xg+fqyF|B`UNGe}ZHw6)k17EIYyqkH zHvsZBGeV6v zD5Be<$#C@SU$Q1vk1-7Pqv zjb9E}`SKq5qNGi)1?XU$@lqNRA^>p--L&WIZ;Y3ypx)#=z4%)Rw=A0s7rCyer~QXv z)x2e>>0e6y?mne?!~>*x6EW*v1U@dQhx>^a$Y=9hyrP>xPPUhjf?AHxTXTczyRO17 zX~odEhsA7z94aa=050b3Y^hZ*8J6lIPR*A&#&JGI&E7!D;uoUB9X@$2X@-kfGkTnz zj&nWV(bHA4;nQzJ>fsv=<(?PuUsVWhv~-0_>y9ISsok?WA!P-)b}OvR?o+rGa9sglLqtIY$Z4*-GHB6bMeO#u8+}m zB5&!sWL#O34xJxQW61)pS7fvS{a0DQ<0<{5)YO6ab&A0qk%h!-ML5O^&H!imc6gN0 z$EX&>0#<~guDJqVM$r%Q)OSL5@j3|2S3|q^Mi{(Q0jfMx;P5JW%wKDSGj|w6Lt8If z^u3vVn<2+bva*F2OIcF1vjBz;4KSt+tyKBVdaS%_No+0+HSlaUVv_M4_TAhi-1FH% zcFc9e$n6iIMzEC;mexk}QD9Zo)gYt(9=rbWCs+`YiZVwQG=1-T4-ReDXr$a$+*9+4 zq)dq<2S#%7>zpPgd&YVYyfcNFcvJ+Hq_0rvbALhkX(q}v34?1E4_E25l9<9F_R#UY zFh2Pgy?!Vd^}RQdaQm;QK2DQN)l}z=1->S4PnQwHtu1K&Q=aQ5e2EzuPaxs>8R8R= zguiM{aQ?@8xOmAhdA75jFdy@A&q4#(wXlN32)$q|jLT^5vTZowmJ)RD@nLtbG{>f~ z^{i1(I2ag;@{i5aK$U?AlG`mzf*5P8%hdzzaa~kp)e^kJd9kgx0i4bHU^vc&mTC*) zkp>$Q;m4r5<2d{FOd%5`u$4J_^%!JF9fXfP&SdUdO)_}=Fx}kZ z3&W#!L}6rrm1)~YjVJiSe}9kh@=|(ke+F~8Af6s^H6~vsOrz&U+-a8GE+%=x6ZX{c za*|rSiQT0=Lc(6B)Bo~!M+!CRNzRx4SL+Tps^#g6k87#FVE{E*UV!h6&X72t40^S+ zkaQXQV!_W_?8e+{^n#cN9G~7s+YZi$?GfT=Ipm4mDsyP=a5nrh<(QyZ+}x5wT=u8w zFuGbq)8yG6s5NpBk3cy5T*7tOMJqtKUp?k~=;C~v_b~okI!x<*LEm}SkiAzIK+fCW z#4#Zb+MZlx`ED%uj|+sJf(S@=m_dizu0YJzA6#aBIqL2T#?k_J;2QTZ#pVe9nzfkD zohHsFh2P;|NmLUjnNY1JX*k&b3gd0I<1Fn{B-(u!9^*LtkFN{iR}0hB>>z}LcjrLB z=qo(?I0DumsDaDw(VPzv4?y*;pR9b+nObTtCxIA_-Ah$=Nr4ehyNk7-%}Go{1+3Vf4_2@5()O*bkYDYNYbvfI zyd1*oI>XF(^Cp5e(KT@*O1tn8D&3prId+-MvyoqW%O_ez?c{u-yY+ z(mpfWjhnFS*Hn^Qd7iY0*g)!86+FDn6Cb*F!HC5WmS3BU(ar|+_z@|%aX*nbhBZ+g zsRf|ba|qvCh(Y1WRhU0p4BxxvlO3E6a`&?k7K$FhD9QCSaC05a{<8v&YOCPTq9h2@ zmnOlP6Zp)FO1P``s;M=t7GpDl*_->9(`Oq5_&mGuJioGGPUF9bWm}ID>CXMQATonK zGsy=jlY3Np?n&%ApG{}pUk*{$@^IHO6{Wd+39e8=T@L~No|g{vkhwoi;^rNwc0ZvC z-fx-a57H%;SDKo{bJM|Qtk;%rs?<>K?elO^Gh2XvGubK7{C2nrOJkaqvMjN=93<{-O zE~)T2-K~EWyT`jA@9<8v(q2T;U#qb%Hpb)aj#hB~`IgaWG(wLVi$J^97{;wQLBI2d z=(zn`V5`w6KDh;u9C(6+2Abfz?mc+!n*b=7U4Z%Dk3dM^SEBx~588~X>D8B#SXfky zlw+pX4F=Omk>>1P$6T;{JAL~K|M>D{`xhUzpVcXo@l=XrWIPa zEw`O9GM&NARXL3Rf}fCS5As0k`Wtd&^D8Kg%qGh|hN97|TqJs4Am?L+`=J0T;*OEB zvUzy?R|Y)VWs7fm7Quy$ELG#bV)NF9p~}`gh$%mX7_ky$^rGRw)0MDSPJ|!QVnnMy z24INyed_sk3NMRXqz03ovCW2BPLd4-za;5( zHy|VVtBC7OoGpzjxqjH|#sd7$_i`9N%%Po5-|3RdgSbpj5=tV{QF?S=5a-3iaB_3I7Tbvl+zad=P0mdu9cXm@P0s38w$`oRf>PDV_AK5SjS9Hpwlao|BY z`YQ54C!w7QG+YX1>4U6C3Ad-r8OK+&DZ@*%E2vy56p~7%*is3{4UsZa2?^4`8=18aHV^fO&;qXkMo$UHzv4M%=P+)gDjMw~R-9lQiLF$2QWX5uVWh}XXa zlui#LdzEfdx7xdOc#$i7S=9o1sfBc3%@Ta=GLI*pIUBe5^wNx`sc>$pvzhVvHhdOv z00+-`QX%sbwB~>z3jPzs#FcXN{%j{ydGQNv2;r@F4#&4Eec<7Xa9nRZgY&4l?4i2? zU+Ja`dQ9kKb7%0$^yw4P%c2Zthu*;!-W2Mb{|deqti|U$BzYIB=fE4j2Rzr~vfJ7^ zJYGu{b+HJCzdNph@OXWiDS8k$rc@B$_h<0X25X=(Nhm(08-6-JXU3kFgJH#Q65is3 z|D7A9+}*&ox0P0oDFhwC0BRsgQkjuuwd0ms_@~y*{o$J z$%g7l*3F; zItKZ}8{ke`KGA)N)OqnMYNX-(Es!Ds zyN|-IZ%3h0YC5SBlwp0GM5)eiU)V7{iQGufVP5pb(j*-}SX+G<-qaj#nrN|!_>T29 zO`RA3r}LMRvw|KpUTisi&mQ>S{;l%zxvc-RG%!u$V@K8*Qs~W+`M2azbU_haRCN~4 z+gsuVrj-PgR^#639bjX+1cXC|z-Epg+5DxDmQE3)VW;fz!Tms#vlnHX#1e_L0)rFZ zKA}TR)!_Q3c9?vP&exLhvt>|3bh$2v;;tQvq z_QJGJQMfI6A(oqHB$LtodyM8jrq z8=a26`*JaD;4!oMb{f=4U52Oy%RqmkE2=ck$FjgkxZRTlR|lNX;)g2i*xSm+?8pE+ z-wG(;JcEA|%!uK3ZZ_YaBj_nIgnr*QLc$ek{t%IZi0*Uf?;{BFUvS-7M<39H`gAx< zRPo&UQ|$P#0@zj63)eZF`Si0j*miP&tgMuSLc2jTt9^RNPD@7KCv}`B@)u4Xdx$Sx zi^$uc=o_X44SmZY?0($s7*Z!caK+q zeSjohKU2&soa6!v*KP*+=3R7t(-u}?FpNqG+mYajr(o&4323Hx2rb`zp%qr~xb9pC z*e||JVmG+Kif;=r$>B7e>vIiub6pH8%5(6I(qX)_aFAGwi}7S+oE-hTbu)q z9c#%Sxnz3X@C8{VF+?7GIF65@2BsbLVCVTXV)D5JI=1f%`LpdVFkJ6W$Mz+3@t-}^ z)L9nG;>@VEunI4<`2s{FTBAm~EV^s0M$3m<(2`nAKZ#D_86VHamkN@orn8kW&+1`0 za~cDErcjxOKj}NQt&n;Helfv^{0+VHH@!{t* zbO&LWrkIZVe|#qLlTVVCqaNh3Img9xixTO;e=HYwg4H^qP~|(0ccXtX^XOG5p8wm4 zx!=a~G(K{@vi1t-K6rq!S(i`!xVfiMk2oJ*@i95KE)skPv#>=?43#^j;pOi_xV(EJ z-rDJid-hB66=HwWZ})Grg6+d(p2KDc{myj(OQ*oWXQ}LtQvr168)3M=b2H`|#M3KV z+cDfzjQ=x98de|BCgsPp@x+`XtmMlIJnbDv_q=z7Vgqlg*HnQ2#OB~zT8xpdeq095 z3*|X>YJxyL42$Nmqc`Vc#HL&tc_Io`hlQxZlul;p<`cO8;se%3h1*N*wvzHGx|mwD z422qJ(ZiP|@r(Ul_NeM0brFn4tM+YpQOE(Zmqj%-^lSy5;6gg;@t4|e&!E=ftMK}IQf?I>&s6#5zc-an)UX^6g!uQm0nI)+It)s=Yd{fY;Bz z!y2yF>znI!a*ylSG;fQhyZ(o+GYzLQZ2z^H3}sePWK2qv$nf0PYEYU;qyd$Rq72Pc z$~==Ygrt;0=E@Z7zE*=m6f!iDPzfm|MY5lL?Em}jV}ID6d|rpOp69-=>-RfPhbiz_ z<^f8}vw5pgd))n~78)&rF!VSdME-WgyhXFv?29KI%Q_90gd$1APd*GkGy<#qcjDGz zD|~2mhUKEXWbcS`K_Ie+n*N%D4~pJk^%8T;%lQr_pER-2L6Q^)C9tzWw%@QTf`ood zr@?knFfG3xe^tBToy{Up{$UJCy;$yGV<}OP&_b1=M{w0in|fAM!r(v}geliirm>am zTl9%+isM7kuL88(zyK5fioh-tKHPPGF-TUHz}YxYl5gdY2CXL`VtW-v#~I+%*O%D6 z^B3^y-&d&>sSYCb>rr^m5SYc>!ncAa33yb{7ug3${qa<|l`w=p*+Jx?!z8#Yyb(hr z%b_Y-1DH+5_)f5=a-N1R&G29_nyID}CUfBA(;xJZryO~7><@unO-wviM|Y$+!lzZM zaPP-I%#GKr@a~sB)za4oQ;!9}^|8;?FoQ#1oc;XYIcV*wv#`I|mpxnPbH%^A zaS7)Pyj5#pwsKc7rIEEy1{Z!XIOGElgO?~ri05eXw<_) zpw5e97@G}1M_dosjy!Uma_Ty?8cMt&ylc&tTCg{2(8U^DsNWz^0TJ_=vypZT6=k2p#aDEYOt!g5h*}T8qJZmyz zBo`v@4U_Nl<*>5)A$cautJoJ;P3niVv9$CeD&OCacUEgs^B3x5$)@$ZxtR>EGG569 z8ZL*J$at79ltBaa#nJi9U3{Hu12)&5z>}_CqVl7ZWF#8Fd_7AnvgxElOK0Mf#Fvi716rwkIPb>Bk@vFn3!dtLoiZ>izmxXILMzQ_)4LEDD2XlAs zQ#i3^BOLL2feF18bVlMW8a!f{k)l+ZM&S98DfFL560@k` zBHiVG16vvz@O7aeXMMjey0driJfSkOkbVY^Oeod=HLef zF{pK&!MXPO27DA@-!_3UuK35;iIbC#$_z`GbcZ;as z)1-&QDA~RDIb~f#Xv2R6K79$HX0M-t;kXNY8fk%7peHY^--$qj8*`0P`7v43mqSwAng@~t$9E}Dk zGGA2*Fob1?a=mG-q&xPUJ&AVt6DZ}_KwMxo6KEBTZ>N2u>PPm2y3u<4HJn#zX)O(6 z{E?u2=MO1W(8A=3y_iH~z$|)_5{&MxW_NW-FgDkpY~ON?41VmziGaKCd8q=33tU0x!`bNj z(h!oHG(T{SYB)2Z1=-{^|djc1!7<7p0^@9~0p9nBCul@HAH>{&3WyOkul zB*D7lC8PzH0(h}Ah#Hnx>Mo0cq0>0MGmgSH1z$Wcxs-?)3xd(tW^T{*2GH;?$3r*8 zz(0o}#|Ja;i*q8B=0($(<2LAWG7tYcZYI$i-on^JC(zqIpRTtTfV1xwVa0s*-VspG z=5h%>*b@NIsttbvec_cwJZ9~l19i1Y@V@#A<;v!f*lJ@Gia!sc?)E4ua0$gG<)WKi zEmL?T3CdkW@q0om8QrD9JbrzNyOQP9&+d@HTTW@%H{^|Oei+~g7m%;S7 zMsR%d5?wAw5X}i42bghQgt|-bwZq|ZCM!iFWR*8zy4Wc&anopcU~C>lma0+ zR)-r~L_xYOgfJd1u(U)6(#9+4quOcIS$+nmds;j`T>h3@w(=*pUPywb+*-)!3uzMB z7RngR3#Va&*Qn83EjrK858Id4(a0a~h~vE=nh?}P4VS&bg4he>boN7LV|p_c(7D8R ze?>u8&H&f0d;(kd+{PW>3(QvOxdXom7q7A1&4un(98Z82tLn*5V$u6Tr5DbbK@dbCnFq$gR9UjI8PasD%%bUzOAMQh-*-UxSBSvi^cMI283OeL4DRdE+LINZwnsY9Ydb)|V2oQB zaT(IbMbK}XAxumpWB9WJ*cjJBo{CMz=jU%i`)~}dlISssex}3`Sg{QjY2Bko4i}=6 zco-24&7#HI1Gq1IE&(0= zJ+Sla?&SKfQcPElhA;jB$a5^gFE*#(r}s~4$vX9f-^^l+I7#Tc`5}0WCV>5@9$o*l zlP-6c1jT^~>{Iun%b)y(=)HHLggheYsW0fy<)TpjN&wCc`NO(bAuuo!i+tRtud;R})|(A}tP9Wd-fEZ^kK2qJ~lC1=bA$I?n~(CB!gYYvQc&}AIr+#40*}B zA(D3&zFgW)m!*`FCn5JBdrSd(s$_@;=24f5Ww@k(M}s~{^HkRa!P=n~ra@~6Ux%cV zW0uBdJ%Q)Zq0|cmmE@=kkKGd=&V%=pS(?}~gx6x& zq_z=;KWP!`wFxY8pC7w}AL5grTHqpNhGPd7puAN&IlP4Jsn$d?f}@2vgA4S_pF&Ux z5kcno9cr`5n=a*_4R?2bFnyRSi<>M7otZU@E;Lf33$IX6wmOSjYXafK!4d3UA;vkv zNgzoz7O;9`D`U1&6x<;MEj4rTm4*ehe5fGjf@Hy}Dhv31Y+x3##3Ml+82RQAE;w)m zYP;VswpYVwXX!$$*wh6B+br?z-ZPk@wiqAD2=o56^6`pInlQ3njaYd;LYGszJRc!l zN{$rZ-dmE`UF1mGmxt0Yt2~%#CQEc2)tQ^(zbf0aL#gXsHh2B;JcK^j0fSeBahA{! zSBxm*xwtmsXSo;8d9pLNCKDVSyNTk4i%iw-TxX1N37#%a0W+Nf(0G#tG2XhIrDwZI z^Q0wM{63Icg^PpI2oFs)dSQSggLnLH5dThl@VL6z?8$$%Iq+8Um=)$0eMJ*mBr)MSw^he$OZ;mrxO* z7P56)4fvNhQuj1DNP-&XprRjs=`153rK=f!tHbcZ`F^E~S|;6*G#jp-TaItUC3ugO zH5k`0Ar#gNg<;msboV_MLXsEbv9=~|{?bEaGv6OrrgI#A3{+#@@i^MubstpT3B%?b zmd81DH`rcG=N3M%L9@34sP=RU2t`fCV;h@^oI@$dsb!E~T@f%#a~82P(owd#8A6#H<~ZnhU*QyBoCJ#W&<`%Z$-T^|}^ zEzZ*%Xn_kE{bZ4a1ITdsxVG|HklXp2OxraCF01-UI-5Dv8(lzdj%*-Tq*8HO?I7~8 z{lTrjGhwpsbMBJ&132~bd?<4NO1x(a5KjDNc;t0~nAZIy8}535YnTl<=NFObLu-IP zU6fp8*gH3dlYbw|E58a~!52|Zuu@J2Eci8@##8{|9nyjWRDXu!c6A_QkfVEAZAFRdn-=A)A|qi2G~Sf4YZdRV>>L;|XQNMTZZ> zL}TbIeK|Uid=bpWQ?Po`S=jJ873X%d9`3!XKc7B=R&5&KV7DEKZ0oTq*x*UvdZa{gc1v1=47 zUz^Y=#cOCy&LtRNx|v0CnN(-$T^MA$oqBD-cyVtpc^mf1?DX6``f2SEoZLaUhFeeJ zz4By8*xk(XZ6itHZe`N)r+_;DY-G7P%b3qEGU(u~ay)8*w0IB2+BrhJ!dMOZ%P5jD zv5Q?XfZ7C!xPov4ZzQWcT4)SR84L~m7M-2@xB zQqx|b9?Q|^%$|oWL+8mzO)lLSv;v+!wn5*jbmqyUZ+JfM7JYwn3g@KdBP@8y=A52| zLG{H>+_*G~N%%)dhJ6tnHa0@W)1RG1%i!&h>3F?tIZE)GqV;46ynl|5d4Fk$>CJtN zQ5hlJ5x&=0t*=DSwN_S)viqkS>$bw{r|~e`<06gEi=h{bStpD%%ad-;z(`|zXbiqj z4o^7%V~i1VZ`VOwz)Yf%nLgwgE1kL>6v*h5_Ru*2i-@jE3sn$!O7tE4!Q)O377FM7 zfBBC8$pv!8uhH$9&9IW2O07D_Y1G0~>2o0;^7wGk zk_#}xIZEOqUyweAgJp&4NX%yvw_s^7zM4dM`*;{Ve+qMNVL3dKP9`%dx%8XYJ@UTp z8kp%Cvi)B#l-)i|ORScXukk$kpp$j6AF?B{oNB@s8VbEGWu%xrQ%&qm#rH>~*|~r) z?D*kCB@4#6t6Uer*uhjhawr#0e2v3eKUpY$Q-}ujLoh$DiZ;oybB8j3=fy&hVV92& zR;Iv^G@sckk4Sv;VLte8RL0OScf8gh3}@G^gI(78u)1XpRc zA9rZyt|+G1HVQw@)*<}uD(EvF$MWMpknIMhXq@+t=ACfBTYcVCB`1?>9xcj!Sn`(+ z>ki-u>vZ|25=W&HKj4>x{GdN$3Tg}9B%&kPuw#J)=hN6qd=^niUtATY(`@zd_NY5q z_;x$JCW0hOb^-0#P==Q`Y{oM025P{!pM=b{#Kf?FD8$aq$d^QfFhv~jT7#-=KPtw% znx?Myr>Q?b;rSD3bXwL+)`58pJKVkCTcHX^E|AS2lBpn?$0Is#%)xT3$ZVcUGmI9@ zf^90lD(^0`1obOWc}Mab)>GE`a8r&WJTeK-S3e~_-vuz;Ee%|amq3Ta9fsyup{nK& z^2IU`Iln|X&LN55WRgc#=>6qN@*?T)Um@@?ZwJ?`-yYjNpRsJt=d2$g1ur+B<2q}W z5Ve1Pyh@jHT&8#i<&UJH{=guaB@qPkj7zAOk1uE{*>Y9o93iStm%F21j@MJfcDqyx zV2y|-C;M+P`bQ~X^yR&nvivD)$jG29J)c*=_Zuo(^x>^iHOoBN#^wi1xhnfp$#Bwo zlxvN`6R-&9Tnpf?di{q^<*p&iY{fWk8u?^{^=-20krp0OS`Nq8+yNayMOx}=MQ2q= za;Geh#`QjXaZ}0!M31J^{p+5Ppg3z>XnFXy ze(MRz+)A$U8GfEpj4o(>{Y=WGR&ph^C82+%0<;xpz~G)K^s1>VY*|@M^F3uqWDTWd zmWzqh%>5{;76++`99%NZ9dYYds&QI~cSm48yz7XFD ztw5PuUqM7s8$7>+aWB_LlQ5}DxTBX1B!~56Slq9ua+2XSd=kU%kZ7X5CJLTR65ny#)C4^fo=sD`m`Uvyji~0(FZ^ zhABoo>ik!cvop?$fXfxM@sQxGin69&5Jft@=TvOc|Hhp<{Spz7GN~;0+)gESSrIv( zXK;6iGbD<0py+S}D2g;;FV2S!^G9(1l`f1LijrNk=VE4j65TY8p7Vml4-^N&3xGu-;0Hw{^y0@bw8{ zKW{4P_?-dwhqAana2SsW{xm!O>of}0wvt_ZE0C7U@=~twacp(@=}!Z9=Z+x;+mv3($JZ!p#Jr1;w+&Et*6=hNkcI0 zRnDLj_suc-^$nD>IgL*TbT{vOp zYUIaUbFcxqJ3Gja7)wlvI|~=TZZ`GwVSzIPX&^7Is^S(2yubiuSVpJlA+^obj4i2mfPw zHYK9lFG0%F5{3hxD&d;zLp-x(3jIDjl`%H}Zny=h9Si-Ypsh zwrs%tXR<+Jt1J`9c#z(~xAcWX5H*Z<;JTf5N7Kq=DE8$NH8ul&FJuzi@ri?F+Lzew zyfKKB*Cxg2H z!y8O>z!UC`X3JQ=Nypbpa*D42yu^Iz;HwmvpzE_*+1U9G9jyIcLm!Pq zkq@0GA@ltW?Bg@$QT?5;QJ@-iq?Ex!u!G6=W*rjF{G@;JcUZ~)j@&kk!cp_>*tqmH zcfDO3%gyw{$JJlx_1Pl40|%Z%)n)}6yGWV7sj`Fl!`(D!wE+13m`|LulS!QHOZxF< zA?R;seW$CxW8wWhu;*SG)-lAH_EHl`!z6{2o`5VyNd8YgD59m4l9dbuo^q$}HU-EGC`=*mO%3%kekN>Qip_~z1o3hye9UQzgexG@tGtha{m|9C6^ z!)0%^W9R*YTgk-rA;@hrt#FGFrn_=wV7iJZvue#S;d2l>XuFP=H(vTE_8?b z|CYk#D}>1%pYdL7o zy_@wnU6_qFewSeKrd+tgXGiaqmk>hLp+UX`ewn*MNy|Y}+ZTkcE$Q^ra68HKdja2a z;@QsNN=n$@8vkBzvTbh(`pPvkr!Oo9jqw21OIr`4`<74v>15cma0l1qB|G0H1{zet=pnmBxZ~kQt{$H>1x?mhy*3!UnjeGG{w(Muf6zkc7UKOFTDbKL zZpjR(6lJ-QU-tZfka~Z3_s9XBwj9G{sS^LqlbVrJu+;cHli96GrF89vDTtzYFC>7vlsE$>tOg{^Ue6BEU zosx{;hRfVQrT@#VN~#-Y3@&-&Kf#r#a^f_Jqq20Fc|6J1y^PzXm%tl0Nf;K62fRy|Q!Qyq#oK3MNp%UDyxfR<`%iJ@%qQc* z?Qigt%^2y*8=^}!)6nUg89f=c!nEE_nPYu%3TN4a*|I%F?q#Y7M_^_&YR=q&KIjRGVs&(i z+9{IiD+BM1Z(*mL2Fq^^#nYy()O6Mzm=OtZi)DLORNTO^{-d}^RG4SKE*#HhJA&2@ zz*Ue)Yn6_Wo~Mm8e$NqX`JIHns~an9PuEkcZ~Zh#;T>sx<^Ep|R4ShZy9bHEsa@+( zacKq)-YCKXc8>RY&M@6@FP?@Mz9m1dYa@#`BjrI#;Ic>o9JVR)9!+|{TwwV$h8H|h z=1?^=*`k8G@TnEriKt=Egd>gr9fqbsX+$q?f=)S~OmkLPpktX5gj$7yqJtei@tn_* zR?!24j{77x^%S%B*$1=s+uIrQ#SCm=88)?IUF3$S6s(YmA(q=^c%LlJfJ1RTG&x9M zyQ)5MgxDtqtKSGm=CD2dGA6_k6z#LqB7xzl~GcN}$U|-if8Y_N)d-6my&e@hh zHgeTolBaJ)*_>Ml{4~|Z+YdbSYP5eqSW9ouBCkqYP@?hypBn z+=gEdJ)zCI8*sBvCQkkF4EjZiQ8FliWlOHb(js9_o8&W+oL+^;gvFtMJOUPs7c!M6 z^NHL-Q>YW+=LxX8oA;Nq>D(9LpqO@->Nm>q!iCR6`MWTLU1PLoSqGhP>ZL=;4!AV8 zhd%GoM^z?|5$$sWft)}rmS*?YmVapL+-8z^@H#mY&H-M+1B`sS53}6Ix%O@igf}>a zGxeA%thjcZ?mlc!{3gB2Lc7jc(<_W+ZkmrUsDkH+OlrrsW0hP2T#~D&z*iwP6RW-A^L6k z02zF@hGC3k=tjGI@LcDGjn67ch~XuS+!jNZe42xswn&0>JaT)l7cus;ZBVj966ro0$`9z=!wUQMS`d0T0}DO=vJM+h7(bOwiA*xRyM|@+=1`RAtHy@kL+Ig{1ghHSsl+lu52grn zc2^d`W+C?8EEtRZ?SGkvsk`WfLoRqo?ErpqzJo*m*u1;Rb+Gk~WsYqvMgi7evRkPU zV^~J&y4W!i`Ik`cp(u3N9}mTuGANlL#^c#ZL)V61q$yE@xNQo9;?H~O6W0_rtA7qw zd>kPwRtIo%t0Z7DHlb_yb!Mx1IPQs(hugUaaSD3}P@6B#6uT>9*VX;-MfWheM^XAT z_YVCoe;dxdV{_~017X)aDZHV211BY3pjFm+(AEK@W}z}3o&T9z`!^P@rkSCG#AAHc ztc9{J@8Lt+EL=6Vg9K^(pu64oSC-^t)2SmT*qv)M$(kpHr+$7l)7M`^bE_(e-HvCJ z>8GWKw};+;%LN(^6r)g;)lD} zG23t(IiQk;Q5tnHd2S;7ai57MyIsLS=`oS^iiiDM)Uj)-2TX*gVb#t`%#vTtTq`Z6 z#t-<&sJs9#RM-aY9NkB<$Kn{#OjqXdwVe3} zrI+{+oZ*8hPJoB4Wx9>1JTE2zMS@D$|%d%i4zFrgNlCA4#t#2&$(-J9`q_GPB7w~a12 z6HYq5*;19B{bbGdyKw%81g=`X9V*qb=z*Jp+;>}VqSl#yxOP~Zwl|rg$zDZza_~C| zIgC_xaWd1v-Z3;wd~j@XCH{4C0+$oQ=x=(IR$S2~;&XHmuexJ&d=Nw-Gag5=du zGTgL?fhDONxP03PZR|2C3qHouj6dF}e#noCd_KzdaQoqHb2YVF8AeZ!hQs3H-!bC$ zC%SEyHgxH~VZJ~6jI!eF%YBz9h59inWA>Y#4X?&u&-^gA+6k5v&f%_4YyjVmaMqE_ zx{m6X;F)>y82#-jBFkAR9A9VhZvUBvSUAdD z(RhMNQr~d#{a!M+#s^=%7-Te=I(YJX5(ei><7=ACxG=E_Se1?(`R8l}n)IvbP>m{AUHw?*QbTwGALJ!t*^T?~ zah(jW_73)Q*5mRk8 zJnJR*gVzC28V|!I)CE7LCXhqPc5q~-Al7II@}{>Bla#={_+{4?`o$)k?)D9=Jlohq zr%&Ri=kwpuF3W=?M$sGvxc$_#}4DYRWlA*liSaoR&+!T=K zZKd}?l69;Ut@q-l{+$Ar%ZxyERU-};{6lSyEv;Z#a$GjM;jI~t2{9d)Ea_rq&*(_qCiTe z8^FPkbvx@vk(Ckeh;hmV@Cxe1k0)YD`=L$TL=T3yeWeknQC|Y%W8>)XL?eAx)WPss zTcGG^QyTVk2lFBF95#qLnmN77NB8dnoVZO7L2%k2X?~;yO+{1DlHFfu)=GlTpbP3c z1yccsCMLT44~^5mg$GSGL1(WEnc*mkyVV7Pe#@%V)*HN9DH_<21>nQ;^woNif7{>1|sM` zsjFf*AsCaQSich82CnZb$=#YV)EpB+*LO}>m2QaZKBvKzdTr30-w9jZ?gN!ZA#}4C zf#X%jK#HfxDU`U0v}72rHa@3EgKXfgqd$sA9#oZ+#aY_X*i+xb9Mv6RI+@pG z#&LGPFdB!aL(40tvKK)eHupJkZV(*oW6<^QObixgoz!vDnZ)`m5Py9a@9t^D3X5A* z-Z%!%@BP5s+W(pP@T-gE(D*?9UPr2*R!O?p?Aa9W6e>~pjV?O;og{y8B41}LhRN%r zXn#Z^Hs)TSPUYtz?c+-v$X)?0uX!l-ScVo?6f&i2=Mu4pos9LIrEp_S12KGQ14S0f zG!K0%|E-t>S=-)|rhq!^yFZ(X2wp%fi>-L>doB}|?S)T0Ldjt5H}Y5UF+B1i^yNVv z7!MsXTVXi`-xPO|pK?dZnC~sD9$JS_oX2qaArI=~;DBEOA8`*}%)~RgRiyX3I7pf{ znLb}TkLGMQfu4C`VEI-WeO|LHx^cFr-z&*0kZ%Q_^NS&7Y9W*lJHyl2@vwY69Mt%# z$v@|nsM{V0n`h7EE}60m-H&Y}PD8Qe{VN4@UeB@$<(6U4*)+^+=i`hFKE*9BlGv`r zDli6dtou6;XCIJ3h4wD26qe$2AM%FH+Bx|3r8AXYUqBZY)X|ID4)8p^4h3AM&?>h# z^iZ-Aud|?^nAZPfbW~J%n^Paa+4^&EGAkU?pN-HR=k<_{3z^AO=uii(kV>g1%c+;T z61dd!|nyywH5YwS$jV1S(3HjHknly;qr%z?5Peye8QXS~TNT}J zwk~}=IN>J-hZf-1;Ugew%Al3ocGNDBfMXNhba0X;MwgdD{T~mK_1h2byA&c{+Acci zdl&@h3=Dh8A^EeyVNg&I2iB&LuhSLKlUGLtM%#$4x=`g)gE~Ar5{Ga2|A>I5ENiP<` zCuZwVzqQ+JpS3h^`C1k1`DDO)+XT6`8}7lkb8?&;@)KBitdQmTjB-~UF^1(|$}z;l z82gwpVtsS$|CyuzXNLa2|Nf^J&ttRk5Fk#3O$C{dHIu35Qg0YuXi3&|pR5StZKNx& zoCL3c12AE?kxrb;p&DFOQnAMe?r7@Jd+KE*>SQ&O7VQp42YeuBZL(QxrvsHZO3fnS z6x{u~oK(9zkfcK^=;nFHpy{tS9NgQ&xIM9@4es+RS8!dSl-EG)pSTl~o*KID%AWtu z-TW72(!pdkQg53Ffo4t2(SNV#4E_tCp|KX5ZhwU|&wM;vHW!*U1p&LPg1SrAK=RQMm8LmM{TQvxM;N==d1i| zs6YLJzC6Z*JHHVc9)5*61zwo`#uis?jl#aWE%@}q9O`;C1zui_fJhx{IJ$Q>_^|Bu zZXJ1gfc0)H_Nbxhz$G&8u)L&+20=x`Ie8<}AN za<&uXbaf`_P7XlJ)&VNl;YsZW%yI6V+oXM88MU}{fIM2S0o^QvQQ+D(-1BuYHis<* z-8UY1Jt3c|dI$K!_b%a%Ws{IlU5sDmhqtH6;*e_v1w1Du@8iv1XEA_=R3MR|U4~ zlkxIIB$htO0IP>XB*jDx4Bpk^ogaRfM!d0RzAD}Dz6ex%DYxQ)2kw+kh4PV=#JbO$ z8>=D7yZZDyogdgvk8$KM>5VWi{$4&6nxp`=mb+ly!FzCU8y9wrPR6k73a~^Z2d3up zsQEe_bUhaXu613k8&;A`>SVp<&H#T`EAjT;o1753fO8<#)W=whl-&Lc@q(5dzdJK^o zH4M4(kfMgma~=r}axUb@bLsm?D=KPL11A&& z@ttWP8Jji?(vRj>ym-_Nql*PNJ4%$mIf`Jl?_^?r^$=V>)qrx6QfPRufdrqbgF3w$ zmh~l%kB=wA%FfN8T}NmP%YnPA?u)J0z7y?nah9)V1ILZ!c*?m);FlH4W;gyxLc8YR zwMoNtm3%lkz5gW@vYZSnBQKHrZSC-NMH(JnnM70LN=WmB0Qc`;IyC8Shv(T0e9v#C zJqGz0$0xCy`Y*LPXu)z8E9rgSDc++b&DE`O+~a{%=n8b%2b18C&LLFdBHbhQ5l^SQr~ z;WX8d6BhMsul_!BA@3<={=KBH)9=vL6BL$c=M!$g0#w>@k?dQ)1Fp&VZOr9sxp-lFiPXq(_Mz63ulghjh zY@B)sUbxIQ!#_i?Uwu7n3|b9W`|_|$+ZHX-OTb~_6yDAh*7G|x2&IYw89TSd;M+Bs zWZwVCReGL6Uo|{r{AN~?8)^yQeaVQHdI-Y>;gw9ETMm9nyF$10t|WKmZkQIn`ALSq zTt*RZQE1v44y$}iF__GxDfL2Nt3L=K*M!M_mZ74RtO>n;qG7%1BWju&jLGZk@qok@ z(!6*B+Ls4VX9+!H6?D+R)(uxV zUN65UlnI|5A5nAZKNzyw57j0mFauROG_AOpWM5eecJ>0KCVLjT-%{iq(@TYp4we%n zQ^}1QIa8CG1teLZ)>-oZ%nAE)^d{Va=1EJ4kruiMSH1|N5^j!v`+9S!0-D5;cc?jt| zJ@W5`7TEpqLQ8TTd|B?I@r$M0nIG;F;iFs8ZOu*$y)cEU)Hg8hC*Hs#wXdv0i{-W_ zTk}lqBgtyDALPj+Mbck?AC|PRXW*Z^O-mn#F-rDg(46^|nt7*UTjXkJURwgGf1fc6 z`qN>wxS1Z>+Q4pkeTj;_9q1YAU=qupcyJ*dyKjd=TVEXM*3H2!nkjT>+aws;s7*eS zLHM59O5dHnhsx7s;Q(1gVmDoZ)CEGEXAhs5ZLrOy?K<0-u-teI_g%*w{xTDaXRU!5 zXCC3fpI2DnuMaG#@W5Q)QLQ7|a3(DsUbkk`&;I@}XnCFlomh_*mCC%uHNLQSdk`-6 zGeVV7Ro<_Rhak}DHul%LAp1EN6Ss5Kexih`I3ucko zCH!>#lxVXq&-ZYn^aywRuWuNvWDTxT?c~$=bX;4S1XrzOvH9&Ayld=Mk2pV@?pjVFH1H18bxbmi!UOw@Q&KiEg_EOr&83;$UIUSfI;t%m}1gIqM z7Pvlm$++i+Vs}eAcrH8y&Y(r3$FuOhUIb(3HHWjQpn!Py{{r6P&-CN#z{<}Pb3lG( zEbZ@c!N>CcNL6a!>82p=-|Jy?^%j3z%3TeFI74t+65JM&01v*Wux6qFW8B3!Mn_$- z{IMAsSn>nAYhTh`M`dtlq90oLbkYff3$!9yn3r_gnY-GHn6vp?j-Qz6)`%rTtlZ+Q106BgcyhDwiS)Jy(`<7pwIb!`>ItvW&N(tgnH@fA>% zss_i0RA8(r4|L771Tjy=nNjXc9{BSC4Y8<%H(REH_grxZ->fJYmMkXzL$z2HVI`QX z&1HyuwxgfAz92#;g_cc4G<|GJ&VStra}vd2jC2y)v$R1h@eQRq__^TqB-ff2$PF+E_H>>f;o=hAHQ-F&;jNn3%2Clp7 z2?nC3XyEOPVMh08(~4d6;qE2qRa^rN-1j~!hx`2E(?E8LD=#|tG*uER#5dojQ;9?g zy!3K5EFT#|Ht)<}MV()=n4Ja630-)NQ{f2PIDUfQ&4P83B`#b%(8E(?V+>q&jO9y_^36Z;;@@_s)# zLtcK8C;Kj6McY!2tC^||-S-GaM85=;1$(H3Oq0qLdh<(SL0hPPrFkVdrKGTK&jmW1X23JJ9fp!? zifnn;L^-!}3+x>dzyXV7Cb^&hZ8Xi{j0X>2JXD5^(J`2^>@+jk=Kw}Z{-h0dEGXGb z0k1!9cp#vFx-H-sEDy86!Tb~?Tz?2NJr9vc&3A-nXh3(rHQ{|6XC>&c_y`x5J%?rO zo}ezbjY#wxpzQM&ytO|9R-C(vhWxFlC$^crk)g<*Hs{>F52ZPVJr6*+5%$cLg#N%d zoR$BDKWeSW^OTMz0aP9QyJF$fXA3e-e~IA2^(xj@)E}c)R-osD&CJyMD3D`lq6W5aVCKE}fI`;j#-$$b|U`Zc3lVh z6nUp(_F?Ze3A`@OlEu68!Do9rTgW&-j?N@5=M&H6;OcN~l^96I`hc}eBE}~jB6kkf z;O&h+XnpBaJazRr+P;`W4a7`Pq~Qbj76rpVY|(B#o}+jgjVyM8DDjzaNiK<2osJ??SLL8ek2-$2R|M4@lWgrKDRT9}byO+3 zL!3RMQ8xYtUDnrNr!JpI*Bw;h^$ett_xmDo$SaAR{Vf@O8*Buf4}H`%{2kFds|AJ{ z^B~_Q7N7iXqSud69OFF$121>r>L;;e{L~E)Uy8J4+;W`OFi5K8*WjDkftYtYmY-{? z#;q{;xY~LSP7st(g%~YRv!4eI7vsQm2M=DCRFGF5r%BJH7bGp1%c%b>K#2+7Abv|5 z_&-C4W88IIb95C`AYMw#jGXcPP%6o1Nbl((AxWswu6N~X!@=}QP+5--bTOhpkBG!Glhd9+<_rEz_>SH?)j{6I* zUvc-iOe6LaDpSXw5+qymHtAo;;}}zgM7pzpzF6){3%UGGSk5OBu2sck{ty%VC#Ncv z)5?-Nc2Huzftj!6O(jw$!?ikB8nfd$nX#7ZqVxSZzSlWCHDrU|4WhtVx{9W#3US=% zYj)yG=0Vo992}IG%zU;?<-EX-*f4G`F3Q`4J^>sfZvPm#keoc59`MBop$xKBe>{&3 z(M6FM?s^JzC4O#BaJ%*?)y_coc9fo9P7%~ewiV=Q_a7Gct(qBSm1EBVnG_5cG!RAbt9nwq{MD=MMR^{xxx+8E8!ux_I@w#|%)_ z>-PMbOJ%5=wE?TGy=c>i1lYVU5gMdYX=c-5$iHBNdSfN<*-#v;T^kRNV^jnM@o!1d zk6I?EdI1P+Q3vl+59`N;WnqZ>A)+_ajB&0|rLzwG<~K_vGI?Kip}&z7rf!qxz4MF1 zgx!b9`>2iJ`^2A`eCh?8D;$@9XB{>FQ3ads*+TL|19bjYf}2Bb(U+tcmh3MkW4*=s zO-;^RpLiBjKL3L|+%n*ntR%YFA`K@W*iZi}-+Xht4Qb7mfYM!cD5&0v6W+R`;PwC* z?C9dhPvTfXzxPtzwE{Bet^ze@r}KQC=m8TLidN%%LA`lD7I7~|)5s;YNc zbCD1tnPAW55XLc|INyoS>TDPd1!A3>%JG6Z&$OK;i0u;Q9a}yDb1rwl@Q4be{GI`4 zt-WEKqAyrH9K$;tX@k3GJA%U+QR?cb%sluw2$&m(Uy6UwsrDt{JF<#-%5@R>(_&#? zP#Nv+;8>$E6LEOa3OaFBFp9p2=NvWupkCh0>Q7DK7$dXbLcvuEd8KGS?Jj;U=!7HM z`v@F)M4xpV13UPLTJ223v1)6Hr<);DqtuLbzfB-wk))u?BoAgArxC?SF>G59MMd5j z;=JUoH2cM6eA-XQe$%^D<4+tWmqg@2dx0cuqr>1eB@r^l zw8F)x7*O>sprH>B6ITxnV!ZD$YqdC+<0zZ66<%48xmpoDc(Hi0#to`HLb3A4OH5MD zg2-3x*de!T}x3-VMwXL0;vsWIZ9x4gus~@6cUS#8i9L~+ngh8{=Jp63r%;-H? zjK+&?xC}@T^&u@xp-TjwI)4g^b_-xizb{>vR6y%AQ`jxb+)14DMNTNj%_EW`*+=h^ z;m+V=dcW5Zj$#dFMC9JgHv{F`p$T5^|S`51iH7IH0j z;wO)DICtPW>0a}euNIn%!TJmNF*X?>aX~=as5m~Y9Kce?0CZinoD8Rla2caFsCe`a z$r&z(XX^^lHvBS_e<=Z*4~FpJ%|$%s@Bznq3P8qY7ZZ55fqp$N3kzMw2#%d|C49G; z=;Go`Tm$|<(y5QEhl?oBg;_+6)yAW%9=A7?q3~unmR(++Qk&zM<`3&Kc663m>`rf`nN#d7@H-pVqlz!1s5|o2RmL zG_;)l?J1@Yd-Y+6$N3#pa*3OGp>5}xL!7r+jrV0Q$2<=nVAm=ofc1a_f2p)56XaY- zwmrOnW(xv=S=dEF4@t2zg7<;Rsy*NmB}Jxp3bNa|tkUrlxbBG@+$`6?wB8E%9K~^i2Qr{1a528h=p!eu zlrjhR+Fd4Eb|~EI7*f43D0`PpgC>K>0N8z#-Oe-+ky@sf!6h zYjDopZ7_CJkD5tr1IvTc1^zvkh||7C+M0M2f)ZC@-Mwyl_G~@K%eaH(+G;pB?V(+% z%yxQn^e2Ab>A>Er`-_(jD?{C6bqr9jC;jh@$-?k^Ag-5(XNK$`J%D3BebYt}p)Tqz zw;WFIO2pF_L+L=819?0kfm8EcA*}jm-TRt%u*Aq3dSBYZ>nTCxrLi#Nt!o60{hEaN zT1Z7a>~L-O444wJgE&4M4|C$TqPev(jPudNR(E?`;^4v-Y>=jTLs>9xZZ}EKYr?rf ztKmqU2MlT^aQ@V#>=8{B@KHEKzTHq0q}1C&UdV33Tyn+$mtv?S6FD!5fYC_K#HEwh z(}4&LklQy8<3C%1M?kr)QN0MNyRC*#qRFIuemOe()Kht@VA#F#7Jjq~MFZOz7~mz% z8+qrB&o{>qvkS+`)6bEVnO;D4&Ypq&UQ=+BqAFR=li*Et2_;MBM9`}X>JcIaXnKyA zz{kEA^TpCw&nzqS+{s`=r7kY`V}LOMvpILICh6thhM+n{dT??+g284YI(-9pPndzT zBr~YsoGg631AG23U&S-&Ny8109c7Be>Am>nwOi8(B2U8 zD04SF@DXE=r;1aLtQcnJy!{Zu{eg|wT^r(#@=H8)a%206dSxw+-|+_W1LyC0zbBV~M;Vvb%_`%uJRMsPp; z2UUE|xqMeIz`7cBs4PB>-i}e&vZsdoeALKwZy8X(r4Ip1JID$xBP`CmOhSJSP>*^R zsxz+AXt#0rkl9KvrV8^8O}575-Iwv^yeQ0kJq4>DG_oK2#?u|~km8cfo7d>5WFh4Gt7?0fVIzPuxJjMS7B{;Wupk}z2HXQaxB(`+oNnfR+`Zy z<_oSKmgl9W3S+b58T#+UI6YaY)(M@VV+IY$ zK3^&Fbi!q@8VG=ABZeTma4qToK9Bn>?lRZD+<;>5XlmCfN`ni1AV}Mj&i}IuZ-j*6 zwE5SV74$k??{*SL%fkq@yJTW{PKaqTu5W^`QEQDV2fbG+MlY6O2_{d2>J#98(SC|rR^WQ^S zt_4Hxjz>7h?O3*Gxpyemu)@>CLydLz9~|ef>fTYel4awI!|8|3eeII=_c5 zIZ=l{RqmtF)$61uC5BZjjfEvG@gQ`jmXrv-q1P&PTzW_fH=Q3M?yC$SV1F`(O8297 zgDRveoPe0fL$FluB&n8iu@~uTw)mocb?=3_3ivPmPe+I}y z-Fk4px(l9J4#Ca_ClCv0B(Z_V!2NPByYiKe!%Xv^&JtsR$O?BeP*gAzF+TMr-Lvl5$f>uwJT*=)8+Z zPsM%~Mii-(NeMg<-w2l9V@T5FJQ77$;;9ljusC#=7>yrc^gFpjzy>4uG#UVI=TC#q zXFGuf8KMsUH$mLT4#b<1vFL$1*iKg<&g8)O;=??zbjGcsj?K+VYdw-_?P0CckL^IZDSj#i51j zEEM~(6nt;#V(i50T%KMZJ8~+(u}TG`s=c9dUL1O8=im;>0dRWYL#NyLqMO`wdi7I3 zaUW|6eR-oOx@CwYv-!kiTR$Tl_87|FzkygTa4se@pHVAZ4HGI~ld3FL2ukLYocIS2 zDBZ$kYLjWo2hN|IyN$>M-=!p>oW$G<=ReR$qUw9w*zpM(JZ;W1Ga+7$x1+3rTR0@9CoTvlu`oVndk1BZghF_Vd0x2=yZx6$T>sBQp% z@#p*ldC!0-#DH$a4xGSs^fkSYQ~Sj0wAMWl7kc=C@a>QEtWF%RTj7LW2Ep|3f;rd& zyRgY&EEa$FhpWlciD>5rJT^WJ>*w7;^;z~{qizD$HXBg#iY-XAT}R2)31B=V&z+eX z?d<{c`jfqt8^5W38>&7ODwyH>!wY6%VtwHek7|BF=J@5uBMSj!(|;&`4Q8 zj|}_Mrp0+MXlah3zeIQ&lKjc^N-L5WTnHZDW)QUu?p>gjye~zLm#7FDDis%)G{W@0JCnz@Ly=tO}pf>_G8{B^@ZaLuPoVGLvE-kkdtPP)s2i zcHE92U??T{EUJ$mzzaeRZ1H~3J?PNN0hxw3)FV0_`;IEXr7xVBc^c_kUYJ?@i<5tRw1pctRvxOjIJ@ChvzTuA6Rk z@+slQTV%0q0@3WA1Dfn&oLtDA?}mR+*X%6x(klX&TuW3E_6PqfQFurx0)FEmCcxVO zZHlGvs!Mdeg7ZmARi)um_-zd2SXv+5SK#%g;}~&%FSEjs2R*8b;G6M84DxiOO7=Z; z-o_%56?PncE>48IOTGi-1k);^SM>gy3#hev7Mc6>DHW9QY3)Rcr+$9o9E6%scDR~Z zGdK=gcUeHMY$et#+Da}wR$+Bcb<<;)y}^5HIoaDDjV4dRQ1`hEgdOvRmaJ6R?_a>g z%sR(R^xQ+*OpMd8TNYosLkCoB6zgMkHq zVRg%MPz_xVuj+%z z(o5U8XP5tu^>mL@F&aKKrwg|)gZX0yQ2)hjoL`m#io??cz2_3)_qdrbar`;Ebrl`J z*`~pC}p8Y+C4h7{Z;FZMp8viYSDZ0$`0=fyz5JrS5~Fab3#9VQ8-AIb2|3>xuLSg@sx zgZBAehmFg#m|^b(+NTwVF9fY<9`OhzgLZ(7^CgVFf0K>)mjJ!a#^CEZga%zkWJ~Wf z{J2^QXUv$5rdbDZQ|B-{P-u*1vZ^r8=PPUVw+kXe-jkG3Pa=N61SZ`x0G+l})b8eZ z{@YmcZlxa?tr4M5kG7!S_ztfB_XbK58$d-;8BQ2nFCV}#<5ppBh z7`p!{@m3AZ#VhhB==}O8_(3U`pE&fKcIllU=9ga6BjVEZ9@IgKVGrGKU?!v+-6O`X z1DrQg59$xr5#38UxFV8cqg=|O2mM=#+QKRfNP9@cmdD_X!XM1(6KdeL9(YGPyHLvK zCYld?A&XC^;^>UWED;vMH^uizo98U9$8Sjg27aI+^XiH8rvc*qOAklgZa{s^2>udX zN+)lPgikIjLGWP*S)r$mm#^=`Y6A%}5;{h3&#i!2afL;LKbPtMir0J5~6U|FZf_p4_*b_9{ zq;B_JaUAU0_z1OI6Ny)r7x-jNfSl%T#(2hC8tG7u0^dD&=~yRo&M1@4p0^82-r0hM zJs-km%%-I;%c$hVvApp(k;%1=!$)&9DV>r*7Un7w%h(~Bv`zt!ENe!!xKNz!Kb|eO zbtR>ok0JX*EbZNWhTgazPGyFL?s@$iUL}6B6ECQPA2-iX&-@ICouz^D`RP=N z$I|%IuSwXevrI`yJRWSMG#6rM=n_-d*m8>|nKTjVJq4>y@HnnJNNxsii3Ya|_P0EJNSuX&@dyN$~UPc9>RF3)S1R$iVf9 z@K3jmJmq>HDvCF896y&#~>*zkrK3W3Dqr^aQd=SaJ!JVDYKBdPp zjIbp=0Vf9a!QsAIkbXG;Ybs@6*P#kJ5M9r)Oh2(H0ly$d=nP)v z$3fPR4A);=0xMP;hmWMy$r9O8T{}5s4e9Rd##tn z3uo?Ni9t0LxLn4R-y86po+bP&S`V8X-OztY9(1~O67R9&@#%sM#Oj3=sgS&c1v~P% zzq5);tmSes+_`e%yu~;(+X}2?7eI&l4R+GE@t|uo6Sq1X0P#mBnDf&Z9OFI)rBd!N z%*!)&BkOj;f{bh6vtb(QH?W{G<}&DuNU<_U7GO+UGK!`5(#jGYcrzRTzLUAR-Xk42 z`l^tfqW1z$?A!v15&CdGL{#uzPY%p22FT*#6;ONlDzwTJpcF3?Ub@@C^KYSqg(JrlK}{E%%#dlnYizjH`)8{K216P9sh`ZCzl%t)^)0p3A|R~#q%Y9&xrAs z-dfM7aPFCEc?Kj-&W8LsrfkV)ZCDa1A*l5~$M{I5!y5G)xN=e?s_k|MXzid^Jyq~S z0hiq>Er6PKYw$f0iIu(m7_7y|_9chG&a<0HAA5tHYT~3_!~sInqVYt-ELdVTmO89$ z#gb=-_?6@H$X&}PB%JjkV$U=Ir?=59SOhl|w-E0-U2K1v048%}fxOeE?`tdYYw}dy zr79I%u#zDIiqfD~kw|(={P_7>d$CgDB)B*uoaz-LBQo32!+#cr^{&F+R8R8KHypjA zrFnJdet>Me0M0#)px{jUg!9Acw(R))Hb_f+Uyh(DAxj;EBxs2 zZ66$cD$4Wb-T?&jCxhfPW!&!FjK8PeWUd)#5>!w2Y8Y}$1C>BmTNx`2FS@=3O z27PBN!Gp!Y%+5G!I zcvb(u^Mbo+BQyEJZ)RfM2}bCO z>m@RHzX=S5O3C)+*|bS!9t~dM2c@qXh+0<@+0>FjgY4JAG^rh6wfHFYRjlSKkE9Z7 zE^D`+>#lBW?`Cv4H`4yKJ0W=aK3IR#m`DwG@UwSLr7sulgQ=e%@E@q2wwv+jyPam& zL8!G4g`0b8nP-1;X@CMVra4GHIy$33J$iI83ovoWP-PWy^ zb9*Lw!i+DJ1Z=6MZ(g@hi{b#dF}BR^^NbTj$n+D_;SdaGP4`3Gkx#6_&}Fu-;uWLa zzaM^O9bmP+ni$K&OPNo*4*a)IWawxtD79Du3I#>vm7FZ(jXuWei=9yPQ-}&u!&qRuWF=C!=@4c=8+bGO;DBr( zcnmASJBx6Q}JC~1&v_&zz@0)_u$f$$Kb$^VkWEAQ^qA3`#s;0>w6O6 zl^hTMn6U7O7NrKB~;g1O=En&H8Khnz35-#Q5EYH!d> ztxJ&~IElBsO`KVEX9KrJ%KwJ1f`a;+K|zK+?BX1+bxb@L7$89sz&)gLgY zb~2=Qa;^)30Dp;iz!VKl43)J&$8XlK?Co=Y(_dM1wveN+Dj#OlL~`%iPjOpBHoWZ* z<(x_P$eMp*@VHL`;~sHdNkIzwn(1I}0Us6)G++_uc6Ajqg4exWw0JKCsUMTc{*q+i z@!~MN=qKdnUctlHI5$e;URdT|3CWpVn7_S;%JYw+X^bU5(LES{v^9`U!w4)l@WO`9 zDdd1n8d!Fw&=6iV8g@;=tBJemVt5IpUyA$~;^V_(PUO-2CpgV*B^elSWW&GYW960e zWX_!`k~t|CrrizZ&IhWflP?KYd)-LOAs#A^i^TG0S^{?dLRi(Z9Bu@;;9#2)UABK4 zT(MyAlH@TuFw{`M3`)zC_&|?j^fGb=ZRydExzBl6PLXH&Xl~p zM@-(%!UvByPm|U+>fv`1v@#>Wu6Z1e#vO-lqo|PJscJX**uwGsb``_jo7|q=V?UZ)slm)QT>wABI8W+z5|weC%+h;I zWZdr3=RUH$xmO;OmCetn2yG=(ZxP&R{fvqpcg3^z3(?as8!K@gZkqK@GK)LfFHGXA)swXd!J_p~-Lu)GM)mX?t*H0c1 zv>CI5@vwTdlDWvuPqR4d5X@Z!O-mx|W~4XZ80S1VZkIz1H?6{LTjvl}!6H;&?ghD4 z5in!)I(2p&57TBUD+tKv zH^}T7NB{cEQ0KlSr20oDnoYPzrhF?WwiBFy*Ed3+WuN8P@jh^=Lk-I$t`WyW9IGKL z5?17|0^)T9`mHMQLowIKnA12j9_#uSW3R0Rwfv)oHPKqF#`L|+|0B7Mg4+lTwcdc6jn$vd>yIN=$L*P)1EIS zIk#5Yy=V|7KYwZAF@6qI``Cifxk>b3WE#zPe@^$))BNNe12|Zg2Vea5FgmX00;TkF zx>@xf6s;1+Uy&A~lZJExZxb3dDG2&u77|w%^3p{0QsNeVjCEl*9D& zSUN6l6UwcK_o5H}C%QUd`Hx={&c8+d> z+oHK}l!i$+5i#@Y$Fea?!qA@W(YLC1)^F4zn{PI>~ZxUqs(o&pUU<`A=7u;C1skx?~|VH%30;ndDT zP*|IVFSh=G+wIjfu%epQa^3%|B?Ztb>`!W_MGw07N{WqLYdY zdA`>Km%XcC9(?x2F3B&%pn6 z7!C*fpsZ~&*>@%oqh_T*&5A>#^ZIWWZF-)+YAM0KNAv-ieYVh5_BXU zBrobCX~vpbT)*Qe_s&<26C=1?boU!rdEh=-zIO?w}L{t8vX|A#B?lYd3RK zDmSADVPthr!^5G8#9HA4%uBmK{4BX%T3Z=>S3k@*X|#Z=&niK6FdlVkEKvLFd6KD+ z#5u1bq3VepZq_&CMMe42hdd41(+~qz+f@V}9oI@Q+Rr;F>hN_ z7P&e)i_F(z*>Ek+Ly$a%r}<5l=>M~1clvb0k|)Z%-Y3nd`KX#>DgPyD<5Ed~@?0#Q ziI^RbLa)3H!s@5~_&6$#{4+7d*i~`hbtMnZdfy~2>BCrkxRt6t1;N#PF7Figm@Tsq z5lpvSNOOkk;o4M1ni4z)Z`vp_Y&4SO^lF-w_mTPH)Cfgp1^j?$J`Fb2MwNUKK}o(b zq)*o-npv;0|BeVx?Rx`V<#H6q-QNhJ6CQ*7Lp5;i(Zb8+));d%5hm(56D2!ifzl2R z41Mhu)q59;9^^3j66k>5zgh5^d={&kf8eLZTj{QvJye&21 z>tkzpapEl!T@<6r4pRS5T4*l(3w5f7Y0spwbeyvU({}khJwBH~AFaJ) zDI|~DMVt{ras)VQ5i&A z)N$*Me`vF)iHLb~4&6u(xb~|QXAKq++nsj^Q(Holjs+3+q6j8>8G>W>Kgc~MB4FLM za9Kw?sYG2c>x{(1E*dC2R8M~J9)ew+1MHl51AZ=COG|3A$+NKvn0LqoQk>*@{Znng z*!K;_&J5skbByYU472v8Gi_17}^#ZGYjs{>! z=ak&Xt|({TEwLltve&Be1kuCH2>3cEi&gXTg_5N$RkZbP(^O!C)}fm?S;f1eLUbCMb)NwI6n(!Wbji zGrtU{oim62`!k@uI|h7`-?JkoR=j;~*6`Dn^Oj7>hVwsvGsl~Rv1VjEHZ?ATE#I2S z7xQY;7du>bURCiB{FxZ?Z@ zrK`P*nu;dmYhp)_@$rZ^2>s2V?q5&JMnNi1=dpiX+7epbk;_~RTmy%@N zi6|y5$=jQkg73Dg@)S(;p)N8B4MUPZ>9Z*fyd4Oqyz<$7y2t29X&sJ?*My$MYH;!5 zQfj~IH4$-6#+iNFv9G}$n^&fgsEBc}a9a=UN#WjG?nZ)3ogC?L{Kb@+nK36tvT(d* zK75RfMWc;%bW=bo$Q)0nK_yA3MMueJ<4d&9+6CRyEif_iGG6mAK!>QOzq7K(-`-7&8&tZWYm@ z*0s27wzXhE+hmM*(ubqdqcGm03ySQ=2mo|wupBY8fjz}C?&PU{P-GN`u z$RS4seit0M^O_X+o6bc0p*osr(~Un~|D=1GRWb1Sch+mzpS{1w2c?3}GuzU+=dQyt zY+1YreMm9xcpCv_o4Ec;#ssqBXBI2}s1go1FD9qEGMRm%eEQx|iEXgVWlc4O&|+yZ zcK2!tZn`N8E-iC{<)(p{pezllGYar;s6WbvwlhOBit*mx`Pjhk$GQ5^p!Cy%Y{!XU z;?YMY3jbn?Z$x17ihCrhd@&B}Iza|JI^f6S80xqs1lO$Ci|Lu0arzY}rgIzTj8(h} z`+xOm40gYS!GO1ao$U}~| ze!8^)&u+WKTr2FN()X0e58Tv*W0lYr}%K-r2?(L>W7ZSY4|iw z7cUGn(1$z52%`Iz!rgz1Aok=j7|x`)oY;Zt6L;E|5CM~lt%-A%8^k~V1@fS}tuK5wHK4)qd0-W4%No=;kglt#$S9nHf6j;K zx1nSvaU>D-K8R!G?q_5(Z6Om&H?YkzeQZpLK5w8oi|mm!!aKY4k<(qkK*@FH6CA{% zG$9Ca;M_-LAz<-c1>c$uGOa7};Le>vQiMl2w(0|>P{;}WN9Lj5Y6i2EwverAfNxhs z5%Kg1FcR5`5A54<;iL=@JnJB2-$BrktDqL9j^LJR${2Eamp7@ah`an3a$oug=Ked3 zAJ#1g`Bk5x{b(s%ShyPF_I6Mq*oMJI)gW`%kSuuo0FQJ1ZLcq)aG`w(o*#dWnf|wt zq%5t)h4;^*_I6?1Bd)?5otglKExQ;O>h<_$;}Y zo^Ja{DpHa$y2XW5z2=;%=C{~C`EiiGdp+4c^ByVX{97~IlJKGa1^h9)gMB-?2(D!) z!M7h{c#j0>5OJlHKGsR(R;vNfxxW@>$2&pOkJnV_u`IQGtqYD)=g5*NC(vut0ixM( zl4#h4!pQcOpm3^%sYqpLskaC!I%JdF-h5EgQ>819=;8|bc{px1XFEF5kIFA**v7{c zp|IK_{QD^%g^nlKz2;a;1&WdwB(jjq+kA!&HrB9R#b*g;=^#q8$MBl&E#WP`=#Q2X zN9oRk)1dy2G%-q!Mk$vGV68Zt_RQmP`^Q#dv%eDBCnduA9uZi)K%93scMVnroWP#m zDAW#4g6f(pWTvbuMrj^^4ZDI+PGhWK;HxYQ1f{|q-G2VWRrfJy=@cpj{CJ(P;4B|bAO?wqJa_*zvwVQDL{nPO5>^JJ^`-x^h zoHs?2S2wP3sP;&VSiE--kPMuE6`8FyJqz;Z)7#?w~k_^e1EY0@s4nqn|Xh{ z`ir#RFr{6iNO#2kA&W*kxom_aPEG_`llqB@r)%I>$vZfHZ8>D*Y4Rkt=hODPNic6n z0k&1-(6kU?=t~xY3-i6fLDwA)d2l=W6kU9mY>4Mfbn)1~RJ#{_UzsW+XV@L_29D^ydWNG9xS}6HE3|6*sZ0?8; z)W)inOuk_XuBHOe)~SWq@iWL-KTnj@_yM(PRm_5wSMaaDG0))adPv^i#7cNwrz_t@ zlTGr?G+S>c-Kd&JdvGD%;WFmUe}93EuN_<;a{`^j%aNnhfbTkM_F2*gav0A(ytH_%-#^+QpGo7yhNhx|hx=fi4ZzM86%)64FZs$?25)-mKF_C5&=<a|ZEd^|YrwuuI6=PB?fXwSjA=c@3vL6My6ae_782g!6Xd)zu- z7#7AX#kN<=Xr-SCv~wL3JIi}`TTz_%z*`m9O@GAk7fhh`*IwM+D8;+c%4L1V3-hYC zl`?Hjr-(@+*EjEx!S|=N;n8hr-V~=+Qo68*z4B==&gjQ-I-8&XqZ(@n8+Il6U8}8f(7dqA|`GUgA*5b~ zj^6?=CdpwJ-ve`p_7hj7uUtnc3Iv|xaLc!;@ToTqb}k76^TbcoX2DoodVdV_)=3Q1 z!ptDBwhM|Qr}6ApsSD1P-lvk=YN+7-F}U%pk((dbkkY%A=+)_pYqEtvT3`e^S*wti zkAZ@gIc%8Z3;uQ=Il+QK5nkQgBBXXRmLP^r$q8q^PHqiclnD8ImbeDMN?`W0Q<$ zudg0Mp-e@g0hOVNW<>wvZ3wimo25|6)7)c4K$AOVh*to10 z6_#El)>E&*1I@?uaykQFgX*|m%3^F7J%cXqb)ZOo8eT8q?#Z&Ebdi-e)VH`0VT)vN zNPPl1>ha`8`ZM;xs3!k>oEgbozn#puk^p-gtHCo(TF@%yS{8hNK75^-L2DgKiInXw z5LVd7nh&3&SvUg>IA&>fk_hwc$VQSdB1{elH;@4?!}`aV1#6)2fMm_s1>-M10GG52 zXdhBR<&RH*wR%A?Iyn^!dfh;sTl=+hvC#k5RaV$A8kTD(z`fKK*ele4J<$w!Tsi~C z-&8SYMUO%0p;D|=P~miKJVwz}j!evqf)~>HP^H{Yy!!jem$qX#ZQlX-a_%}>sxAbk zP!r7EX9*7Ha@jiHpHbgDZSXCqrLC2oxadeRIe2~}{%CW%E6#FE$12o6z8sIY zJfP#x7~rZL?!0inMLu!)Ld5(sXib=o2~a~G4Fu4bs|(@&o}=)`;X2NHJW3C0aV(aC zd&`$UdX57U4zRE`j+Q=~#(EX^u{M3SY`l0kPBC0g3oT-47!fA9c?X$CWqWY{JPXk0 z{erGATXwlX1YU$0!fB44)Ano`B!@i4<5^j#e18YVq?v-Nj{|I1)JNs7&hTQNl)!t{ zKQy)OD}ARk0v`IkrP2q5z^wWgR`kEXSry*YB{BqaA6V11E~}}j#9Ursb6%%EZP z2DoW4jP0Y3ybA*>i1t+qy@5td(18Z_g!}-xyOi@`w8Y@ldrEvmdl%FjEF~qqGQ`h1 zhRkhUL-o7WV483QP8l}DZ*x!K$LdVhd7dO}$T>>7ZcO91+j+o9q#^Eh7DgvicWf^T zq#P%aElx3l`b93BCZPl)CWpfY_hRC?n8K;#Wl*dAlNz^70$ayg)UKTX;)V)>*PpMT z^o4mK9WxQeC1=wP&S#XrDgjot<0><ww%noQo_xtUX#l~Iyn0iOS@mlGoP}OA^BB0tQ}rK zme{Dnj|T4EdNvpN#07!%&K(!dj@bbrw%)XWeDTbBn+UxybP zmXkyV-{#|C>2UnOn=A0rJq2<1l3jla>HWqpjWmncyTsle|+Ib{9F3G}(Mk?EQpgAeb-V^nA|6E=v)-N&3j zUyv?NWx>0o#*%N<<-m?=H1U?Zbf05s~&ymwkaQ_aYz3I4*V`cWr zkDJw!v2@`%I5JBG*Z(~kr?V$W`OP7coUjU?_E=+ePYyBe zu;hE`j=<8z;h241L~v#2WAa?o41|;gR8vG$z(02h&)c~&#KDk=PD%wa!)V%~7zYCj z-!sA2r;$`)JACc(o^YdjY;wPg-naA6nogikG%j%*oD|YC{}y@LVh6Ll-qIJMZFIaO zrv-EggN&d`8oM%pI<)Shj|4qDr&2$7u3-#3B2R2zEkpB5J8+xuJN(wo)tW!0Ft4YnsFcWB1NcB(%&88bywi1fyvG+;p%9_^_i(n%~6 zx-Atic9(;aoC01*NCySeCg#oalXP8XIApLBp!B1NKx6G1=t>Dh_uD*R*?9u*ktD*1 zY{OZNS>&MdE#jfK2qN3m_;nelFl|RPr=i(TU;jGE_34$c=H?nKD7A&E#i=+&GmO|~ zPKAi%Xu9o2ILL8l?;fM8c(IwAHLNg)UGKheXSBG0_-{tb(q8gK&;qe3AMj9)7O8(c zL$L8V$47DOBM;tH!^k*2-W{7C%s74o*!(2}LMg`NY4;dWZ<@|B5p7&n`z+W{kw%Lx zW#EkSVhFhs2xZ0@*lAQvPTd@YrF$lFv-E7Neqf5GoYrafw?_73w=(D)js%R_h_`+_ zqH*HiusQJ&9U2RU58U@q>6Kn~%0Gs7o|ujEcp_lE_bfT|AQ6jRyMu1uJS4lC;ag-I zva9=<%1?GUc)1>KTQ0^L_lsCvI}0;)JLrv_TyDpSUd-R*16$vS0NWA+B3}=IB+rS+ zE}8<;9m;qwegc(keM-)qe@O2PJi;kz+v%^AwXmWl5o3*O$e`IY964NpB96+i_nAH2 zqw)iSVLO`am_jDXML;d5?Mvly?wi(JWUS0QVD~@&fTXV!3}_x^4t?j6E)pnA^Nuif ztKUY-#$rHhO*s>hERPZ%`mn_{n(USNKz>$-5Rrf^2orjR-ZT1{a7dn+MZ4piS z6R(w}C@JFy%S3#V+{xUD=jOCk{)p0rIJm}>#^(#z%@$vn&`l38A-}s!$#Vi~Y>maL zXL+cqKJt_PYS?>NW^6= zNyN}0kMYm?gwgHcAm8SMACB(BY^B*K?Vd?xdMcs1RS)8}{*5KNN_gNB4@LU?aincG zei~N?GT*201(Ve1Ea!<(x$7>Y>)(x0j!9V6@C7yaD{!<^hdrLdXKh$Tc)D!?+s$S0 zj4f^?iSLd0O{aW_UW*6WCtHZwo5e}=xsbA7b!yNz_a0NZBo;zv7NP0ase-Pb6)>?& z9b}ir!x2Si=(=ze<~K&*jk8N}$*b{1+*KPYj>dsnLp;0^i$ixE73woI9!9G7(0M@# zAa`;xc_#W1D;zjq$Z;8}m6r)i+@8}pvq!K({TdAS>ESoY1bF!*A6I3uc<4YkJZ7cA zCV3O~C`;kv%1KyjwEzQ$T`b5Gqr`Dwrt@2{CtaQ9t%KU-5~03+KiKi z-_o8r$tbqekB0OHp{)*=eHHJ5QGRytsp=ERM83g~Taw8`AwUf<#`sVRJYgFK+eZta zKUNe^Ei0l9YNwedp%B#jRCstVgvjKIv%yGAmhv4kdxv<}+k_dG~!W`uzWc8L!Ty7voo&n~-|{7Os}Fgk?1i zq}FW{7=Jj0zMZ|q#_>nl;k+CO4+%uQ_Xeounu{0Td%+RaF07b|u*-1)=UpcF`s;Ll z!Sq_X_1z|TvbhL6Zz-c^!3St{ilfr2$H5_Q0){7VgLOT@#=r`qHGyMv*LE^speu482J%s~*w-Ga)eyqI|2>g@|h&UsMlat?pT+{~Y{9+-m^Q0fyQ@wz;wHL#Y z*OQ?9bvRzAy@OMmC^LQ21yI=&1RXa8G%v*;@*hg_t+M1`{U2>`Txx|yR{D57PF)~j z*vGyc=F^z3E8vmHPol8T27`o-g1Nm4UT?fd#WYtiW_ni%VFK~?At8QL`7NC6d<4CP zRRo#U%9zc_vL*-P!B2lDwva?PETRJk4&@Qv1V8%df){=7Tt`$76tc}J4j8wf6_4E7 zfr}hf$+_WX(y4h0Jok1%QTMAMReer(|9cCrEh-@0;$7s3sx0Zq)n(1aRRy9l(*y$=U+6k+X14pg z67;VAiD5;fP}W(+dpHHazx)o07Mx`T){5wUErrM`l|jImIjl)N1fnH%)XpW1#I!Ly zlcNG^ecYcIxy{0c)MF5t`uH>Yw1b#c-Hd87CK?@3-8IWIG3X;0wzDy1n2cVFjaCHu6?b{ z-x{h0rwhJ91h+4bSp+XdP5`+l;RcF2`x3S@^Uw6TCT1if4EhERDHFX9Y@- zq#2Ij#%ce<4p)+7(KOK6)$@8YhAG?`+0oryx!P-(9wNdkFfpBDIs{_!3Lhk-v?>p$Y|v*uin6X32B; zQ&Rj}8Kt;emFvf}7h;)NK7N#2kF!2+!@8w2L8fPbBsJ~C!@j3*w^=YKj;Z7Cea>jr zXv$@Xh+}#5R`@(wnz;Ab(__W+uwlu5Bo6(sg$}dJYv-a*gdaO$q8h%tQckMm)*-9U z>06}jz%%&*EwV~MBVAek8qQ;4e|IYQ&W-^_HyO;$Ut`aZxL{?e5lSsvMn_eL2+ztH z$Bh%hRW)m|Joh%|n~%cJ>6Q5T<|rw4DuAKahV+!tY8aNw1AiY|6ngNQSo_73+u3fk zMO~YGVL8scs|2J-q+-aFTo9@cz>X(Pup>1S9>fQ-_8w*+Z1kIM^kT{Jl45dH{UqGx zyoMEpJ!HdB0f77-{IeP*hsW@xxhMB zlQ(cH1AO%*1Z&zN&^_Frxw1dr`!Lx(9rK%t zaIk(II>q=vQT9E2daR$=Pc$ND?#F_wu?2MXU&j9JaZnkNO5+NTFrlvl&}V@VO?lgb z`JCU3p8rU1hG~!wAKq|$iXt|zD-zzEpNwMoZzjjho-)An^e~fOz8OEdox`cs_2lc^Q&f8KG;+TDHtb#*1Tp-{ za8Ahw9w+=_M+;Nw;k?|ke{(;6DQUr^{s|Z~Qye1dRbY9AIaoRx@Ol&4snfM<1VgM~ zN-^gP3gee5CPvXOt4+Dw9&gxbrA#jSN06e2>Qpn*73>tepxto|3E!Q~1lSSkkaz+v zHvD33#GbK>4VsxBi+tcxyeBxUH6wObGQ@K4VY;|>KV(SWBYc&^?5w@z)Gg-FfBu2M z*C#<<>m=LoNtYaV|G@ievbvGga@sfxpNP zZpbAv`LQg|XO1OX?8<#3tRzkfKf$U9Ny4oATmR$Z2vRvVeANpYk+lhju6aU}P!4KN z*Td`CXYieE0>_WoO^bT8aLJP*kTh+ihI7?Ol<)+2e(gS~(XJu?c)vvc;XBmm)gb+Q za~4>=bjGxo=kU$7ebhB4oF>~mqU8r2N#>G5`mF9Hy}erzymB4s{kN)Qp5iJz=sHAh zEM%}kr2(aqlrXM12CB}-&^alqiAU}Ms14HQC)OMVk?;cScvnX!mvBBj({cFnh9p`% ziiRy>MdV7sBvh+OfQ538QL5OM`=6^r_sb7-+*uL6zFHElE$GLiOZ+jU&>!8q?t$TX zE}Q3AE0?`g3t}~)yuW@8vd0Q9mW7n%K$^WhrjMq;>eWxdw{sf|#q0-LsVRb}f40D( zTUyZD6h=mhcfpAOlQL)Sp5ezSy)sSDql9NEV{-N~HEhg-dk)rU(%wwYWpbWL$wu(n zc@0X;IcDi?Za>Ul0s5so@KC%khICItX)AxcC>4Ppme!GuQ5$dvM@)IK8HcxMVB}si zur8T{g=P;}U2ZP6F}8)RIj@OP$E(UTPN=i~{T)TF_a3|8R z?|dRn&zphIQ|4i_&TUSUoPp`}aj;pk9`7y_2I-z4+BoZaa1F5+9-2Ql7IbH9s5sBH{?ZcVN!OShHgty&@|crMUm-r#hoiyzU@9p)PKNepYL?V zVMrf=J*H=p@rOo5bi*dhMGR z`onYrxn-pDR1#Fa(1y7guX1ner0h~ zRCS^=xmtStYdOYyWi;wtI|yN|vI4WHXXqNqt#6A29Fpo}{5}t{O39gQHkaYOGB=p& zHkzWj|3Y||p^O7AvVw5;M2z5ku}{BfW5aL;UcZn>|K+hhU!Fl{uK2|s+&d9wc%Fvs zU&Wv~mk*bRlW9Xz3RT-R8&wmhgF^i!>KU_(H4Ik68@sh}E(t=%Cyt~hc#w{ zR>=2Qf{MFB;O^2Jbi$@1c;=BYX3sODT^ILIrFEUSi-vLjP0o)lRzoHQZ^EZm;cWBg zFQ9oJ#exPnX8=1&!kYc#8zO4ljA3<#%D*k}zc5R&v#h zFucqTq}kQntb*&UB!*UCYu!aQ{nTt$DD^3v`mKfQrcR_GYki57A;$yH3WW4!UEu5} z%JBeaf|gA>d*j_@)Cz7VE^Hp$-6{q^Iymr#>-V^q*k=4a&H!808< z@ak>C4}SZeY!Z~fqO)x<`t~l-pP-3vWCBq8;}B@P7X{1pyNRi}G~ZK#Vwd?cI^XCW zBPyOszQ{c$Us~+RzkYh#Y9CF{@ovKLlQVP8|u#FmHd@OZ)n zEP9y)le+_9AU6goR~^KZ((eHp7E?%R1V)Y-?A z$Wa9M>UMHGr4X;n2-D3@Ul@fg(A5m4prPzo# z2fUgthTf-|G0#}P^pwu2u=I8k1m9|)JCT5n>sm6^)rlGn9-~QLG|(r{mmE#}3TvA# zW4!21qp`z`j$I#pk9lLu z(7Gf;c6Em;sd$jhtgW6xTc^eo=EoR&=CL|6de(|e`Yit+r$MUMp5FKw#jc(h#8%si zlA()^WWP-jTR3ZAc}QqIJI6qqZr None: + super().setUp() + torch.manual_seed(42) + np.random.seed(42) + trimesh_results_path = Path(__file__).resolve().parent / "icp_data.pth" + self.trimesh_results = torch.load(trimesh_results_path) + + @staticmethod + def iterative_closest_point( + batch_size=10, + n_points_X=100, + n_points_Y=100, + dim=3, + use_pointclouds=False, + estimate_scale=False, + ): + + device = torch.device("cuda:0") + + # initialize a ground truth point cloud + X, Y = [ + TestCorrespondingPointsAlignment.init_point_cloud( + batch_size=batch_size, + n_points=n_points, + dim=dim, + device=device, + use_pointclouds=use_pointclouds, + random_pcl_size=True, + fix_seed=i, + ) + for i, n_points in enumerate((n_points_X, n_points_Y)) + ] + + torch.cuda.synchronize() + + def run_iterative_closest_point(): + points_alignment.iterative_closest_point( + X, + Y, + estimate_scale=estimate_scale, + allow_reflection=False, + verbose=False, + max_iterations=100, + relative_rmse_thr=1e-4, + ) + torch.cuda.synchronize() + + return run_iterative_closest_point + + def test_init_transformation(self, batch_size=10): + """ + First runs a full ICP on a random problem. Then takes a given point + in the history of ICP iteration transformations, initializes + a second run of ICP with this transformation and checks whether + both runs ended with the same solution. + """ + + device = torch.device("cuda:0") + + for dim in (2, 3, 11): + for n_points_X in (30, 100): + for n_points_Y in (30, 100): + # initialize ground truth point clouds + X, Y = [ + TestCorrespondingPointsAlignment.init_point_cloud( + batch_size=batch_size, + n_points=n_points, + dim=dim, + device=device, + use_pointclouds=False, + random_pcl_size=True, + ) + for n_points in (n_points_X, n_points_Y) + ] + + # run full icp + converged, _, Xt, ( + R, + T, + s, + ), t_hist = points_alignment.iterative_closest_point( + X, + Y, + estimate_scale=False, + allow_reflection=False, + verbose=False, + max_iterations=100, + ) + + # start from the solution after the third + # iteration of the previous ICP + t_init = t_hist[min(2, len(t_hist) - 1)] + + # rerun the ICP + converged_init, _, Xt_init, ( + R_init, + T_init, + s_init, + ), t_hist_init = points_alignment.iterative_closest_point( + X, + Y, + init_transform=t_init, + estimate_scale=False, + allow_reflection=False, + verbose=False, + max_iterations=100, + ) + + # compare transformations and obtained clouds + # check that both sets of transforms are the same + atol = 3e-5 + self.assertClose(R_init, R, atol=atol) + self.assertClose(T_init, T, atol=atol) + self.assertClose(s_init, s, atol=atol) + self.assertClose(Xt_init, Xt, atol=atol) + + def test_heterogenous_inputs(self, batch_size=10): + """ + Tests whether we get the same result when running ICP on + a set of randomly-sized Pointclouds and on their padded versions. + """ + + device = torch.device("cuda:0") + + for estimate_scale in (True, False): + for max_n_points in (10, 30, 100): + # initialize ground truth point clouds + X_pcl, Y_pcl = [ + TestCorrespondingPointsAlignment.init_point_cloud( + batch_size=batch_size, + n_points=max_n_points, + dim=3, + device=device, + use_pointclouds=True, + random_pcl_size=True, + ) + for _ in range(2) + ] + + # get the padded versions and their num of points + X_padded = X_pcl.points_padded() + Y_padded = Y_pcl.points_padded() + n_points_X = X_pcl.num_points_per_cloud() + n_points_Y = Y_pcl.num_points_per_cloud() + + # run icp with Pointlouds inputs + _, _, Xt_pcl, ( + R_pcl, + T_pcl, + s_pcl, + ), _ = points_alignment.iterative_closest_point( + X_pcl, + Y_pcl, + estimate_scale=estimate_scale, + allow_reflection=False, + verbose=False, + max_iterations=100, + ) + Xt_pcl = Xt_pcl.points_padded() + + # run icp with tensor inputs on each element + # of the batch separately + icp_results = [ + points_alignment.iterative_closest_point( + X_[None, :n_X, :], + Y_[None, :n_Y, :], + estimate_scale=estimate_scale, + allow_reflection=False, + verbose=False, + max_iterations=100, + ) + for X_, Y_, n_X, n_Y in zip( + X_padded, Y_padded, n_points_X, n_points_Y + ) + ] + + # parse out the transformation results + R, T, s = [ + torch.cat([x.RTs[i] for x in icp_results], dim=0) for i in range(3) + ] + + # check that both sets of transforms are the same + atol = 1e-5 + self.assertClose(R_pcl, R, atol=atol) + self.assertClose(T_pcl, T, atol=atol) + self.assertClose(s_pcl, s, atol=atol) + + # compare the transformed point clouds + for pcli in range(batch_size): + nX = n_points_X[pcli] + Xt_ = icp_results[pcli].Xt[0, :nX] + Xt_pcl_ = Xt_pcl[pcli][:nX] + self.assertClose(Xt_pcl_, Xt_, atol=atol) + + def test_compare_with_trimesh(self): + """ + Compares the outputs of `iterative_closest_point` with the results + of `trimesh.registration.icp` from the `trimesh` python package: + https://github.com/mikedh/trimesh + + We have run `trimesh.registration.icp` on several random problems + with different point cloud sizes. The results of trimesh, together with + the randomly generated input clouds are loaded in the constructor of + this class and this test compares the loaded results to our runs. + """ + for n_points_X in (10, 20, 50, 100): + for n_points_Y in (10, 20, 50, 100): + self._compare_with_trimesh(n_points_X=n_points_X, n_points_Y=n_points_Y) + + def _compare_with_trimesh( + self, n_points_X=100, n_points_Y=100, estimate_scale=False + ): + """ + Executes a single test for `iterative_closest_point` for a + specific setting of the inputs / outputs. Compares the result with + the result of the trimesh package on the same input data. + """ + + device = torch.device("cuda:0") + + # load the trimesh results and the initial point clouds for icp + key = (int(n_points_X), int(n_points_Y), int(estimate_scale)) + X, Y, R_trimesh, T_trimesh, s_trimesh = [ + x.to(device) for x in self.trimesh_results[key] + ] + + # run the icp algorithm + converged, _, _, ( + R_ours, + T_ours, + s_ours, + ), _ = points_alignment.iterative_closest_point( + X, + Y, + estimate_scale=estimate_scale, + allow_reflection=False, + verbose=False, + max_iterations=100, + ) + + # check that we have the same transformation + # and that the icp converged + atol = 1e-5 + self.assertClose(R_ours, R_trimesh, atol=atol) + self.assertClose(T_ours, T_trimesh, atol=atol) + self.assertClose(s_ours, s_trimesh, atol=atol) + self.assertTrue(converged) + + class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase): def setUp(self) -> None: super().setUp() @@ -72,10 +321,17 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase): device=None, use_pointclouds=False, random_pcl_size=True, + fix_seed=None, ): """ Generate a batch of normally distributed point clouds. """ + + if fix_seed is not None: + # make sure we always generate the same pointcloud + seed = torch.random.get_rng_state() + torch.manual_seed(fix_seed) + if use_pointclouds: assert dim == 3, "Pointclouds support only 3-dim points." # generate a `batch_size` point clouds with number of points @@ -102,6 +358,10 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase): X = torch.randn( batch_size, n_points, dim, device=device, dtype=torch.float32 ) + + if fix_seed: + torch.random.set_rng_state(seed) + return X @staticmethod @@ -230,7 +490,6 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase): - use_pointclouds ... If True, passes the Pointclouds objects to corresponding_points_alignment. """ - # run this for several different point cloud sizes for n_points in (100, 3, 2, 1): # run this for several different dimensionalities