From cc70950f4064e3feeb55281b829aa55aa4a7e942 Mon Sep 17 00:00:00 2001 From: Nikhila Ravi Date: Thu, 16 Jul 2020 10:15:30 -0700 Subject: [PATCH] barycentric clipping in cuda/c++ Summary: Added support for barycentric clipping in the C++/CUDA rasterization kernels which can be switched on/off via a rasterization setting. Added tests and a benchmark to compare with the current implementation in PyTorch - for some cases of large image size/faces per pixel the cuda version is 10x faster. Reviewed By: gkioxari Differential Revision: D21705503 fbshipit-source-id: e835c0f927f1e5088ca89020aef5ff27ac3a8769 --- .../csrc/rasterize_meshes/rasterize_meshes.cu | 63 ++++++-- .../csrc/rasterize_meshes/rasterize_meshes.h | 26 +++- .../rasterize_meshes/rasterize_meshes_cpu.cpp | 34 +++-- pytorch3d/csrc/utils/geometry_utils.cuh | 104 ++++++++++++- pytorch3d/csrc/utils/geometry_utils.h | 104 ++++++++++++- pytorch3d/renderer/mesh/rasterize_meshes.py | 45 +++++- pytorch3d/renderer/mesh/rasterizer.py | 4 + pytorch3d/renderer/mesh/renderer.py | 15 -- pytorch3d/renderer/mesh/utils.py | 8 +- tests/bm_barycentric_clipping.py | 112 ++++++++++++++ tests/data/test_blurry_textured_rendering.png | Bin 44166 -> 43705 bytes tests/test_rasterize_meshes.py | 139 +++++++++++++++++- tests/test_render_meshes.py | 12 +- 13 files changed, 611 insertions(+), 55 deletions(-) create mode 100644 tests/bm_barycentric_clipping.py diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu index e280b3f7..aa069725 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu @@ -114,6 +114,7 @@ __device__ void CheckPixelInsideFace( const float2 pxy, // Coordinates of the pixel const int K, const bool perspective_correct, + const bool clip_barycentric_coords, const bool cull_backfaces) { const auto v012 = GetSingleFaceVerts(face_verts, face_idx); const float3 v0 = thrust::get<0>(v012); @@ -149,8 +150,12 @@ __device__ void CheckPixelInsideFace( const float3 p_bary = !perspective_correct ? p_bary0 : BarycentricPerspectiveCorrectionForward(p_bary0, v0.z, v1.z, v2.z); + const float3 p_bary_clip = + !clip_barycentric_coords ? p_bary : BarycentricClipForward(p_bary); + + const float pz = + p_bary_clip.x * v0.z + p_bary_clip.y * v1.z + p_bary_clip.z * v2.z; - const float pz = p_bary.x * v0.z + p_bary.y * v1.z + p_bary.z * v2.z; if (pz < 0) { return; // Face is behind the image plane. } @@ -158,7 +163,8 @@ __device__ void CheckPixelInsideFace( // Get abs squared distance const float dist = PointTriangleDistanceForward(pxy, v0xy, v1xy, v2xy); - // Use the bary coordinates to determine if the point is inside the face. + // Use the unclipped bary coordinates to determine if the point is inside the + // face. const bool inside = p_bary.x > 0.0f && p_bary.y > 0.0f && p_bary.z > 0.0f; const float signed_dist = inside ? -dist : dist; @@ -169,7 +175,7 @@ __device__ void CheckPixelInsideFace( if (q_size < K) { // Just insert it. - q[q_size] = {pz, face_idx, signed_dist, p_bary}; + q[q_size] = {pz, face_idx, signed_dist, p_bary_clip}; if (pz > q_max_z) { q_max_z = pz; q_max_idx = q_size; @@ -177,7 +183,7 @@ __device__ void CheckPixelInsideFace( q_size++; } else if (pz < q_max_z) { // Overwrite the old max, and find the new max. - q[q_max_idx] = {pz, face_idx, signed_dist, p_bary}; + q[q_max_idx] = {pz, face_idx, signed_dist, p_bary_clip}; q_max_z = pz; for (int i = 0; i < K; i++) { if (q[i].z > q_max_z) { @@ -198,6 +204,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel( const int64_t* num_faces_per_mesh, const float blur_radius, const bool perspective_correct, + const bool clip_barycentric_coords, const bool cull_backfaces, const int N, const int H, @@ -260,6 +267,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel( pxy, K, perspective_correct, + clip_barycentric_coords, cull_backfaces); } @@ -286,6 +294,7 @@ RasterizeMeshesNaiveCuda( const float blur_radius, const int num_closest, const bool perspective_correct, + const bool clip_barycentric_coords, const bool cull_backfaces) { TORCH_CHECK( face_verts.ndimension() == 3 && face_verts.size(1) == 3 && @@ -343,6 +352,7 @@ RasterizeMeshesNaiveCuda( num_faces_per_mesh.contiguous().data_ptr(), blur_radius, perspective_correct, + clip_barycentric_coords, cull_backfaces, N, H, @@ -365,6 +375,7 @@ __global__ void RasterizeMeshesBackwardCudaKernel( const float* face_verts, // (F, 3, 3) const int64_t* pix_to_face, // (N, H, W, K) const bool perspective_correct, + const bool clip_barycentric_coords, const int N, const int H, const int W, @@ -422,11 +433,15 @@ __global__ void RasterizeMeshesBackwardCudaKernel( const float3 grad_bary_upstream = make_float3( grad_bary_upstream_w0, grad_bary_upstream_w1, grad_bary_upstream_w2); - const float3 bary0 = BarycentricCoordsForward(pxy, v0xy, v1xy, v2xy); - const float3 bary = !perspective_correct - ? bary0 - : BarycentricPerspectiveCorrectionForward(bary0, v0.z, v1.z, v2.z); - const bool inside = bary.x > 0.0f && bary.y > 0.0f && bary.z > 0.0f; + const float3 b_w = BarycentricCoordsForward(pxy, v0xy, v1xy, v2xy); + const float3 b_pp = !perspective_correct + ? b_w + : BarycentricPerspectiveCorrectionForward(b_w, v0.z, v1.z, v2.z); + + const float3 b_w_clip = + !clip_barycentric_coords ? b_pp : BarycentricClipForward(b_pp); + + const bool inside = b_pp.x > 0.0f && b_pp.y > 0.0f && b_pp.z > 0.0f; const float sign = inside ? -1.0f : 1.0f; // TODO(T52813608) Add support for non-square images. @@ -442,22 +457,29 @@ __global__ void RasterizeMeshesBackwardCudaKernel( // d_zbuf/d_bary_w0 = z0 // d_zbuf/d_bary_w1 = z1 // d_zbuf/d_bary_w2 = z2 - const float3 d_zbuf_d_bary = make_float3(v0.z, v1.z, v2.z); + const float3 d_zbuf_d_bwclip = make_float3(v0.z, v1.z, v2.z); // Total upstream barycentric gradients are the sum of // external upstream gradients and contribution from zbuf. const float3 grad_bary_f_sum = - (grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_bary); + (grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_bwclip); + float3 grad_bary0 = grad_bary_f_sum; + + if (clip_barycentric_coords) { + grad_bary0 = BarycentricClipBackward(b_w, grad_bary_f_sum); + } + float dz0_persp = 0.0f, dz1_persp = 0.0f, dz2_persp = 0.0f; if (perspective_correct) { auto perspective_grads = BarycentricPerspectiveCorrectionBackward( - bary0, v0.z, v1.z, v2.z, grad_bary_f_sum); + b_w, v0.z, v1.z, v2.z, grad_bary0); grad_bary0 = thrust::get<0>(perspective_grads); dz0_persp = thrust::get<1>(perspective_grads); dz1_persp = thrust::get<2>(perspective_grads); dz2_persp = thrust::get<3>(perspective_grads); } + auto grad_bary_f = BarycentricCoordsBackward(pxy, v0xy, v1xy, v2xy, grad_bary0); const float2 dbary_d_v0 = thrust::get<1>(grad_bary_f); @@ -467,15 +489,18 @@ __global__ void RasterizeMeshesBackwardCudaKernel( atomicAdd(grad_face_verts + f * 9 + 0, dbary_d_v0.x + ddist_d_v0.x); atomicAdd(grad_face_verts + f * 9 + 1, dbary_d_v0.y + ddist_d_v0.y); atomicAdd( - grad_face_verts + f * 9 + 2, grad_zbuf_upstream * bary.x + dz0_persp); + grad_face_verts + f * 9 + 2, + grad_zbuf_upstream * b_w_clip.x + dz0_persp); atomicAdd(grad_face_verts + f * 9 + 3, dbary_d_v1.x + ddist_d_v1.x); atomicAdd(grad_face_verts + f * 9 + 4, dbary_d_v1.y + ddist_d_v1.y); atomicAdd( - grad_face_verts + f * 9 + 5, grad_zbuf_upstream * bary.y + dz1_persp); + grad_face_verts + f * 9 + 5, + grad_zbuf_upstream * b_w_clip.y + dz1_persp); atomicAdd(grad_face_verts + f * 9 + 6, dbary_d_v2.x + ddist_d_v2.x); atomicAdd(grad_face_verts + f * 9 + 7, dbary_d_v2.y + ddist_d_v2.y); atomicAdd( - grad_face_verts + f * 9 + 8, grad_zbuf_upstream * bary.z + dz2_persp); + grad_face_verts + f * 9 + 8, + grad_zbuf_upstream * b_w_clip.z + dz2_persp); } } } @@ -486,7 +511,8 @@ at::Tensor RasterizeMeshesBackwardCuda( const at::Tensor& grad_zbuf, // (N, H, W, K) const at::Tensor& grad_bary, // (N, H, W, K, 3) const at::Tensor& grad_dists, // (N, H, W, K) - const bool perspective_correct) { + const bool perspective_correct, + const bool clip_barycentric_coords) { // Check inputs are on the same device at::TensorArg face_verts_t{face_verts, "face_verts", 1}, pix_to_face_t{pix_to_face, "pix_to_face", 2}, @@ -523,6 +549,7 @@ at::Tensor RasterizeMeshesBackwardCuda( face_verts.contiguous().data_ptr(), pix_to_face.contiguous().data_ptr(), perspective_correct, + clip_barycentric_coords, N, H, W, @@ -743,6 +770,7 @@ __global__ void RasterizeMeshesFineCudaKernel( const float blur_radius, const int bin_size, const bool perspective_correct, + const bool clip_barycentric_coords, const bool cull_backfaces, const int N, const int B, @@ -808,6 +836,7 @@ __global__ void RasterizeMeshesFineCudaKernel( pxy, K, perspective_correct, + clip_barycentric_coords, cull_backfaces); } @@ -841,6 +870,7 @@ RasterizeMeshesFineCuda( const int bin_size, const int faces_per_pixel, const bool perspective_correct, + const bool clip_barycentric_coords, const bool cull_backfaces) { TORCH_CHECK( face_verts.ndimension() == 3 && face_verts.size(1) == 3 && @@ -890,6 +920,7 @@ RasterizeMeshesFineCuda( blur_radius, bin_size, perspective_correct, + clip_barycentric_coords, cull_backfaces, N, B, diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h index 54031b17..8b9528db 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h @@ -19,6 +19,7 @@ RasterizeMeshesNaiveCpu( const float blur_radius, const int faces_per_pixel, const bool perspective_correct, + const bool clip_barycentric_coords, const bool cull_backfaces); #ifdef WITH_CUDA @@ -31,6 +32,7 @@ RasterizeMeshesNaiveCuda( const float blur_radius, const int num_closest, const bool perspective_correct, + const bool clip_barycentric_coords, const bool cull_backfaces); #endif // Forward pass for rasterizing a batch of meshes. @@ -92,6 +94,7 @@ RasterizeMeshesNaive( const float blur_radius, const int faces_per_pixel, const bool perspective_correct, + const bool clip_barycentric_coords, const bool cull_backfaces) { // TODO: Better type checking. if (face_verts.is_cuda()) { @@ -107,6 +110,7 @@ RasterizeMeshesNaive( blur_radius, faces_per_pixel, perspective_correct, + clip_barycentric_coords, cull_backfaces); #else AT_ERROR("Not compiled with GPU support"); @@ -120,6 +124,7 @@ RasterizeMeshesNaive( blur_radius, faces_per_pixel, perspective_correct, + clip_barycentric_coords, cull_backfaces); } } @@ -134,7 +139,8 @@ torch::Tensor RasterizeMeshesBackwardCpu( const torch::Tensor& grad_bary, const torch::Tensor& grad_zbuf, const torch::Tensor& grad_dists, - const bool perspective_correct); + const bool perspective_correct, + const bool clip_barycentric_coords); #ifdef WITH_CUDA torch::Tensor RasterizeMeshesBackwardCuda( @@ -143,7 +149,8 @@ torch::Tensor RasterizeMeshesBackwardCuda( const torch::Tensor& grad_bary, const torch::Tensor& grad_zbuf, const torch::Tensor& grad_dists, - const bool perspective_correct); + const bool perspective_correct, + const bool clip_barycentric_coords); #endif // Args: @@ -176,7 +183,8 @@ torch::Tensor RasterizeMeshesBackward( const torch::Tensor& grad_zbuf, const torch::Tensor& grad_bary, const torch::Tensor& grad_dists, - const bool perspective_correct) { + const bool perspective_correct, + const bool clip_barycentric_coords) { if (face_verts.is_cuda()) { #ifdef WITH_CUDA CHECK_CUDA(face_verts); @@ -190,7 +198,8 @@ torch::Tensor RasterizeMeshesBackward( grad_zbuf, grad_bary, grad_dists, - perspective_correct); + perspective_correct, + clip_barycentric_coords); #else AT_ERROR("Not compiled with GPU support"); #endif @@ -201,7 +210,8 @@ torch::Tensor RasterizeMeshesBackward( grad_zbuf, grad_bary, grad_dists, - perspective_correct); + perspective_correct, + clip_barycentric_coords); } } @@ -300,6 +310,7 @@ RasterizeMeshesFineCuda( const int bin_size, const int faces_per_pixel, const bool perspective_correct, + const bool clip_barycentric_coords, const bool cull_backfaces); #endif // Args: @@ -356,6 +367,7 @@ RasterizeMeshesFine( const int bin_size, const int faces_per_pixel, const bool perspective_correct, + const bool clip_barycentric_coords, const bool cull_backfaces) { if (face_verts.is_cuda()) { #ifdef WITH_CUDA @@ -369,6 +381,7 @@ RasterizeMeshesFine( bin_size, faces_per_pixel, perspective_correct, + clip_barycentric_coords, cull_backfaces); #else AT_ERROR("Not compiled with GPU support"); @@ -446,6 +459,7 @@ RasterizeMeshes( const int bin_size, const int max_faces_per_bin, const bool perspective_correct, + const bool clip_barycentric_coords, const bool cull_backfaces) { if (bin_size > 0 && max_faces_per_bin > 0) { // Use coarse-to-fine rasterization @@ -465,6 +479,7 @@ RasterizeMeshes( bin_size, faces_per_pixel, perspective_correct, + clip_barycentric_coords, cull_backfaces); } else { // Use the naive per-pixel implementation @@ -476,6 +491,7 @@ RasterizeMeshes( blur_radius, faces_per_pixel, perspective_correct, + clip_barycentric_coords, cull_backfaces); } } diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp index b01bdbaa..89a7dc3a 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp @@ -108,6 +108,7 @@ RasterizeMeshesNaiveCpu( const float blur_radius, const int faces_per_pixel, const bool perspective_correct, + const bool clip_barycentric_coords, const bool cull_backfaces) { if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 || face_verts.size(2) != 3) { @@ -213,8 +214,12 @@ RasterizeMeshesNaiveCpu( ? bary0 : BarycentricPerspectiveCorrectionForward(bary0, z0, z1, z2); + const vec3 bary_clip = + !clip_barycentric_coords ? bary : BarycentricClipForward(bary); + // Use barycentric coordinates to get the depth of the current pixel - const float pz = (bary.x * z0 + bary.y * z1 + bary.z * z2); + const float pz = + (bary_clip.x * z0 + bary_clip.y * z1 + bary_clip.z * z2); if (pz < 0) { continue; // Point is behind the image plane so ignore. @@ -236,7 +241,7 @@ RasterizeMeshesNaiveCpu( continue; } // The current pixel lies inside the current face. - q.emplace(pz, f, signed_dist, bary.x, bary.y, bary.z); + q.emplace(pz, f, signed_dist, bary_clip.x, bary_clip.y, bary_clip.z); if (static_cast(q.size()) > K) { q.pop(); } @@ -264,7 +269,8 @@ torch::Tensor RasterizeMeshesBackwardCpu( const torch::Tensor& grad_zbuf, // (N, H, W, K) const torch::Tensor& grad_bary, // (N, H, W, K, 3) const torch::Tensor& grad_dists, // (N, H, W, K) - const bool perspective_correct) { + const bool perspective_correct, + const bool clip_barycentric_coords) { const int F = face_verts.size(0); const int N = pix_to_face.size(0); const int H = pix_to_face.size(1); @@ -335,6 +341,8 @@ torch::Tensor RasterizeMeshesBackwardCpu( const vec3 bary = !perspective_correct ? bary0 : BarycentricPerspectiveCorrectionForward(bary0, z0, z1, z2); + const vec3 bary_clip = + !clip_barycentric_coords ? bary : BarycentricClipForward(bary); // Distances inside the face are negative so get the // correct sign to apply to the upstream gradient. @@ -354,22 +362,28 @@ torch::Tensor RasterizeMeshesBackwardCpu( // d_zbuf/d_bary_w0 = z0 // d_zbuf/d_bary_w1 = z1 // d_zbuf/d_bary_w2 = z2 - const vec3 d_zbuf_d_bary(z0, z1, z2); + const vec3 d_zbuf_d_baryclip(z0, z1, z2); // Total upstream barycentric gradients are the sum of // external upstream gradients and contribution from zbuf. - vec3 grad_bary_f_sum = - (grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_bary); + const vec3 grad_bary_f_sum = + (grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_baryclip); vec3 grad_bary0 = grad_bary_f_sum; + + if (clip_barycentric_coords) { + grad_bary0 = BarycentricClipBackward(bary, grad_bary0); + } + if (perspective_correct) { auto perspective_grads = BarycentricPerspectiveCorrectionBackward( - bary0, z0, z1, z2, grad_bary_f_sum); + bary0, z0, z1, z2, grad_bary0); grad_bary0 = std::get<0>(perspective_grads); grad_face_verts[f][0][2] += std::get<1>(perspective_grads); grad_face_verts[f][1][2] += std::get<2>(perspective_grads); grad_face_verts[f][2][2] += std::get<3>(perspective_grads); } + auto grad_bary_f = BarycentricCoordsBackward(pxy, v0xy, v1xy, v2xy, grad_bary0); const vec2 dbary_d_v0 = std::get<1>(grad_bary_f); @@ -379,13 +393,13 @@ torch::Tensor RasterizeMeshesBackwardCpu( // Update output gradient buffer. grad_face_verts[f][0][0] += dbary_d_v0.x + ddist_d_v0.x; grad_face_verts[f][0][1] += dbary_d_v0.y + ddist_d_v0.y; - grad_face_verts[f][0][2] += grad_zbuf_upstream * bary.x; + grad_face_verts[f][0][2] += grad_zbuf_upstream * bary_clip.x; grad_face_verts[f][1][0] += dbary_d_v1.x + ddist_d_v1.x; grad_face_verts[f][1][1] += dbary_d_v1.y + ddist_d_v1.y; - grad_face_verts[f][1][2] += grad_zbuf_upstream * bary.y; + grad_face_verts[f][1][2] += grad_zbuf_upstream * bary_clip.y; grad_face_verts[f][2][0] += dbary_d_v2.x + ddist_d_v2.x; grad_face_verts[f][2][1] += dbary_d_v2.y + ddist_d_v2.y; - grad_face_verts[f][2][2] += grad_zbuf_upstream * bary.z; + grad_face_verts[f][2][2] += grad_zbuf_upstream * bary_clip.z; } } } diff --git a/pytorch3d/csrc/utils/geometry_utils.cuh b/pytorch3d/csrc/utils/geometry_utils.cuh index 5fd4ed6b..f8d6b2a8 100644 --- a/pytorch3d/csrc/utils/geometry_utils.cuh +++ b/pytorch3d/csrc/utils/geometry_utils.cuh @@ -221,8 +221,108 @@ BarycentricPerspectiveCorrectionBackward( return thrust::make_tuple(grad_bary, grad_z0, grad_z1, grad_z2); } -// Calculate minimum squared distance between a line segment (v1 - v0) and a -// point p. +// Clip negative barycentric coordinates to 0.0 and renormalize so +// the barycentric coordinates for a point sum to 1. When the blur_radius +// is greater than 0, a face will still be recorded as overlapping a pixel +// if the pixel is outisde the face. In this case at least one of the +// barycentric coordinates for the pixel relative to the face will be negative. +// Clipping will ensure that the texture and z buffer are interpolated +// correctly. +// +// Args +// bary: (w0, w1, w2) barycentric coordinates which can be outside the +// range [0, 1]. +// +// Returns +// bary: (w0, w1, w2) barycentric coordinates in the range [0, 1] which +// satisfy the condition: sum(w0, w1, w2) = 1.0. +// +__device__ inline float3 BarycentricClipForward(const float3 bary) { + float3 w = make_float3(0.0f, 0.0f, 0.0f); + // Clamp lower bound only + w.x = max(bary.x, 0.0); + w.y = max(bary.y, 0.0); + w.z = max(bary.z, 0.0); + float w_sum = w.x + w.y + w.z; + w_sum = fmaxf(w_sum, 1e-5); + w.x /= w_sum; + w.y /= w_sum; + w.z /= w_sum; + + return w; +} + +// Backward pass for barycentric coordinate clipping. +// +// Args +// bary: (w0, w1, w2) barycentric coordinates which can be outside the +// range [0, 1]. +// grad_baryclip_upstream: vec3 Upstream gradient for each of the clipped +// barycentric coordinates [grad_w0, grad_w1, grad_w2]. +// +// Returns +// vec3 of gradients for the unclipped barycentric coordinates: +// (grad_w0, grad_w1, grad_w2) +// +__device__ inline float3 BarycentricClipBackward( + const float3 bary, + const float3 grad_baryclip_upstream) { + // Redo some of the forward pass calculations + float3 w = make_float3(0.0f, 0.0f, 0.0f); + // Clamp lower bound only + w.x = max(bary.x, 0.0); + w.y = max(bary.y, 0.0); + w.z = max(bary.z, 0.0); + float w_sum = w.x + w.y + w.z; + + float3 grad_bary = make_float3(1.0f, 1.0f, 1.0f); + float3 grad_clip = make_float3(1.0f, 1.0f, 1.0f); + float3 grad_sum = make_float3(1.0f, 1.0f, 1.0f); + + // Check if sum was clipped. + float grad_sum_clip = 1.0f; + if (w_sum < 1e-5) { + grad_sum_clip = 0.0f; + w_sum = 1e-5; + } + + // Check if any of bary values have been clipped. + if (bary.x < 0.0f) { + grad_clip.x = 0.0f; + } + if (bary.y < 0.0f) { + grad_clip.y = 0.0f; + } + if (bary.z < 0.0f) { + grad_clip.z = 0.0f; + } + + // Gradients of the sum. + grad_sum.x = -w.x / (pow(w_sum, 2.0f)) * grad_sum_clip; + grad_sum.y = -w.y / (pow(w_sum, 2.0f)) * grad_sum_clip; + grad_sum.z = -w.z / (pow(w_sum, 2.0f)) * grad_sum_clip; + + // Gradients for each of the bary coordinates including the cross terms + // from the sum. + grad_bary.x = grad_clip.x * + (grad_baryclip_upstream.x * (1.0f / w_sum + grad_sum.x) + + grad_baryclip_upstream.y * (grad_sum.y) + + grad_baryclip_upstream.z * (grad_sum.z)); + + grad_bary.y = grad_clip.y * + (grad_baryclip_upstream.y * (1.0f / w_sum + grad_sum.y) + + grad_baryclip_upstream.x * (grad_sum.x) + + grad_baryclip_upstream.z * (grad_sum.z)); + + grad_bary.z = grad_clip.z * + (grad_baryclip_upstream.z * (1.0f / w_sum + grad_sum.z) + + grad_baryclip_upstream.x * (grad_sum.x) + + grad_baryclip_upstream.y * (grad_sum.y)); + + return grad_bary; +} + +// Return minimum distance between line segment (v1 - v0) and point p. // // Args: // p: Coordinates of a point. diff --git a/pytorch3d/csrc/utils/geometry_utils.h b/pytorch3d/csrc/utils/geometry_utils.h index e283c8f4..f24f58ac 100644 --- a/pytorch3d/csrc/utils/geometry_utils.h +++ b/pytorch3d/csrc/utils/geometry_utils.h @@ -242,8 +242,108 @@ inline std::tuple, T, T, T> BarycentricPerspectiveCorrectionBackward( return std::make_tuple(grad_bary, grad_z0, grad_z1, grad_z2); } -// Calculate minimum squared distance between a line segment (v1 - v0) and a -// point p. +// Clip negative barycentric coordinates to 0.0 and renormalize so +// the barycentric coordinates for a point sum to 1. When the blur_radius +// is greater than 0, a face will still be recorded as overlapping a pixel +// if the pixel is outisde the face. In this case at least one of the +// barycentric coordinates for the pixel relative to the face will be negative. +// Clipping will ensure that the texture and z buffer are interpolated +// correctly. +// +// Args +// bary: (w0, w1, w2) barycentric coordinates which can contain values < 0. +// +// Returns +// bary: (w0, w1, w2) barycentric coordinates in the range [0, 1] which +// satisfy the condition: sum(w0, w1, w2) = 1.0. +// +template +vec3 BarycentricClipForward(const vec3 bary) { + vec3 w(0.0f, 0.0f, 0.0f); + // Only clamp negative values to 0.0. + // No need to clamp values > 1.0 as they will be renormalized. + w.x = std::max(bary.x, 0.0f); + w.y = std::max(bary.y, 0.0f); + w.z = std::max(bary.z, 0.0f); + float w_sum = w.x + w.y + w.z; + w_sum = std::fmaxf(w_sum, 1e-5); + w.x /= w_sum; + w.y /= w_sum; + w.z /= w_sum; + return w; +} + +// Backward pass for barycentric coordinate clipping. +// +// Args +// bary: (w0, w1, w2) barycentric coordinates which can contain values < 0. +// grad_baryclip_upstream: vec3 Upstream gradient for each of the clipped +// barycentric coordinates [grad_w0, grad_w1, grad_w2]. +// +// Returns +// vec3 of gradients for the unclipped barycentric coordinates: +// (grad_w0, grad_w1, grad_w2) +// +template +vec3 BarycentricClipBackward( + const vec3 bary, + const vec3 grad_baryclip_upstream) { + // Redo some of the forward pass calculations + vec3 w(0.0f, 0.0f, 0.0f); + w.x = std::max(bary.x, 0.0f); + w.y = std::max(bary.y, 0.0f); + w.z = std::max(bary.z, 0.0f); + float w_sum = w.x + w.y + w.z; + + vec3 grad_bary(1.0f, 1.0f, 1.0f); + vec3 grad_clip(1.0f, 1.0f, 1.0f); + vec3 grad_sum(1.0f, 1.0f, 1.0f); + + // Check if the sum was clipped. + float grad_sum_clip = 1.0f; + if (w_sum < 1e-5) { + grad_sum_clip = 0.0f; + w_sum = 1e-5; + } + + // Check if any of the bary coordinates have been clipped. + // Only negative values are clamped to 0.0. + if (bary.x < 0.0f) { + grad_clip.x = 0.0f; + } + if (bary.y < 0.0f) { + grad_clip.y = 0.0f; + } + if (bary.z < 0.0f) { + grad_clip.z = 0.0f; + } + + // Gradients of the sum. + grad_sum.x = -w.x / (pow(w_sum, 2.0f)) * grad_sum_clip; + grad_sum.y = -w.y / (pow(w_sum, 2.0f)) * grad_sum_clip; + grad_sum.z = -w.z / (pow(w_sum, 2.0f)) * grad_sum_clip; + + // Gradients for each of the bary coordinates including the cross terms + // from the sum. + grad_bary.x = grad_clip.x * + (grad_baryclip_upstream.x * (1.0f / w_sum + grad_sum.x) + + grad_baryclip_upstream.y * (grad_sum.y) + + grad_baryclip_upstream.z * (grad_sum.z)); + + grad_bary.y = grad_clip.y * + (grad_baryclip_upstream.y * (1.0f / w_sum + grad_sum.y) + + grad_baryclip_upstream.x * (grad_sum.x) + + grad_baryclip_upstream.z * (grad_sum.z)); + + grad_bary.z = grad_clip.z * + (grad_baryclip_upstream.z * (1.0f / w_sum + grad_sum.z) + + grad_baryclip_upstream.x * (grad_sum.x) + + grad_baryclip_upstream.y * (grad_sum.y)); + + return grad_bary; +} + +// Calculate minimum distance between a line segment (v1 - v0) and point p. // // Args: // p: Coordinates of a point. diff --git a/pytorch3d/renderer/mesh/rasterize_meshes.py b/pytorch3d/renderer/mesh/rasterize_meshes.py index 125eec20..984d3352 100644 --- a/pytorch3d/renderer/mesh/rasterize_meshes.py +++ b/pytorch3d/renderer/mesh/rasterize_meshes.py @@ -24,6 +24,7 @@ def rasterize_meshes( bin_size: Optional[int] = None, max_faces_per_bin: Optional[int] = None, perspective_correct: bool = False, + clip_barycentric_coords: bool = False, cull_backfaces: bool = False, ): """ @@ -143,6 +144,7 @@ def rasterize_meshes( bin_size, max_faces_per_bin, perspective_correct, + clip_barycentric_coords, cull_backfaces, ) @@ -183,6 +185,7 @@ class _RasterizeFaceVerts(torch.autograd.Function): bin_size: int = 0, max_faces_per_bin: int = 0, perspective_correct: bool = False, + clip_barycentric_coords: bool = False, cull_backfaces: bool = False, ): # pyre-fixme[16]: Module `pytorch3d` has no attribute `_C`. @@ -196,11 +199,13 @@ class _RasterizeFaceVerts(torch.autograd.Function): bin_size, max_faces_per_bin, perspective_correct, + clip_barycentric_coords, cull_backfaces, ) ctx.save_for_backward(face_verts, pix_to_face) ctx.mark_non_differentiable(pix_to_face) ctx.perspective_correct = perspective_correct + ctx.clip_barycentric_coords = clip_barycentric_coords return pix_to_face, zbuf, barycentric_coords, dists @staticmethod @@ -214,6 +219,7 @@ class _RasterizeFaceVerts(torch.autograd.Function): grad_bin_size = None grad_max_faces_per_bin = None grad_perspective_correct = None + grad_clip_barycentric_coords = None grad_cull_backfaces = None face_verts, pix_to_face = ctx.saved_tensors grad_face_verts = _C.rasterize_meshes_backward( @@ -223,6 +229,7 @@ class _RasterizeFaceVerts(torch.autograd.Function): grad_barycentric_coords, grad_dists, ctx.perspective_correct, + ctx.clip_barycentric_coords, ) grads = ( grad_face_verts, @@ -234,6 +241,7 @@ class _RasterizeFaceVerts(torch.autograd.Function): grad_bin_size, grad_max_faces_per_bin, grad_perspective_correct, + grad_clip_barycentric_coords, grad_cull_backfaces, ) return grads @@ -250,6 +258,7 @@ def rasterize_meshes_python( blur_radius: float = 0.0, faces_per_pixel: int = 8, perspective_correct: bool = False, + clip_barycentric_coords: bool = False, cull_backfaces: bool = False, ): """ @@ -356,6 +365,14 @@ def rasterize_meshes_python( top2 = z0 * z1 * l2 bot = top0 + top1 + top2 bary = torch.stack([top0 / bot, top1 / bot, top2 / bot]) + + # Check if inside before clipping + inside = all(x > 0.0 for x in bary) + + # Barycentric clipping + if clip_barycentric_coords: + bary = barycentric_coordinates_clip(bary) + # use clipped barycentric coords to calculate the z value pz = bary[0] * v0[2] + bary[1] * v1[2] + bary[2] * v2[2] # Check if point is behind the image. @@ -365,7 +382,6 @@ def rasterize_meshes_python( # Calculate signed 2D distance from point to face. # Points inside the triangle have negative distance. dist = point_triangle_distance(pxy, v0[:2], v1[:2], v2[:2]) - inside = all(x > 0.0 for x in bary) signed_dist = dist * -1.0 if inside else dist @@ -433,6 +449,33 @@ def edge_function(p, v0, v1): return (p[0] - v0[0]) * (v1[1] - v0[1]) - (p[1] - v0[1]) * (v1[0] - v0[0]) +def barycentric_coordinates_clip(bary): + """ + Clip negative barycentric coordinates to 0.0 and renormalize so + the barycentric coordinates for a point sum to 1. When the blur_radius + is greater than 0, a face will still be recorded as overlapping a pixel + if the pixel is outisde the face. In this case at least one of the + barycentric coordinates for the pixel relative to the face will be negative. + Clipping will ensure that the texture and z buffer are interpolated correctly. + + Args: + bary: tuple of barycentric coordinates + + Returns + bary_clip: (w0, w1, w2) barycentric coordinates with no negative values. + """ + # Only negative values are clamped to 0.0. + w0_clip = torch.clamp(bary[0], min=0.0) + w1_clip = torch.clamp(bary[1], min=0.0) + w2_clip = torch.clamp(bary[2], min=0.0) + bary_sum = torch.clamp(w0_clip + w1_clip + w2_clip, min=1e-5) + w0_clip = w0_clip / bary_sum + w1_clip = w1_clip / bary_sum + w2_clip = w2_clip / bary_sum + + return (w0_clip, w1_clip, w2_clip) + + def barycentric_coordinates(p, v0, v1, v2): """ Compute the barycentric coordinates of a point relative to a triangle. diff --git a/pytorch3d/renderer/mesh/rasterizer.py b/pytorch3d/renderer/mesh/rasterizer.py index 3ea2ae16..eefe1cf4 100644 --- a/pytorch3d/renderer/mesh/rasterizer.py +++ b/pytorch3d/renderer/mesh/rasterizer.py @@ -26,6 +26,7 @@ class RasterizationSettings: "bin_size", "max_faces_per_bin", "perspective_correct", + "clip_barycentric_coords", "cull_backfaces", ] @@ -37,6 +38,7 @@ class RasterizationSettings: bin_size: Optional[int] = None, max_faces_per_bin: Optional[int] = None, perspective_correct: bool = False, + clip_barycentric_coords: bool = False, cull_backfaces: bool = False, ): self.image_size = image_size @@ -45,6 +47,7 @@ class RasterizationSettings: self.bin_size = bin_size self.max_faces_per_bin = max_faces_per_bin self.perspective_correct = perspective_correct + self.clip_barycentric_coords = clip_barycentric_coords self.cull_backfaces = cull_backfaces @@ -127,6 +130,7 @@ class MeshRasterizer(nn.Module): bin_size=raster_settings.bin_size, max_faces_per_bin=raster_settings.max_faces_per_bin, perspective_correct=raster_settings.perspective_correct, + clip_barycentric_coords=raster_settings.clip_barycentric_coords, cull_backfaces=raster_settings.cull_backfaces, ) return Fragments( diff --git a/pytorch3d/renderer/mesh/renderer.py b/pytorch3d/renderer/mesh/renderer.py index 87835a70..46c0231f 100644 --- a/pytorch3d/renderer/mesh/renderer.py +++ b/pytorch3d/renderer/mesh/renderer.py @@ -49,21 +49,6 @@ class MeshRenderer(nn.Module): the range for the corresponding face. """ fragments = self.rasterizer(meshes_world, **kwargs) - raster_settings = kwargs.get("raster_settings", self.rasterizer.raster_settings) - if raster_settings.blur_radius > 0.0: - # TODO: potentially move barycentric clipping to the rasterizer - # if no downstream functions requires unclipped values. - # This will avoid unnecssary re-interpolation of the z buffer. - clipped_bary_coords = _clip_barycentric_coordinates(fragments.bary_coords) - clipped_zbuf = _interpolate_zbuf( - fragments.pix_to_face, clipped_bary_coords, meshes_world - ) - fragments = Fragments( - bary_coords=clipped_bary_coords, - zbuf=clipped_zbuf, - dists=fragments.dists, - pix_to_face=fragments.pix_to_face, - ) images = self.shader(fragments, meshes_world, **kwargs) return images diff --git a/pytorch3d/renderer/mesh/utils.py b/pytorch3d/renderer/mesh/utils.py index f61f4faf..749a746d 100644 --- a/pytorch3d/renderer/mesh/utils.py +++ b/pytorch3d/renderer/mesh/utils.py @@ -20,9 +20,13 @@ def _clip_barycentric_coordinates(bary) -> torch.Tensor: if bary.shape[-1] != 3: msg = "Expected barycentric coords to have last dim = 3; got %r" raise ValueError(msg % (bary.shape,)) + ndims = bary.ndim - 1 + mask = bary.eq(-1).all(dim=-1, keepdim=True).expand(*((-1,) * ndims + (3,))) clipped = bary.clamp(min=0.0) + clipped[mask] = 0.0 clipped_sum = torch.clamp(clipped.sum(dim=-1, keepdim=True), min=1e-5) clipped = clipped / clipped_sum + clipped[mask] = -1.0 return clipped @@ -49,6 +53,8 @@ def _interpolate_zbuf( verts = meshes.verts_packed() faces = meshes.faces_packed() faces_verts_z = verts[faces][..., 2][..., None] # (F, 3, 1) - return interpolate_face_attributes(pix_to_face, barycentric_coords, faces_verts_z)[ + zbuf = interpolate_face_attributes(pix_to_face, barycentric_coords, faces_verts_z)[ ..., 0 ] # (1, H, W, K) + zbuf[pix_to_face == -1] = -1 + return zbuf diff --git a/tests/bm_barycentric_clipping.py b/tests/bm_barycentric_clipping.py new file mode 100644 index 00000000..0941a97c --- /dev/null +++ b/tests/bm_barycentric_clipping.py @@ -0,0 +1,112 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +from itertools import product + +import torch +from fvcore.common.benchmark import benchmark +from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform +from pytorch3d.renderer.mesh.rasterizer import ( + Fragments, + MeshRasterizer, + RasterizationSettings, +) +from pytorch3d.renderer.mesh.utils import ( + _clip_barycentric_coordinates, + _interpolate_zbuf, +) +from pytorch3d.utils.ico_sphere import ico_sphere + + +def baryclip_cuda( + num_meshes: int = 8, + ico_level: int = 5, + image_size: int = 64, + faces_per_pixel: int = 50, + device="cuda", +): + # Init meshes + sphere_meshes = ico_sphere(ico_level, device).extend(num_meshes) + # Init transform + R, T = look_at_view_transform(1.0, 0.0, 0.0) + cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) + # Init rasterizer + raster_settings = RasterizationSettings( + image_size=image_size, + blur_radius=1e-4, + faces_per_pixel=faces_per_pixel, + clip_barycentric_coords=True, + ) + rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings) + + torch.cuda.synchronize() + + def raster_fn(): + rasterizer(sphere_meshes) + torch.cuda.synchronize() + + return raster_fn + + +def baryclip_pytorch( + num_meshes: int = 8, + ico_level: int = 5, + image_size: int = 64, + faces_per_pixel: int = 50, + device="cuda", +): + # Init meshes + sphere_meshes = ico_sphere(ico_level, device).extend(num_meshes) + # Init transform + R, T = look_at_view_transform(1.0, 0.0, 0.0) + cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) + # Init rasterizer + raster_settings = RasterizationSettings( + image_size=image_size, + blur_radius=1e-4, + faces_per_pixel=faces_per_pixel, + clip_barycentric_coords=False, + ) + rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings) + + torch.cuda.synchronize() + + def raster_fn(): + fragments = rasterizer(sphere_meshes) + + # Clip bary and reinterpolate + clipped_bary_coords = _clip_barycentric_coordinates(fragments.bary_coords) + clipped_zbuf = _interpolate_zbuf( + fragments.pix_to_face, clipped_bary_coords, sphere_meshes + ) + fragments = Fragments( + bary_coords=clipped_bary_coords, + zbuf=clipped_zbuf, + dists=fragments.dists, + pix_to_face=fragments.pix_to_face, + ) + torch.cuda.synchronize() + + return raster_fn + + +def bm_barycentric_clip() -> None: + if torch.cuda.is_available(): + kwargs_list = [] + num_meshes = [1, 8] + ico_level = [0, 4] + image_size = [64, 128, 256] + faces_per_pixel = [10, 75, 100] + test_cases = product(num_meshes, ico_level, image_size, faces_per_pixel) + for case in test_cases: + n, ic, im, nf = case + kwargs_list.append( + { + "num_meshes": n, + "ico_level": ic, + "image_size": im, + "faces_per_pixel": nf, + } + ) + + benchmark(baryclip_cuda, "BARY_CLIP_CUDA", kwargs_list, warmup_iters=1) + benchmark(baryclip_pytorch, "BARY_CLIP_PYTORCH", kwargs_list, warmup_iters=1) diff --git a/tests/data/test_blurry_textured_rendering.png b/tests/data/test_blurry_textured_rendering.png index ab3e8d9958175f5ffc7bb8afdb90895628c550d8..30a870adabd4c3cf28e90fe3a27488fbd001824f 100644 GIT binary patch literal 43705 zcmeEt^^NW(^lASF`LQi62DV5D?+ivmh2h;#{z4naDlVRVlX z&pzMhe|Ua+*bnz!9G+= zMOfXCE*|xDJZ2*2ChJe73{$m~9CndVkym;CUGcEnoPEL?Jjl{-F7fD#4yafst!&A^ z9+mArvgBWx1=)PRY2;sU3cGAYo2p(8EjJ%Fm(JRHZ#BD*yf9{G)1p+=enon^Ap=s> zrra0|eF`XQr&1_V!Hc*CqKXKim=6`SaSOSk-HN_qVQv=JehuP^F3RO91O9uGQV)nO z;^K_{|BwAYSINO33@e%()!i#25tIbn?JeIeWUP6Z0?Q{q5?05NL7Q89RNVZJR20k( z+_hfOabJ@CSq&M=mFsISf#jTb)_-x31#GF)r}I_rVsteM1AH-N4w$`)YhNv&LErRSgNxYE>V;`;!o)HTioUkv%T5n%oH={U}DJKE8oBsCV_fXV|7DiB) z@Bb1A@&vV&_<|!9wEx-;LhiHij_K(m!DAG-*U-t1`hdO9cz@W)6aQK%y$=}(3XgQz zUZykKqFqhf(Sf2m{`(T>_cPl9(l=hG(pr@d%cuakNG3{J`73je|2=IItw)Ou-6v#0 z{=zq9HDlI3Ciq68B}0n|0A`%_WBq`%7j|Uye-hl=#pH*VRn?}9aoH}ug#pY%5-js~ z2nFv~YyFnZ?M)pQNpKq)m6fOfmPyb6KKv=L0V`DrV&&B~7zmk?Z+my8zx3$_jl?D! zL(`F)mqM^M$6k!wNTHsu`0^!o>k@vEfTGMX^0u@GAgiyQ+RnV=3naG?Uh7nZCdGa< zfj$^13o!G9KEAcS7-yv!>#B}_VHo9(bE2SI2&t5S`a&?4ZAl0~_O_MP5&$C{rM}h@ z{4;$h?iZw7ghP!S=GGnUeH3JWVx~db*N>aFyUK3Rl}L#6Z<|oxJ5A4LsV^P@>M}sb z_KxQL2h0-C;g}`dt{Ouw^z!5jxo%(XdB6FW#=u?q?d{P0z-vA}in4(1V}0K<5Xj72 zJ@a!L&fhd5LxW*qH)eW0eG`4Ev9CgTTxNgQu~}kMb}?Tim($68NYg3=m}?~6{c%!) zssw?F**H>~P*&s3cw}S`-!NNN;6H9=1dEyf0|{Ox;EjQtx(zYt+_+aBu?2 zPVIAM6?@PuG9lt@$xdWu1j`a_anUD%h3&=z=--?&7uq84!aWdd*C~CYZjcAw)Y-?R zF-DJ5Wdj6**k!`l1FSM<4Tsp!(!VADF(N4XXmC}Z6%VS~-WBqwt;*Q9k<<28Sov6d zpW)EdUXt9!(lM1PxSCP^j7s2$vQH97A+jFi$!yv3?EC4J{En0Kfpe6~7uFrJ)v)W% zc24E7|NOD_g-*tLajeQk0_0>vSeiAPM=gB#NpxOS(bcajxkxUs4HgYi7 zQVsO`?-AEsTpvv+5C*KrW(92L!Myc9aVcxjB%Xwz6rShiy-@Ob`7Vq?J33H-`%;ly zxC6(#D8< zmpslepJ~4)(ti#jl50xUK6XlO{s<%{!ojnTal*_{@Um3 z$hr%cNczd1A~@@i>yAb%PfFJY+aqD=JJP+r(Itndd2VJh>j(ANThq7DrM6k@;)_o& zQA&{|;&Cm1&W9gr5iFr^^Ol`U_lJMB9yiY)?-;8VKX`EoCrZ+Q#_VfO4n#2H7eR&9 zS;0GOP~YHfLa6WYJ60~DHtJ%?008g%_if?cu!7w69c7%*Ey9V<;MxqfULJ=WGuUb> z!F7M%8sSj;-2j|td;p?zWy!Pd`)0805 zSbw@+fqm96RN#R9mmy}hyXGlS$0NDdGAT0l;Lxh+G#U&J*-#A0-H24myf6MIOMRj@ z8u}JDeM6Mvk1`wq&mCr$EOa$3P59Uln1$ z{>ipIzCgI*;|No8Q9Y$Sz2?86nz(S9Gd*Mqa&2lZSz*N~gtGPx{&SDyuE=eD>H$0% zP92dN2*hDg2ss*iBoTcg@Sm%ar>Gw}dMHPu%y3_wBDHw10Re>DVD9ASX^r~lw@s2k z@AT>Ct3Us9<}Ko0x|EBm$AZ`Wk^*4?DRUIT3nVPvJC+*^0Y|5dn^ZU5LHO=ajwb zLGoci+#MV8bi$8}#Y8sx_jj$?rfPgVvDwY-ie$7|P zvMWrPcAl4*&EJZ|`a>K)wA;J@hj1Rc17gU0hRQzA&Q+!Kb9hmU!T&6gG^y+Rw^%Ig z;Jkt}7CB}fbX;bFiRNF617|#YZ?w;hV-SsPhx$BS5kW6r^ZF+Y4XH8}Lo7{k;a|_l zEzBJuo_6{bG8J{t!>{-Dh$FGWE!qpl3hNM_o#z&Gdkqh|mJBmvKJ=t@M%zg#%A@?ff#vAs#|5n5`f|$ zUOy5^+b4wsw#+DrHOy=h^&u>+sl$T+n+_StTpVLUB&+8>96W zbX~$rusb1^?)D+Q)yH7~9BOeKK(!(L38M>0;p047HiRi`5n|m$Wn|lwpNFQ=ueKCE zeHi1GLVm;ZZ@B@kVEZTQkVBgF6Y$tIhk)pE!}LkgX3`81y!)@pjZ95gscj z_!{bg+NGz&3U@i(6ZCZ!*|K2h(>&Na*|EK2y4vZh{C(j_wbAomTKWX0{UZ$#U;ZT6 zxWBtM>3n4y>nWHs(NtwtIqe{9=|UZSL0E?ejw*N_vQPo+o~~Y67Ej*9cAsLu0N%F5 zDB`~}`Rx?I_iT-cq=GOpkvDs$DTMeD8(~Iu+mHEugoSex6|Uw|`-SFhJ=1+CbKZsb zM%zYcw46xjb|}UV6#w&sr0)KXbL1ZhE5<_c<_36Xb5zeLw(pVJ4t470 z10rUvHJxlda+P!;ZX~oruJ3wI$9drj_p?KG?SCs?$CNh_~a;X!ZE}M!)xzlD0C<6OouzNzR1%DE<3H-a~ zvp%r0dHo37dNCNZI3I+rewKT5OY%a|SJX~aYXkSR`cB9U8{lzZAywC0%B{XCxo<=I zorujf#?FCwX(tz7L4XGk01^S&E(@I6)M{ zb{K#{a1lpiV;e`IJ*~{W*aKwDO4e_;DPAU3tqx+YdipZp=(8o4I zIowk@>`zYW&3xk!Yng>1!-sn}bPBOSI}O?QaP-jH+8a&oMN zNu@KnJM3PGUS`=jA%4vF^a{B6dg213`O@eCGjw&p!TqIexn=lc-9|~2N`Xs)3$$i{DqskK}XB}9WQwPBL<4C`OcV4OiBbJ#2NjaG zAyKc=Gh+k+!eD}#;Z+zd!h#lYD-gRzD`%cEEndHiK0U2H_ggfytrVQ1>!viWvI3Es z0K#`c0KoiPEWQ~dTCeLn>Gj5G4=-lc5J3dpU?~}uPFrpTfT}uM{PJ-j|UR1rWL)4u$b8|ckL~EK&eF4@ol2-wtSF?Eq&wxWqQq- z@KB&DL0`Amfbc^!a^I=d2phWW`FjI_l+zSaSwBNxU%$K!nOi$$30FM&i|!|t6QX@vc#Y@EW~O3vY6R9u4{rrt z?QI@3#rODaH6^=6X{iO$j|Qog5OZHNGk@Io^7i-lclWPf+()<$;a$znZcgf(qQ>Z* z9Yv6O%bbj0p6HwjK6#6D1Iqg$VGZx*#ai3x=EeB)F(-)6;dFT#(h}!)CT-$&Q`I5u zR)t~;V%>cIf!b4+4a%ONm^I;=GTPpxL%n5jhe}6t?Ymi~Q?q9h9V~f7GJ)5rX1Xt2 zYV7E?xsw&|dGz%&>9ha+t56IZ@vOm?e-$-qm@uAvB(rznzm(xl0tag>d$M)V6@nh=s<%boB<#q{awAIUr% zvq@qD*h+6=CxnUDoa)=Hq$z((oFR;GejNw;X@Q-sp(fMkuxW4DnLE4gq9M4_KgRC~-!f*cZaus{f1k!K zL*8k8*n-XNE%L!JR&ejV8#iTVvj85t>JVLpln{P?F1%G>V@;4y#g}77yy2mR>taN? z6Xe_9I;8g8hHjKLytyt9sZP#iX!Ik8UL$RveEIp55&A+hK@|A{sJITCtbP|Dy8kSY zmrz01>x~C==eIiR6VsX-#$H!iHtXM&l9JfHp_X=BeR%x!y8% zhfQlBI|cN|i3 z_JQx(34)nB(F>z^wr*ZTW3*OrKe$hmVmzmjdcrnnCI`;;{-8-L?SX%PVU1$LhOu}> z{b=sFS&Wi3F*fk__HMp)%9g_a_|hzr!8OvIDO_Wj(9d`83ZLcI1v^{%A&MP4+&(+P zA~SkD0ZF>$SO^YSDRge!rQD&O6N*0UG48G>A)7{0;TQRG2GrAsv$fL=Nl;-BZDs8g z8)8Flr~37rT?ad(0e5Y%_(h`0HOPYzoQTa|8hqRUzDkCUqE+ej5jZ9M;a zG8XLH)Gk%%|M!#YjB`fYqGoXhwmt2Yan1Ep;tr>9^ijfBH(CrozI&fOEa4SR<)c1{#!S>Zuoh|-r9hM$E zC99T8^eqz*S6A~);l~|)R6@_B?o;o<;6@^2DdGDKPUb~qd0HJcfVs^gvxFOwM~1Fi zgZ?1RM`IaMr0DdVYyziIw(!WY;?-&&tEg_1tu&fnW^rW8#FpqZHnyS{p8K>|g}d++ zd8&><&AZ+JU)TJ+j($Qs2MuZ{cYcgci!kYWa~%F^smwiU+6+czIU~fxNHesb1%CV9 zXK2ZbugMynO%E~H%5~?Sma&16cRJN4I2FB1Iz8T6%0IJ3t)%*)!&X`^--fTqAF)C& zDgZ3=qx?A7lHszqf`EohSclWkDr80NljaA73hbNFfQ@Kb(?>v|y8U`r^Y(13JjWD9 zU-vuxec@Gjbjgn?;dWxUi?8d_%^VlMU>RWQb^@^qO{;#KzP7nX7JCu2w&EP7k}P&w zc~o0~ws%?rliE6pM@f`CZM_*^O-tbJj(6q_5^qIl<^u+}9wh*@zQJ}jRj zhqGVeZ1OO0-yyf^qSe-O9&xd(+CnZTV$5VujnF-+wy%F1@kj>;h^jX+mtmXw{Ulkj zR=UHMfk0C3`(_XMVWf&8>q>-m9-FtR_^VKy$VmN+jBh38pdJ5Fj;mF#y|c@u1Wd}t zo(2vtv27QLyE>$y&Q1kb;Z?|~^2xG8Gv6~VoA-F6P!W#SztZ{hLeZrK2q)RXnGYV> z9ItpR=Yt@zn)x!Y>pj2bdD1ShTtjQ~sfMyAZkID+J|f5L#q2AT&_*oquOK%tzsK6!#Cy+gyqg1m`!#V?qs@0Txt|>~ zBv7?w6zLa{hFtbOy8Lrf5MjYk%8m0J(NIPIYrX)rn!%Omen9NnYT8^SZ*wnV z0^WUfFWw1ya_s<1RNx0e;RgOvv1|V&%hAI17s$*`MZqSA6mY?v7$7sKM8{~57 zpOT}w%gNov#YmyxlnWIlR(@U{>#A41|IK>8n8uoLxX+c&s@5)W0*WbTY|r&KC6Qat>80?q^UWWl6?v=O%gdmk<8mz}+voA=X+oLfolh`kpL zE8pbQ)|Qv^CVO&&N$DxKbLHTc$7Eqi*c8OBS8H{k_q9*PA;Nr3nR6n5Kd= z?4nfWa%8Pnw7yzPPYN4lX!f4$+^m)N3L5EyMFo#Qqao zed9LHVi(pn3u3I`miiyZw&x$%8CiA-!+bJ|X~g;aR2a>5ExDGzpZmo{ii1==I|ieh z;HH6eX0qwe46oxo^weuC$Z6Q-gIdO(IpRloB@TgrySRAe;|jSl+*fqvyGar z<7Ib#w4l8}TeQz(pv%6OsQQH1tNcjM%(`F5=C;FSpS3vAC7&*IKMP5hi@U$KwR(E8 z)Nkp6eqQQCP8?kVVc+(}eP<`9YG)_s!-HSViz%g}`7-_@@e}LP+CDdDkCY8YM@Orw zm{>y!TIy!1r(o?^Q0i;#FBBSqHLm-wWBq*5@5aJ_L|eP5B*1D6{g+=N?O{oJ)cHTzHfgxLdVk7+&o{QKRHW%%w=TM%K~7LfBD198s^Gy z_^Ww-e{p|tanTEXHd|}KApT^hvVM7%^t0gR?(Tqi?+Z|Q)uDuKh3A3~J^;klDG0}= zfPk%JB@L_hKx{BL+kKEvKwM2c`kj?J3uy%|cnUcg2sx6fnHjFOpL<5BI!Dnij{t{( zWB(BVh!dOdQ?Dq}N_+_jW-b>6V)^g%by2gy7?aDLgqZl~olNk$r)ajH5VPCL9>>z$ z3SNp9b~4}h7ra$5?ktjuF5_mdY-L`Zl&Z_Cdrc6G|J_Atv)1-MOI=(YQvloOyFEvD zuRt$fU;8{ebTWM^^C?OOMp`^`(W?l+kpz=8s?TE?Q1#BbxWU>gV07Klc8Y zv%O&c{>}4;m+00@i~4zd=B76Pvu&zr&k5n;!g^VSSIirpp$?`cglY>xfu}#F_GL??a`r^VFJ{)1`lJ96Ab9kUAgc*5L`_7|%Vj_9lkQ60q1dMB>6h z6arAUuVgHVwO+&&{00>-T3!nO_(WgKF|DFv5THm4Q@5TA6;z?DIaUUP#{lvlfvcnL zGlstXm>e};V_F%Vi`$27kRAB$_o;5S!t47N4k4>Dvia6Nzh+bal5Hs@Z(%*3P_Wtf zIYBS)g5eCJCyQ76gDhK5rNObuG!Ypt3Jv`y~EkS;8UqScbTZWt>Ye?C5N# zRyl#mRat~@>4^hqZZp?qigm>SHtptiEB%1W8lW0v;vM8@L;67E!%8iyrM`6D03&` z%hzXaEHohQ_gGBNQd$+?QbU*gFVLm6FPrh7YhO=a)~Q`^9y^$W`&sZ5tcWl5!uJ$) zOm&0_ZpLfY;otD8MNiBe+2x|S-$=Qy=4b|O3MsgF{Zj44b}&(O;T>Tj+o<3vNY44( zKlG&|koi!2I;BVci86EbHY@YW!FuGLqdX4eI4kIHfT-TDYJ>OvGU^(%ZAPpkKFmQHY5Y7GEG{xf zS$Ib3W7?&6zsCQ&G}Zx9_&hUe_?~0=EnYMkjl$bmbz!+d!QBN|fgC|#HiyH9i^NP} z9%3@N!bPvkt79u>Df-_sSKu43=8e(}nyu=#lkGI+j0_qYTiYBA&Xv&5K3!K^@!8)v zsrWI{IBdfx=Y!1a?erH6)JFgR;L<_U{F5j5bBDnszsZb66%Q1JoW`9kZ(iYWSC1+0^K(OSG ztT>Z*HJ<3lw+gwui!r-d)el&IEESwgaDkSln~l#0iN}Q$kx1{1T zd3=0*AKlgh2BS*lXG>PNHUeZAQfJ-?v{HM&UWhSk#;eAB#0(ozbJ?Hrl4-<3ad zjgEV}|9v0^SO1T6!L8Vicq)5Aa-Z&El?83eG-uOYk8^NC7^PUB3_l;s*C2UfC_FI; zvnSRy1)HjQ;2x3dgw){cxayo{%!vy)B(X?75gxy1zUeNNCn#@AV8iL43w4XxJK!!z z9e?>HD*DTi*uR)wXWZl|(dHJxhqhb6#!F~#OiITH)rASO2@6_G^Y^;Sxcb_j&e_>{ zq9pjKE8kzJDJpU{eWY>neH!O;g!|)9g^?xpep`9o`O^MgFxYM5DC}VNymlsRc8)tP zLNkzhTC1N+8B{n;* zL5O}Dz4gr+!&I;;Crp%wgnGqFJv0S9qjL0qtdGOPM@1?Mq)L~jvt&KL7)0L8{T+Rx zw8P(b@!`GEqYzSy6AM~Srtg?iIug;ZInn*$xbZu3_FM=pw0W^<)4GdU(x^zzH{=`G zR>kQOX_xsv9hdbm`zTRgIs5t*XEfjSNaJ8k%wWu^D+7$6xY0hj{J*^}$D;kP$tf4H zg3V@NRq>)r%pms1M{cU;uVubao+?2}PNrZqmUaewgg&7)j}QieTZigP?l~XNIN#{5 zKY97U{Ph;_P3#jb9s|v0PuNWxN&UAsfg~6>_9QHF%>2p3=W=fmfBopo*q^HQIU@Mu_s7r8Vnp6dRK0}>aB`9OJy`BoSKCIMCM2%W#Oy5fAGSeMIB<#17-2qZNY4WJETBl z$q+d}434y`Uq(q~h#e|!gT6I&*BdivXr$WS@yN;+-@lA}3eTXxX=!b3ZV>y(c26e$ z$n*;qpGwdLt@$33*?&ljPcM!n9QxC!l2@}#P5@zGT%SlW66=f8kV;LYoypl+8Wqc0 z7eZTNKevC`yqG`X3KeE%-WW^rd*@cG!)p z^>c>j#OFO3fq1+)(#E>dOD2#MdbpCcrh-+TjzubdZRM?wX?_gINZ;>dpSSNmpf`m? z1;){5B;TQ-)Ja!sREde|-eP&M#`l&0R`f4&PNG6TCb zG|7Om=$dyj*XJiAchKK6c~oR%r@P9&Sy~w6q*I&`Q}|dxw>JL16wm&fSX!ij(OoV<@}?@N zV;PNxyve40%mQdy_iuh4dX`;~adcB(3g!W@dNn@3ng;*Bz|hNL{>4?4p_&CNNjeuMG!veAi?0G}N(D04LuotgAkN zyn22zp;j)DA!Qh+CPi1_$x8f`1P2R;KJpf3Z9VbUN>VWS_X*v% zTg<-tT$pIvI+U})&Bvvl;@R8eUWi|tJ07R1K2#9D_U+1X3#M<;pRgtzO98@+&ahYh zet|fDE?-ies&B+5y|={m#s)4ZHl<*lVc2#3pWVbdCtt=$IZsh4Bh~$DYm$&frvmv~35Kf;d3}JN z^ZC%=-lhM+97j^t2DfVK`$s9iqkklGPj%}s{RmkAUPYFaB;lwn6AkFed(vs|v(spu z6{t_Bz(1hH&vegt1FvNR(JrYR#|IyouXQ{-80<_!DQ8cA+dFizw} z-k$4fG3wyVvSb71m9BhsEH&vA(G|nEv>N$TQ6D3rH6$$eU!Nx?W4Vqz?21G=ut&R( z0VkQiZt7@LuV21(GbC6)2V$+50I(#LJ@xirG$yQ5c$qG{zMEFoXmN7!ErH!w!@|oPbs6?w%syIo^rTv!;4~`C{Ju_O;}?}1z=#9&5+EWLD62ca8B-5TN0F| zY%Q=0vi-32_7FqadH$Mp6v5!4ttu~17Na+~E>f9luXNmHzA+SXJ;Y(zZVsfy5B#@# za+1dnJHGi-eq&?m^5w5~60YNzaQtm*K{h1_0BDSl(DRG0ZXx_nFbK|Qe3!ZP5x7U@pAILkm9Z#5Dg$QzFoS)d^Fh^o7+oVu~6wE87jjN}Tle2#AffyrHH@9u| z3#46{-5|>H^LKJB8>BU2za3^RBu66b)o?9Y9RPJ}lA5t#m}F+zklSQOrp=B^4D0z_ z%ff#}5_i9tqYpXo+o|5HTAwzV2Hx;}&~{6he!VfgmGsjafX&QBJ7O{@y*{#_<$!NK-eW7Xl!=kh^OQ~yf$_^$Dk0=JFMb!pC}+7tJw3)ubk?#Z9e<>&7)maKS# zxlW4PR#x6Xo{z|}iL|DEfDHaM@yi#9ws65AE^);P#<#j!#o|#)NljMe{8!3-C`4_y zYr1^?Jak2ukc{fUXKPQ338{%}4_e~ze^9L{^oooLIi^#-uNe@in!rQ?#& z%^$}~GMCKoHFp0T(`zxD_ck%WXCi5tKlnAjNZE3kFGNi?Cegf$#?%b66&Uz)?sq#d zsW*MOfUfk2AK|X5n}B`NAi*m7(Ko(~sv=?qcjyPMcQx{O=8{w7++G>@q+Jtym_m}a zqc%U6m;YCJod3YDiwoq^=FD?iOK#^(1F%6OxK#!5QH3F#)BJ&fOgKnHqgjWAE zIXHMM)4VSXjwpOWu`I&%))iWzGQLuSoR>VmZ>hknHj9igxb}}K+F~2c^uS=?aRfu? zlSqEIIK82PQZTYo4hX??pO^gqC~g;~zybBigN84V^yeWS`H=*8Gha2mWN58QMle3L zayvX+NqY0@teyTpPp!Vjh-z9rIwFE`#p}r)-wXKYtE~*)QFz5tL1L4yt@5jcXrmM&Ir@##c-wHIqE*~{h8iePkxe6 ze2_g6^jnZ9S6b<9HUpE1NgJ$oYmrti`u4HFJqK+!)DFuqbfKq?9>43>dM&c^1Ix0ad@UC)h?7L$+eMqHe8Yf35bCoV6wxp}_u zO&~x{M#GMM0HL6#Tco1*AeR-_+!70jqxXNj(junk-b#vLlH8xZ%B?2oChRaAOoC(V zixIvBg+Oj?*}|rc9dhfApv+zavKt3-OXz-O&}jOd6NVzV#1tZo4TJZ>aVeJ*-QGlT z{{+r#Y6JHt^{e^#(U?~MAa;hh%&_=N`!BLGj;LyO%aS(P8Z{)CfpT{LQemayKckK$zt5@;VZkvs}7)KNwR5n7)s6a&qG5;{#*x zP}Sk7L_cU=GfE@#{16uRk+plUcu=^@G9OC1bHQqZ4`t1n+yrkbL~3m~1b;)s&k` z2|F;jlk}n(fTg&zu@*ZQbNQm=cWeo8py*Ebv;1ZyoPkhSQS`H?AYwLFTy`s7zr#xI z7oXMyPCBcroeypYQVS-wG2nj7zx`B_kDqTQ8;MTa8@Gvlgm?f15bUjOeuaYb9HQw8 zp*YnA)!u07nB>}FL+c6tS@cZ_8KQj;`$q!-OJ6UMf^htVSbUe*(&A$4>6NNTyj4%N zpve~B$S4GZ<1uA)>Q=#Ta(ji~LRa+k0HDUi9`$OVyPWy(|B}oI*Hs;B>Si1};`|RH zn5K`(7BO8$#qI1fZi&eByESZ1f5^JxOn`(O>kvt*P`zSlw1~@}?b@k*;P~#s=jW&8dJlum42Z*}@>x`%#DFo|; zPYb-7qGRR~wTx-IUQE>44ocfu`(leSaOKQ6mM;Z2AjsO9#d6SHRZn%*%`f`RP0P#0~MPWrr%7; zIn*vKEG?3Edez6c+J^|AGU^m&gf0N@#(MKhMlweXeI+9-FdV#dRl&Wt21+~vix(FM@w!geZOawYQ{t4hsqXXmk8Z2@oI7tu{XiRREL8{} zYb9Ba;jeV!Zy|hvDU#n3A_d-@%8DTZZrJZU_>Y=TtB9 z$W$3qUV~_wvubErzSY^;{A6V_$t}Qifj}KR4wpWXn|hi~htY70BJdzjQ29pUQ@^!z2}su2Gm{djI}x zSWj+2a%xkr5!>*$N3ChaF>>Fw|6!H^s)0bBxAd|{fdh;~6ZUYzx87A`L! zOo&bP3}EwK8FfMj97zSRm`Jh}w26HezTI%+sb4z+bBexxW5mRSfHAPtgUKIhpenFI zqK|rB*Rj$C=>P%*qN3hNy-O36dQ&1Qn7~}B^E3(I~qyk$(f3M5}n=&mO@Ai#D===S|FYTDJ7>^}{H*r!s z>u(2R76Ddv3m)+En;naa>L_Wxls_JhSB9||*}{&_ zB=7oXe1H4b??Ooy3YeGT|SFMpOzWim4@=GzG64Lpr<9|=B5E0X&f%z;~ri0 z>oP50T?|*fHFnc=rqD z5?iNy)gFPs(R>e9?^`naio~y@)_W}Jv032kw_EKO#a2mVA~Mj=50Ik-wXL%p;gz2> z_@2rH+O0OS$-y%{fQZb9nBeBK{!vf*f@{&Em!tV6?j0+YIe9HK_T3NZtoTHcPz=wq z+jKho6hEv$3a0k-R+AVrSA&m)mzS6GYA{2Z|Cu_=gt@8>Q}3F^it>q0Wm1CpzUAIt zxW>by8^gaH@w{?WRZU``k1R5)0OR6FgXe9&>R|Vv9w)I{mp$`{x^g_8ANuM{ zMD^r{i;6F7_GgBxnRvF4BICFmYZV~S?Fo>{_N_fj!8^WfeEom{|AMuJDn<4sHpd;^ z^+=N#XK#y_XOiYlUa#is>jD9cm`oFoFoCbSNbJe=c#3TWFhit}<{Is6>ICmI z^2bavSpMGMJ=nL5&&V!KefaKs_L2&>P0r(bJzhhpXXRl{1z`<1JN#u!c^jo8at-6! z69+#c_rh;e6LilXIhObgU*Vy brznX*R{aawQ!;{ntn_Q&de@+Fg*Xds+Izs|O0F1dQ^nlXAOV~6gXyxQ2 zgg}nK%m*Dcs=*S%KtkYbnwMjb56?Lzvf`z47nd^i8037W^%tH*^+VA@_^yB)^vWLc z(=M?^Cg{!klu*b#LlM z$cR#(J@{&CyJKxM=f<%tIDaUYpAk1<430)I*FS90sAXqmgF7sQfg=_f$!5>PN&n^L ze2{X{8NeIX54!RUZCEDtp08B9r%=WW@ltp~3;WT-Hn%^o+$OqsuDO9rn2*B!XhF6rG6raSFPvlciin(}5V!TGZ1G{O$tb}jFG^mn;ZEeLb2 zH`x;a_9CL-6@r>h^|P}#He2paLsaoOdS-f_#hpLY(@#szJ}HF8{rcMV%ah|P0hM{H zap3E9)Ssp6Ro%m$haN2E3y!a8j`+lb!_o>iB-cHmb^IzQ$PXRfF9a_A3~KcF0^FGl zvxk?{f8{0S*2b(!=arWO;^s3Byo{H^|0e?$a0FYvRi!4Jw^+|EoL__j7)4Ah>Za>u z*x(cm`@P`VE89C8hsjJw7Sja_=hk1Rr7W>i`V}V6zjQNmIvLyOh$*^wNQJ94c{akJ z!nQDf?NaDxlzaD$$sm^NtIS3x0gBlRGZWPqOywM&Q*JviO%~9DtUeL+Zf3k(xx&nn z7PMc@7u~$Y9Hb#fwH@BE13>3~VBfyw$PFXW)XN#{6~D=tZl=`JH}3?*E(>YmlD-Mv z2UE)t7@lDk`#y?ApQUzu`7+P#!FXIPcw*St`kP5L*(Qu{osGr%uDci*egm{@q;6e= zfwAJc;t8DgA33|8*yvhc|4H)Z zBjSacXO}sd)$9azXTB~5nX6FcrWad)#~nidD7|_Ajbi&a=E#g8Tv~48d|vSy8!%rM zPrn-}`q{fS--FU!_t6F_y~8$t>tDq3P5fkOOoX;H2xy#yJ~2^FzTN2L?KFD|?`b2S z;kB$GQN1f@`1qFqtA?*1;{ADxiD^Q3GxS)ZBfq_<$jnfI3htJ8h+1sl$HKEW@g5PU zpO;inzU;ybk?Bk?&DAbnw%^G!FPLZtzWRhY47;ZFUKb8G ze|YkRoM~%N&sgQhG3m?d)pIi3jv71Bmlc~_I9l}-{^(3kwLEn-HKgC?;1^Z z=#qTq0j#q$Z*$tu$MxFj%SX~KoO88jhx%$%x@{1+v<3gG5Ci#1|ZViSaalj%*6-Vq`U))RQ@i6~`4fRl~V3YgFlH0H2A! zj~K8i*vjtxHXO^%o0uICK1S%Om&#$uIdq`klhanKiIfL6Gb>b}2FcmA14%GvV6B5S zZ>#SADUx3i&~D+Tr{Dk4+0UgKK{YYk&Hf*j&N3{jt_#C6v>-5aD%~yJLn$dpH%KE$ zr^E~;Lw7dgA_#)gDcvyh9pCHw1;0-0z2@w-*Lt3NOT<9c1G!5c7L)hM zcf;{q@A4!J2iHq5Ho*N_@AJ&3+S}emJ-ZA$Q2Jw#HBz{dih+&q1KL4rhw>d~x#ni^ zmy=_aZ74TqLgFH(Zi6Gdi>cvmtxC%?eC^uzhp0p1?%)mmZxrx`NFhiE5}Jvk!BAA( z(xflvkIM0ISRZ_S{h}#INXvbERftWfp<6)-ii3b}bb!eoQZwQ~9&VDfSIPyJcxdY8 z>rB4H6R4x-l#b+#0<9J@b%cj^mX(4caXG65i}xRlmd`dXs~N?dQeE@>rj8eOIFsus z@clF3;M=I>{l!CU(c~XNBhI*oV+P!JUFE*a48k07h_V^22p6bV=j`K=bQ7H5pz1D8 z*2w1kT7PJVhnd2}-2cTuNfmF98dDkPt3+5Qv!p(u&h>IpQJ^uC`rlVX3Z^SM=6KL4 zY-O%168A)*K`0)=X}FB&Tl@3pB6FNS!ysk-c#N!@5EsfA)SjkrlXG^ec72C=I$&w#g*N>b zY!wr`Nb-ZJpMJ&0PLBHMbMb6UVng%k&SBSOs5+$&2oy1Bs>Fmv=K5)l_Q~6LHVr(; zHwZSPUV3uA*f&eHnM$9SH9Gn!u+3_4CaN&{)uug>6(u&KqR1RA9etE_0UL7z&M+yC z(y_b<-XQ!TZdMt`xSGNUI5PW5M7$6amSPi*B#r(Oro{%&7%+>KSIaBMd11xg7gN_y zyWMI>oaNXPZbhA3=^+nrqQvD$gzZx$HEW&njZd^^M6O4)(=H1UkSSVQphsc4RZ~Cs z^hc6+orc5@lOZ?8nHG9($4^JwFULv4cCi=@otJN&Pi90C*8oi}P>~X`mlf2S5z#WK zKgRY!n(WosQe`&Jtv6L{)5a=^Yq`Nwv@KXO2bW+BA)4tL;46`?k3PLKi_*;V2`D|- zm3%1jwzzW~uOUD5AKB> zn+l*d$mPyEP|b7m*pNcXLT-viSp2Z%H#nufb8{NP>bvZ z2xL1h(q6a6xiuZx%i^0L&Lp9p+O*E}X!uuflD*DJKZ1E8RX3R><2RtCc<)SoF(X0P zFW8`-_cRgoP6(;Z6p$nLb?bQ;9wrGXHi*TB%6O;fs>=hz$(1%54~I*+EB)l1ho`+Z z=%^BQy7EouQz{UJTwccGyk_>OXso!4PmhvM^8K)wJQEGrc3B6%hJ#yU~AJ(Q?lTP#(Xs-!Y zSd&O#u|i4!#mD+B0uYI97j+pns3TO}kmm)3)$=rpue5M=>TEAERz>A_3jDxYCD6R2 z^z%aPETu2i$iLAY(P%kN0a+knp-kH!a)lc{KuMkssfy}U!QvLfV{fGdK=bR4L&Lzt z3hw(!x8%O7v6ClIyr3g-!Cf$L>d;k0Vz`OEBF_W=O2zr2I@Sxyojc|43?hWsi|5rx-mO~B944L^CdDE zm5>a%w~^^HPxml0hA;0f+%e%3RRWJK$4!u+9Hf7MTNoUc`>y-=p8w$P9xe^`_TS^p>S4&LI8AXFOJ1UFHo`7-r(GZ*7A9p_o zM4vhHst2^eK<%@WAhzjy=g~ha-Iq%UlNOv*4LFW#oNZD&d~>H*BIw=U?c~?~M_(HE zVAOA&NxM-EI=_=}@gc$4cj@AB&kNF@M-2-j{`$xr%EK&Rc`l9`b6SdqUaX?zelH3qkF z?W=!2d{M@_J0_jXc3F6(3KHvLwmE%>Wr@Uv%+N9#eUG}MahwRnQh5n1q1c)H1#7Gt z=71-z9vReS{}FMPHrz6pZuw*o%S98hlLsDVG%2-i%ZU`q)buKUZ#mhB7Tk?9W=ob< zDfW@%Qz}<8C5@b@I6bj8e}gg5pNNA853k1|EHc3hBesQ%P)m~3AY7TIv;rX%cHgg< z2TT}w-T^QPe#{t1b))hWN1||sVF+LyB-7bilBhiRsQ=F^x|L_323T}PoM z!@;hW?$rcW)1Y4Cu>muunm-pRbEeAa1}0BVIk}a)i9|qO501Zz=G6T7nk@>iq1NLj zeIGLj;9_!YVThYo1)}*xg`-ng_$*FBt1A$t1vyt30tPbrmHQCawb;W-cCz9)hWRzLi z$rz<-(%e%ma>P7lRrNdk)Z{nkSGRW?Dtdu89Fm~Om;%as;+cAZ)?(B2Vrdh9V=r{jOu_06a9G z^R}znT3HUvvMsnu7rq~(8V6!!6zSzsa{c}sXNnQz`8SyNP;q@I*-P_XWYWP*jd#nT z(LeKyl$lmsx(l(KLSWT5oWi8JC^gTXPf95oqY2v>a0MS^tzq@V0(+$Ocxdk6(ocj7 z^~tzxgj0<#(hunyvf>Mv8ZCa*WBf)!2IrUXnW?}!G>VxOyk(|Bys>rT7GotgqxHm6 zr5FEE(xw}BZUZg&#OZn3V@3oiNWjCAaRW>f-A+&FF9}_Rf@~8MpvZ*vT6SR*cfi+e z{Z;kTwxMU;-skVHzF!vj8h%Din1k(hn+lgCPK{5uHqOv~(UK$_W@jNQ?sIhbV3NTv znlmDQzJ>@u)({}jr@_z2`uW4#(C?S`RfnIK+t)SJ)4qA>v)GdifYV3%hGso7+9Eg~ zs2Keq{i-Ei{dd!ixC)y*j2navg5JQ&q`S(H!}GS3_wCr)kHG?_r&LK05-(xR2&Q31 zN6BIuEkS2R5cc@-OP2!Qy1kq3y<22Ea)8LgQ!3;_mJi#%d|#k2#lhqo^=(_RIJZr@ z=r6GCH`7Jmxaqk(qP}ycl6Oyj$&mY+ zS7Pv@v6rxE+ntCSft(qGl!ubM`t8M@;d2@Tkb0czjh*^IUP1mw3U?QkG3>tbtn2CZdEZ}E|K`fKV9$KZ{9)G5RGu7`^5(TF!v}1f z@o`$;vYPNzTk-v2!(UP)BqX^jdk2A$VoJbT+e3(l85q8D{B;S~9~&({-hSuKMQ6v+ zeCqm>o@HIxWBC_5D;DGtDw|Gel^mq@``zipOdcO^wfF}4R3HdH_|P5Q(kOCs4Ah97 zUyRA&vpevVPCU>5Om1{$`{5>{;(lOsewvoa!z$8~`l4P`v7y++h1zBQ{ig6556WJhqSXynt*&aDA zIpmfmBlly#TrEY|iVOaV>)ZBQ=Ak{EPsUQSs8s!-@0@U%_3Q($C*-c5zDRAsdy9eJ z+xxxD&S+xx;S+oQH}fRL^_C121#;pf@_zW`WQ#qUz6~Cx zDcFAMtw|FrN_ed9`p|a0Sfg%df@AoQkF)lSp954ra-8l*Zie~t?qj;W;NLZ% zKF^oPlaW5&`EnBFL+HsLe?%J78!{4Fb)t{~luqHKOc;(@-)V(YUk8%VIcaW4<(kSp zP*K+4&~i?5{)Ppj8}P%U4rkJ9*D{_b@Z`$X5zxi*z7CMIA{nifV?jFz%j?j;=za0M zozRlIR#42laWFMuTEhCl&!^W#xIK%jovHRz`CbGd$b_FUr+Hqj1xSUuSO&Swa4qCQ z>4^Ow{DMtsE<4@1kWPE_sLJ({H}tWpu)q^vu0hjjt^^@+JlC_7Y^a@XxUGy0H@?`D zQ~Bm+s$9ACT*+WeaIf`vfq>%+Z05wAFRZp-m9grZ%?uytJaHXG25><$PbA-`m>a<2 zW`|Yn?ps?E*4H9hVCpC1_b|ccWjH+t^7=W~nM*wQkl3Ql>(#&1#Kg2=AU5zu0$NAJ zuMMKK`|*Nyht8i|*y5qecO_-^gr{5s1Dvtlo6|*IONez6LQ-PvSGuN#Vrh5K5FSii0W=a zDW2V;gbEl)9`zYlf4NDRUFzb__Up!i(UwcYPda-Q2_6h|E(Oc!*tA!)`L5nDi7sh- zK9Ao6VXPY}nOCXMZw@Yzlz9CW|AJ8oioMzzt#d`ip_SjVyHrvm2U1~0hX?yQ2HmeH zG6Fu)CXc~SM4PYGkM;lh1-+w%k^nDFn*M}@%dc}X!2uL`CGT*-5Lz7?9lK*DwZxIY z@@uJ~Ix>GaFLgbpp-B2h4i5tdi}fq-Js5AKOlEAUG7KY1CW);j^LbTN?R_aGr~lRa ze4<%9XJ#*3SRzGSQ%NwNai;d`?0e2;DdS;WKJ}-OZ%$#({P#5_(OxJ2&jxitr3Cch z#NgTx1|rT>MSErxtVrpF7l^x`$zjD7DA@iagY|6X zefOw+_l5jlvYzV5z;Rg6+&)xx)T)1Bt7+_RC2eh6;Xg__;WK#BMv|)HhAiXwAzS|5 z{4kvo3k-N+b#2%Ng<^b79Vv^#WS9q6*PBdK!{Dh8O({|)9tH<$=z7V0UuZgbBy-dk zp(4lkTB9La#YOi7%$SqeB<&bej4HCo>?6uuKb(ogMV6*|yFd10y$x-d}`)%ukvkiJ;NoGst z`2qGQhS$@{rAV88Qbgi%<-7`B>2;BSWCflx1j_YqkV}d1l^AUqIl#wylxF0zd|c{K zHXjJf=X4j{xGC-0EgZ(?97`P$Of+zvUqxigWSfsNnx?VQNk=&{oqy0S6?Dij+$ld@n;pil%rD1h z(|=L$m~~Pk9aW#rnEw2suIi5xqN&qbb+YjMJS=a3GulzVH?~tYFaoC0R-D%OLCDHs z5OkD2oPMq#cK&jQw;CbE)OEi8+0>7dBR;TKY*ZuBiGUZ|t&p7n4)L1OBNfpvEm7Mi zkM<@t8z|9jeff_0Rn*t9A!PIiwj2mT|8D6!1NY&aSTN$nuem1oLb1FNHmKB_AZ5=d z`k7?c=~=w@)gWm)3_IJIvDG>0ZRg;oy2Us6F zhct~8o?E3dTjoQjK-XAw@|w58D;YTiPJVismNm=rtoKo5+y=A}>1P|Km7w+%^u(W6 zY&u1$M(89ApzTo^aJw+EXGF!wYJ%)>-K;9BkMvTtiO10UtMzcivB>KAQ#pGL8IB6p zN4=}OqlfHF`fu9O5_|2P^IBrkUg(y;whSLm;#J$ej!^l4dQqgbV1H_piE8rcwfUK~ z?GvkHC!Q`LI1#LkXCnqudxYaI-ys&-o5g9beHk%tjoI{5eON{63;_+A;G$=#US{yiM5ADxlsaC2 zPxkebHlBO=zB?<{tj>YAi{Bl}kPQDIWHDIGdG{n-B;M5bpWTJAcWvt4kU)H)xv!b= zu}k$hTWPg4!fG&WNZiQk(@plK4Pa#%ik{)0)Khp8&)n8nJ{GC-dFb!QY15P52*1!$ z95iC@(#KY5Hb51nU*r(`XfA+y1#h^=DBwNYw*sn?gdzm0V9-?v>?#QpPW)%*K&+h{ zPPCZAd@Lg?AcQ2B)AHgqYwhbP*m4~~SD{8t)?=RVLPo?H8@x}p&vv!>&8r7L$isC7 z9xm*32QyZ+jHB*1A9>EG-UQeg>Zzf*TC0bz2GjILLc$|RWe-y(*LDu(qLX{vRHi84 zx?yo+Xeb}eVrW4-lt3u6| zrT-`izA?0O?Ilt2%V)$*-sYBJI+kD%#S2~h(52sPQ2C8_xsvjTi;hCoElXH6il~J5 ze#RNFwEx6ZbpH1Z#&{8Xuo+6+w3y_$SEw8Q+;uVktIX|*?He}AG%FE$C?|MXq=!=X zjYFYnp$~MI7+**7w>QQ&^YPN7l`&6YFNW7u;~gdV53K1#lFcTtVBYLcMm2O-33!UQ z7i9&))fGJE=-VHc_;F6T$s2OLN2uR`eDkL<6&cFZ zFjLFfCv`g%i3V_pH<~JSW&m$OMF?{6R*1^`1NzwrP{MUa@I3;ol)D|~c@!H2a3=Ch z_olljwxY)`W4~Np!8{1&-T2k9h zg@-3pSj*CLN9>_wv?Ea%Cd3;apAhdwsss*Fuc|d)?-t$?aSDik;8P$U`B%H^F#ZSv zYzdAey|jgnKquOL)OAp%luavCdwqf1xfg%)DYA5hei&!geknNZs^{p#{<iOF73!PqH0N&4gH$fPfeF7Y$|w7=6$1%7{KI0sW-k$uPD8CJiaz6 zLL{Y%THRF3-j#m1`Cu&>e`npH`W7eCKJ*dXX(QCY-`tO7;tn&kkUKCIFTj;>@%CEk z{YX|@T?fn zXyN%werL$@Pf~3e>`R(VjOjB8e)*CZR}R`t_u;1knL;5Cyy}tits(hMpdg!e#iW|D zeIvAAq@1sRyI(@<{+Oh@lgcQv=RiB2-l?j z=<w@q7A5m z5Sq!Ki%{*(=%o*cIkz)Oomq#GmP0SdWR@_NVnk&)udhSI{6(6-#CT_ws($5=0 zty6;m30rBV=M)Rn?lre&D5@wwYM_=Q<#CaD4}#q{ezu-jIP!q+D= z)u&ny4mN|4uAMBjpo{ywY;My`f|%9yk7>_JmN$g{uB>tn!B$q%-miC?VsaC}Z5p$J z3r!YVMn7=iGx38Jz3vgFVY*llMBy*DQ-zO*EmC(Hl92#KmKxm%G?xV01R!vs0N<9B zFe^Bx@Rl#g1b5xqJ#At|B4s*rj0W|LUU5j}BE&L>MOt7ha0Qh`T%V4gTi&xz$q`#p zO&FKQsg914f^$2J)5&)%zJ*+RHvS5^WVk`uQWfWm0VKn}fBZVumOXe4Q@_>7h$)v> zs*-pdN6jDBrpB~$Q){?VO&eb0B0>xinF<;>nt6NjX?4+l{L2$-u3yy{&USf=>^m~% zHP7A_KSgbY=$3Ohl!MMgxx_1mr_n7$n4F{kxZ0OTx3D319Qf%(EuP7I4DUPQ-RgdR!e{Aa#mBYh!ETaH2zw#Wf>nXBp)L2RM0gIY#wuXFo z``wYGXp(UwY5M$^+sR|F$qg(k7d`p{n3N@4W(cqpIOb7+mfS880WE(^3OWdu(?r>+ z;lL*QTZ(0b&`O#(-Z9u_m+JEw32D(r&%62lg{KK$<@3`5*Cn*RAJ>zaj-uo}9;h|o zI*&~@K8?coa0A6!eTzE1Up{#2I9tOn7J1)G@0v3@=QP@%h2~ zg-Tbw%ej~NfBi6v_WkdRP;36a1EUZA@meAO-6Hg8cOWPJ;9`jvq1N}rfGE#go>vs= z*OY}1Uy`!n+W&bMK4!DNewKcFgaT zVHei4^ur1XCKHwez6Hbz0H0Q;fhgeM#Qf|!enEOLTlWt(9AUA=%l;=6YMQqs*k_%H zFSBl^3|dX@Yx~I4k4em=4$dUKNz}bzHYh4`e+q>z6O38$ zv%!NX4vo z-zWM6x&IM!{FH}o1=R`K>-r0OnRmBqek0CZwX;vDjM7Sc^=BMND^2gK~DbJ zPu^gFNhd$2PVJ9i%nqt4HORsEP~7aza|cTEgVqU zlu6tS-KXY!PY2Zs{C-A28W4 z&e_%A4DB8Ipaz&1hq1USjlsse4)gmC z33Lecd-d$Bef`N)jA+gKBC7J_9M1s0Vx)CqCOXSN&uaT>Z0z ze9pk<&uz={;m9zxN!ydf}1i!*@ztJelQk;3}Y!raZ%eBKwm3*U%2wQ zebHj&;Ekgs`Q7eDU?VKkANiA2hH6&%qKr<{HnLp1OLon~;HX5)zRB1yue^(Q1DB`DMt7MZl8hZ(|c^49=OBJ*{DDGbm`|Z|4r4hC& zIU4olw!SR{HJ%juGYH7MS-MIru|zg|{smnE5c#-DK8$WsIYkOcmz21q!Wn2$6qf)xBNYDP!cVh zFqZq+^Mc#6l`9@l@;eNgzwW_z>713B*j+xF zkNLmiB)$#G@(<9zX&v|0B{}kI#f1OK_!xS!f3tAzi+E=>Tzo7t{emQEYDK@+wLx!t zulw(w4@t2MmchSC(1ZQOt;jhOH_kAuHI&7mtprXrDfUL1WX1J*8Qu$^+|f9;wW5R6 zcX_ZKG$tsCAF2G3Tq3A#T6`D8OPs{Cj&e2QzLlaUv0PpR%cd6IGLDBvXuJOf{iX_D z`^fOtvw88?(yxO;S(er3=qZL-U$>J&r!SLpe_@&Avz$08l8Hc8U1FS9Rfm}G(c1)J zdaP$A+?_UhkJovG*LfpsPJ#?E47>M_VACl_ zfie8&R8td$PU?<^tuy`vqZMap-Bk+55C4H9?UO4b)D7&PY72lQmZ$WprFgisPVboJf)*dZPu*UnBylU3<~tJ@Qa+(6v=M=aQ@116;hfRAuF9F5dA5;@P}Q81Z)H z!z`2+9ks2FffbnqJLRwo-AcOqU8K2I_2q)EN?G>_6-H4YP!+Xd_s+HQZZ2Ye6uA@A7&?IZ!OWlkR$paLm1~>Gufy*8}eQ{Sva++7}n7+=16 zrLY9btI84d{Pynh%PpXqTkk58Aqg}De|0O;e|98jva*e1NhtHfs3xBsNSuPM3J>py z9Qp8vE~^f`3C6`&0)7PnsszdzL_qffwF^>=ZmXv8fgxy4SzG5JrZQnN3#Qj-Fx$P0 z{kN9%F`l<)+j%_)E4?Q&ho7efvxW8blv+}2tvN&o`|Qov)Fjo?(NGT zOw|~zl2o_9tZ$9TH&6>Gy*5Vj$E-8Atm(;FSAq$`Dn_*B)X|=~?D895dVIl0vJ`|T zSEsGZ=UH|6c)-BiDtujX{+&@>eGp0vl}B%dXz5*yw#YM^y4-dj<{tp0?hb(Uye0KR zat=#UjB!5{>YOc3`bLcgv1-oC@l8fG@m&O!-(iN)tKv@mE-SkAY}&HT_x1Svp^b=NFB??xm&cg?5xilPIMc-N*6J=U zRDot$BXSUs%WlR;eHN;4TwDQ85dn-hx82?^Zm5pswFwGnuA&V=34e-~hfL2nD5ZiZ zEtER3hEtaMpSBv~;t^}|-KXFM_-LGA7~HY~vh2Dp(+9kcPX7arF#Qe+78E!20=*`(v=;Tj%|p7bL-?IvC+% zB4-9x6)J-9+rQ1cEAfkX=Uhl|J5qOP?+>vfPSt(c($F(|G(=6;tM+D-+0EQM~6Onp|@0w^hry9TQhLgn@RTZ5kcEQj_LKrvL%IPq5R z{<^0wZfd35+v3N76EyvP;fBhMkp?VhcQvUQ5_!#R{Q@aKaErY&Z;IwUaN&TYQ$yn! z?e>c7L*799oTsc>&Z2X4iaO3z%EaABiws|tPJ^DQVRDSk;7;eFlQEf|5QuyP zF;d{JCu2fdH5%r8kQX7Xv>nd5HrSvS2#bmsP`>w!hzff8WfqdeDCsRYodtM=zhQKT zz_A02rXPP;ol`J?dnevxoE~GoSn=hZT=?R?Fd$HPvLb|{`|&`FR>%Ifzw}Ig9gmdu zB_&(7O9ad7wO2w?1P^U1|BlML^cR_}8Suc<^|=*k;JvN=3)CD%>_wPg+Icv!pEGkG zBb20|w-{W90Qq|K{w@xECy%R~H-Q6h!)e?5>{>6qHDBcz#6VJebhu7xowwI3*i4X` zVXb9WzQuVucvLB*N(tozT93-esR;34J%7Cc}=*+1qBYF*c610TAF*eYB@CF1# z3=-1m?1xH>48zo`#bY8`#QT^fNQRAeV1r5@p`!Iel5zVsaZxe9=(U&QW9W9ilAU|f z{%P_-Ft2GbYp5dfeo|AFpeuyJ>)~=3}5Dx)GY+ zOtrPeYiL-V(a1mM8Zk9VHUZQ7ESL3`__PK+dYX0`N_!Y;N!IjHP#yCOgB8KqN%V%dI+4NB21c^ExWa1JmPIRFkn(-kCsqy&rGpSeoV zrIsE$UukEs-PzgoeW}R5aBJ8N%Cx5M#s-0;FdRI`#l zhA7}g6#+jmn6};;igP)$xBzI&)s~!TqRMVq>=Qsg(J=lV*b!BCuXSD+G05UqQsfH_ z*TNsh!IbrW%29R5kHN@W>U@3*^?Yn` z*+FGJD0Fr4ThmbD&nZcQZ^uvUPC`$R3nP+Lih|0}Ui@sE0;5+EuOQ%#x{gIRrci0m z1>Y7E%--{fT6PSsXa7Y29ug@Au_yfG6XypArIDuQy!(pjujTfBn4Q9=uDwNwISZKi z??7!<1V`tx`(m@qJ**pLS$mWu*X7w9v<7rEwl!=nTH|j;w`&eMh`_a??}R`QPC#6j z=xMM~lb{eK!V&)tqsF?8+;Y*rO?^I#_-LR2@a2pD+$Z#c^~|# zmiFoRD7cq;MpFFeh~y7{*`b)za1cOx=eH^FmysPTP6>R$goRSxGB@@AGq7NUn%(LF z(ggr>CXvvoc(wgZ?yemGkWRT9OF9FUp-bz82VuAJWFQZB39j;*C{+%6Q*ben~0mN#F^mC^V)r?8@HdH?$?;tE`1itCYGs~W5z zWT`dNkz*WOjGB?VZ{ByUW>1R$GQ0*epz?PJaNtAQl%Gp6S^QFBWJwBfJJrDOR-CVk;md`Oele z0&0#A8N!rs=^{?M%n!`i-KlBD;cT()qt+3d%eA_62|WiEpe+*)SAS{#`4(hLtMUub zzu8EOH04rAT@aV7ZoD(f2>8wea&3?_bJ#??M!Y%rbhL1>ztv}T5#}N$TZb_kmYYf9 zO9<`Cz@t{e=kd78%BL2jpNv!F#$kyCV*>t>!kGjc(dK5cb4bRH(SC%qy!6_4;NXM} zzmAY>Y0du*;T&uM7f=#C^zmJ*d;jktLhH}piKyof!s2|ug~u&up8!ZDx9!@6~36dU#;Aop>+Z2{}2 z1GB|Svu;KE{V9j7#0(JpWH+?QBI?b!$5JevvPS|I&02n1PJ^=@HMm&RVSj0TtW%I{ zyS?%iV9&sy6|n(d_*$KcTD`e_J@@cq=sJu1X?d#%q5gjz2u3V>lyooHWHI(;wAX5M zzcON&KsuP}khkcU%)D8bmn)&ienva-2gwopKpw>}LYHY)zRYUsUuE(|*Jh)Bucd=9 z(k;adG3Rg?7!A!a!Dkk@puCCPw8qa!lQ9E_*aqNWR(l$kgfo=m$7ehz_MaU{4>AZo zkn|bwYj4~H_5^>=m~{MmV|52TzYJ30j%TeT=`9A6K}<*xPWVZYO)Xh#;RE+0AB}Bg z-k4Yny|Vv)c^ztM+kMi{5UOBQ+UW|6VdjG{UId=ECj@r2<{!gw-hA!%K=e{IfzhWT2)o{sITI82_`}o5HhapXxX0AS)-sBi}GZ3`yyucs5MgM}I zjFI^;LWuprO}5WoewTziZ0pz-HR}SPy7m{&yBy2^+a%el&KOC!sT)gBfvASGHJwJ% z9_i_AV+UH=lYi-ZQ=E@8R}C%|38ZRkmVbPMPM(jBeiNd)6t(LaCz=mGQU^36ZX`PmkB~Qv6M@Qf$Dgn5bk>HX+YGt>({v5eU58t@*cnsyU*l}BrHpU?aR!%8n(<=@dsW0wyk?sAGq-*y zykiwo@}EzNmj!z`dO3LQw;%Ju`{DB%Oz!9x8)a{A$ULBA<90jYM8gL@yHw@BVV2iL z#Ls=j<4>}QRMCcxd`tDF?vKBL5l(_C`c-~!5O@w^5lk*I*G3gQ;0*{6RrZf9EH{S; zXY+BJAO{L^vXMK3YAhJpdBm~5bNF8jECOG z2q!12Gfo7Hq>T|j#lXX>{K)}UpwWjkI;UPU`C|QUPx!NMH~K7MfJ37JCW*^& zfl%VKWmB<}_*J;Xp7?&$!;Gmrgxp@1sT_ivx!9W&06W}`&CvT-azM~4Y3Ww9xi?_r zvk~Z1*T1GP7IM{iSWLgQtjk7O?yIIi5}>xeG0O<5Bwf9h7x|@X9(=3Lwo3deoTXHr z%)XPJZ)pxOK==R;j0ff-%H~)V6!di53AoGF!eTv%Z$Rqs71RH_5!MSV(Dd#{QP6m4 z_l~fnh|UOlAm<6OM3&4ep>C)OkLszshF$xsz*go@CgAa|yNnuGU3cD8VUpsb540A2Fky?$S`RHio)#9XOlc=< z!-xrgD4UnZVp8p_%$X8Q=0UDkkxx}>POKyoKf}|p0a9UkZ=S}U$xrRHWxXAo4^LOb zGJ5-J2fCG$cQRrp!Pm}Sdw1js^>c{smwGY^^C(p_^On9fHI$q#?ZLn`% zyT$Up*uk4xc5o6)+s|TrgFZ3!xSxOZYlbBx=drk(-^B|&|23<{70+tSxRnkcV0qQ1JBxoi*uvkvRw9g)9IbVCs;U+7?yWdbST`AH5DC0NBs`->9_$=J+?#WDaOQdfz?sU z#72kESKidfpNPhXMd`S>jYYYO%x*3PEa|vUTPB*bfOE@w0pNCjE^!t8zxp}TUz$!t zBksw>u7&Fa`kJ}8KD;_Vz*Et~`q8(ymRb3Sr`yNO1ocZw(V@i$J#E`2pO1cG>r*P$ z2eeVHo)E3j^|ft&Vmf0(i?7^*utSBYNFX`zN|1_pw|O?@L$$#^U!7Vp{m!cGqNFDHpwH%5fBUB>ip zc!q$4xGCRVY}G_`hK-{HuN4-E4K_RVrwp;k#*1k<)gp=Viha{>4yD@ZVI&pO_!})T zuvdU2QqW+gSd>gZ*|m?n>y7%s(zlMH(x-uJX5fq%KNKhA-g_(cPzP2U3ml!5j?1JM zy32G_ibGMHiqj7PHmVzqI|AvL#enwxjCREAM-_r6+;Hol3oUc1EgwBkO6wNMTl$r5Aykmgw^%z%GFbI-S|P+xt-_ zrZlr|p{|y7mpN9&#<^@p;d=2+j7u;?}&aX|m=@V-WzP4-4498*4=7Q`W z)MDyT;wdX1{SjZmf;8x&@`cRkt;zT(AEQYrpdoI$m6#5FYzRyw)~K<7(~8IP7kC-Wbo z0Iz;$C;#2H1JZ4WJ%J?mbD80`00k_@<3qdq%I^!mQ&{ezhj z!WkQk#96E%Lh$A$?DOZSQ(4w~g}2OA>&2;pJQoD;0B9^3FHCE|;whB8i6MXbAM}9+ zW7`~ct~wXcB&pTH8TG437B%gkvclg`djS%yyAb*LDG?Vz2&yD;XR~=BKqcsiETis^ z5Drh}9oRMLgm9m>s#7NV{LmkLMR{)7X_s@lFc`wrj5Er>mdm9b%2p;&o$s-g{|1(> z1Z3FUKz$q^T#R&8YdC^4fYbUJ%Jlox{~B4iy`lg~{;$2W{%i7k|M+O6Hk5h;!Zx}= zK>=w-NDq{fk`QSDMLGne8w6>P5Gm=DknS#}86hBp(Yf!9|Hk*T$9~)6Z0B*`XXn1I z>%6Yl(?*CK8+CuwMG?Z9)0=fqin8C_bPKd-q!wuqo;Ll3_?4VC$ykY(HfyRTB~MKV zn=Z13TL0Fx5w%DCP0~uoNUDfiWqhMZcKHG$yG{%WJYn|V=CVAB? zXr8)}d2ChvM*Ca%fm~{;0RCs{@S)^*W0i=_B_p+FhCc((0ivfr#%v=x1i?h2!wCP| zbRMSWVe?+ka78QLuZEU$&w|TZss_LG@A9wD4_mY4M{L{EmIQ{lWm;OP$~b9%7hr`$ z5}&tPEmX#?j5tEo>TGa~j0QRTVYlgmJB1iL=q5ldLFGdP{i-6|JaI zPrkdX*m@<|S05PoTZn`7_@F)Il;LsyaaTQHHmdC=~w5pn2L@EsOL(0P}z&|{xWia-7q1Xa+-5@&Hn zsqD>XPjB?1_&aMBt3VL`aulJhIp}m@Vttu#5@w|!76kAI2AU7LcOuGZM;^Bw>gO~c z%aLU}@EIyvPc6dZOz~FB=U&g};#mghWuOh%Lik&z?1lBX3WUEv z!U@1XGKPjEzt82AYu{?McwMs?8&9w0hoP^#ES%q1+tsqZ2w%)*L!>tI1s#+JzmqAW zNy~8MT32BzY{-$QD+15nDBwMtVeytyy^}q4{3t^jJU{3)jo%Y_y{R+a&vjd^ZaYlL zJ>B~=+hVCbIKEt5atBo0Z8v@@_ikj9UF6xMQW(LYgpr{``jCr?>s)a!x}GWQTP}f*nQ} zdC501p(eStqB2Gc0VUZD)loXL>qaN(f2*|)BccLHDVlDqXAHLoJG)fKkxX~Pl%@!j zui5VFVu~f)=>I9si--~IdW!`ldE6XCE+=9iAdiKaL*p4)iFjAmvoV;X-{y%Ch^Gh& zf@DFtBpSnJ8B+7|;XEUz5`7fb9xgWxa`>QgZ!}Xf&G^vhOz-k%Fi-Z?P<41IS&1eR z5`SRnaHsILn7N^~&B>19?+*I`n!D?yVkK0UIg+Dshspg45>)P8OB^Z=Eyw$|XcNW3 z`Ij8*?w_jn?Zc8HMwKX07+XBseQ3P=bM4WaG0{d$;P_tkql= zTG=m%75n?h&qEfh>*!JuSSsXqg*sPJl{X|M*u&9Zn?>wux_UdTcYl3h|Pm zs;54TKu1_BoN$b2`ab6ke|E*gcl$)tPgLnF4Tu!#B1n7y2diEN>%~N1=X#2SVv^J0 zrMw;SDe@jKAyu+E+FNm?Qr?gBIj&GChdBwR-kzcl-k|b6jcof5D_2TC#ctuQ4r`n) z;rxqk277aK96Mr7R6fh>GJf_`Ykd1HLsF1b__Ui#q~1Y3z2ma&bw$J{vQFE8AUgz; z&PTyl1A)!S-8%x_>R`zFY|3SoclKJcoBLPiV#2Ring**8B``qq_Q3E?3Dr&8dp~Pm zd$o&&SdWmzN()f#LKJ+?x;uVFjI^4`(s8pbwSOI zE7xy44f}5GyQ3?TjFrFhrN3QR+N9t-qoG_cE@?Ga00Iz+M7Bc2!g)jO8!-Muv~x&v zlM%HVy^N?sWNaKKH||ED@8Nm8a(41CbAkf4`MC2i10`9`&zHS@3bY!gX)=;RcjRP> zv}j{7&>)UN+dGsQ96Z}~w@p}dJpC^Fl4jDZnVfU0C&sOK&kM#E*o8P=< z5Y*9xznQ7(v@}S70s}7d`Dh3e`$Fz|E3Y1fILn4qB9Yi6`mtTnu`84wz+EGQtg*wq zcVmY|gL9%}Ey!}kDi7rg_rFhM3f!Ym-V7{zY~O`#mtR!^%#T-ebTw$-;c%()8rJ{4 z9RJ`TNeL&wgid+Xzx}4+66bMn2(2u+==bp-AM@Zc`v4=BO~npybe1K374B@fr$ypx zef+MT0$#8bK}%EaqB}dCsm5jVg4jF4>U(9L;Bejk3kV8q!ef|>>ZB7JWC{*h!y; zRPPqL*MHA}zob1q@RwZwx)<=RFb^m4Wvj~CnKfjr^00?DUr&ykSIWi2)pV*o#eZwn z6{}F&7$_s`_qF`(!J(<=j>F=xm*0_0zs>E|u43W#sPs{q@I|0@Nwe>0NfP(6$I+eV zm%R^uTn+x0_0i`t6h?rDi&P$$4r&W9)luTxHcl$eyKoz{thh>#lpN8xfQFRUYcou> zFo~<_BSFH%B@&73&K!nyiRxp*#rfe_%$jY`pd6LUQ_DL$4K$@S2I{MLJtAlJ@BxIT30w&9b3#cWr?9JB)rQYSTp5HgZt$lf1 zWh`?Z*1D_$S*!tcL{+*GIbU5NyEY*#=HJQwkfYxI_k0TWl=I_mY~Q6A6E7tWdC^3aVy>9v+c7gWH?j@0%k%oNwmrx*6jq#nm|$0r^kQ7 zQv8p8kaa*fQ~b@y?25KPaQcrgs|+?pR2hQ|6wOKWC`h$4qk#(XL_tu2gcj7+$$YoS z*p02G`@<(Sa4}Sg`iE3K;c{6S3~-e@eo1BwQ2dpX*5)>Cm#sFtBxIlrQ=Ddr2 z&)aRaJ+Koep7qH<%}amz%q%LmvU#Y-Mb245H~f=Zud`D-ksu*#e}6y*5g^yRElR}h zZ||O(`Ud{k=nDWe!^NRzMaTzro4JtW!^7uAa*qq zj!UR+3R|(@nr-}X{D}i1t7;qT z-vU-))}^C9&T#T2N`ys}&?Ap2!QnF0y#irdj&g?WzGy9%q;Zk}6_iYnCWQ=Q0)_LM zvCt1IeI+e{dMqmjSn+aF&kZd!oHdPo3$Qwry5ox*dq?2U!*TFPG06(9O6`O8c!n>i z2Fw5GLO-0sV!_C(s+YSyblpNhGOJ(2G{d?fa%Zc_i>o@@=>u{J>ua=2o|hd?EyyU3 ztOj%->2QODQ=Fu-ci-?Y?}>U5+t5aQ^4z-acq=wa7%r$Kh$>5?gF5d4D@o_fvmp9F8o{^fQTy}j=w}K`KR`yZM{VtyVyF1IwcFF5grw~X~4d4<5HFUlx=KTVFxE9<`Yg*XDT+NFE!1`qB*i$J8s(=!s z`@J)>z-B7J^-&%o7qW`#q8xi~Jk#*{7>|%70ZeQbDH}6PxYbK?AB1QpV#Wpkvedi# zi6N8r$s9|fq*Mc~R$pNJlEE4ciIkSh9+%2(k`l*wK}26&vB^}qK0J8wF`VhO){C>; zU}t)2Bkw0sS*=^(23!Eo4xEVUP#jhM*8(sqJKAibe@a9H-1Zaz8q)B4Z@gh_?o|}d z`z9Z$U#M%r=;6lmNW4fS#SzS2mj_|lH~1|SEb7r3P*d~KHCK*!LKp`B=x^jtI9aIs z#rr}EpF-L!yrEMc*F4%as5ZbjJsB!|xLWs|sq)J%&Qhz3ArKP6>rYRemSD zn|UpoE0`o7P!eWq0h%+`(H1=q_V4`hJtq>tyBPWEkqnTvB7sFW5VMnQ%(=`iEigTb0Gjuhq>& zxrbE`K~Q`OWL}#j6~z& z92@uCQIZNWqHJbX6mF_IP7QWGLigmsbHip?Cf`w*KQ+wO&j_XiK^?(dp+?^rQQ^hI znf<;GWMX%MuU52%wU8vhen=nJjTwzWJp%KF#0t8NY=(2dDIlAP)VW3!!ChihY;F|c z8_iGIwdjBuFM-##B-5slFYHe*`!u=Iq&KlWL@+i|5A-z(;1%TOjc8)OjAUkI!P-P| zHKPDrU!=$nt*`(VdwpfPCza0S(VIU3Yx+?LDq^#x1e zcgy$n4qNNGA(>c~z+hr6d8nmE8UZ#5g z+540xKzx`b03lVZ80pSXgb2ORttuLk4y$s|C71Hr-K!_cCh809ILd zt(k$?J7C8WAx_D-eBK?yNRrh=CZnNRJ`pSjwMaDwar#i=eo(+GveF9M^OW{r_Ervy z7ivgc5J6rPHmgpP=p0@74<>>C4C1mKN_S4?ykyKT_r^$ZgN9L`y#+1%I-kbNmVr>u z@NR@CI!50u2u4dq*y0Mmj#>bBcut{1guDU-Ygq1=+%L4jB*;FDvq21%q_M$K03Uq0@JX6!4N9t=5Fe*F*u@m4ACWOy&2& zjbjr-Pu$^pmG_}NQ(;yGRE>2^-iQVtfMhy9L9q5ZCLA+VMF3txu9dw!u5}_3zJtF+0|!&KoE~(~E-V`AL={m!Uusr7K1NAKLGWP*Fh9ZDC7a4DmqRJA`bN}FZGf$iJ# z0Z5o4X)yek*_Y31AFZZEqc2}F>f;5~<1UwELbe?1GP=1sM}=NJ?I9m#mySxvD(Jgx z`y8!Ku6HC*(P4>MyhIJbWcDi#kXKHv%>BD&&i5>JkG z^yFo&QKv|>P6;}S(r|eDGw>XIYN$&FxJjmV+n5hgf@3YH>C$*mnb^e;=F0<3*U%(t zeA-~2t*Dtwiykj5HS>*@QGIAW!s?tB0Abbt(g||jC%b^SlNl7<4BZPC`lzg5`N9tJ zMTu8e9n`5ueo7f$Hw&#)<>2>WM+XXcASxood|}gcptJIu;ia{an>g$Y^W}ppDK6;< zpfP+L(^a?Rdp)I>P|nDXwlcK}Z14fYahs~D-23_Ng?^ybIIu|-ty+ZA+oivoH%Lh1 z0XXM$9Dl|%qW{JtMjr5Sv>!f7$m(_mOdV4Tb_w9wsFoN#*6{#u#{S@nfmGSGAEC^6 zyCDZ{R6*LcvDL`lTJH|s6RsC8(1TcWQ%GDgeO71}f{l zij<9%vUAglT-3ZrYepFi>hn+9vmYKw_wN0^HKK8C1Tj17C>mfVdS^jq7TyxON|gNc zvLTu`NGoY1_Y{Zs&3&h6D0mI3E#0*h2*1gp*m;Ha{pIuvzEK4)u<2musygE`L9TNO z`AL)`?HiG~5ZGRNRT`k(6%zyk^v{zisY#e`Xl$RRpD*TKE>F}p6mR2hj*eBk)^&I<=A`g7$ZDnm!R}_dD_t5hBv{Z>I-14nxB;RcdXrd$*Y)oHxUBah zth`x!luc1w+M(8NEQDIo@DW06AdDtf^|6i|;`DDW) zKkZNww?9ObFe|@RWFvz^GsE@w!tgdHoz)$3Tpm*IRnwbfU7ynf)AI^}mTL{-c7rFu z%RrwsGw1iYJ~(e0me;(%h8}nq#?|rNmx+jAN!pQPA(}{`?O&vetLIm&B6d#c5)bif zI5hl~@%1Bazt}M7kgbp`ef3NGlI&a;_6WD)@@DRH_ z<^VO(`fe%RZ0PxRDe3ppb*4Lxfa?tc0V(-TiF_c{>!S%#wF-f>{q%@^hg4ys2TZPA zZ@8@PKJ<3gqEhbfSG?0%g4KHeG~pr7?*jSpqZOt%2XD!7R#6;->vc5=1(TIE=B7tv z+?!C`o*@-kDixj5a@)@AYJE{zo0X2t29?8un-FJ9$%j!dKx7UcS-H3RpU;q=j;(0% zI_8h_Pos!MS2C@@P(wS8wT${v1yCd z&?iq=A&^TwEKbL&-#4uO@&l*TmzQ)It`7LM`2G|4$)fO(PJ%E#5A2niT36AMqh|f~ zv977xqF|rjDn06o5cHjeU(paEbh6^B*$VClSc^a_K2iROzL(ci?AqtchO#g*|;ee9FzI73W)Ke*-kN5g;q zLT%;^UiCZH4S<1_A{8K9;jazHO6aEfJWgTHL-O*{WMIsn20@l=$o7qBX6Uo8QjGoH z^$mo7u_pxDIgKZ-RCTF<^*nGDv2?`%S|0nNRFNPxla4v#zUMVv=u*6O&s4aySnb5% zf#$sj4*9dWR3*%>WM_%n6@?&U&tgc2N&MH^FpTN0>Mon!a2YN_NJPr3kNa0$8%gqZ zX$k?0M#;QGoSyI9q{YIH1a9(MC6ulFR zmS6bxWvuaYn~3ju{#0y0@-L=1VV7`y`H@pSq|i-5qIRPhk}BwDcaYr!NU%&LI=$0~ zR612Y)Xfky@QkVDKaa68^EjwK+!4>)XL9f3f*?tFDaekIR3{Mb{zz#ot*?JqxA4LD zdq{`pjF3Y`8zO^5#|U9&G5%|YkiJ#;^5~!IeK7XlTN>rayLZ)#+_*`XpYcM=E)lTp ztNH^zApm<=Moc-qXs1w1&-y%NBY|~RjwyP!53l4`u>8#4xU3xVJ_U(O+fg}m)fqHD{k%J@HvtudLT?%7N9Qj@pS)BkbPP^&FfZuv==0E2) z7cfHjyomUTFH&MSjnC7a!b2mv2@zcaVq!u|YPj7n9^mrrVMBZVtR%m=il=p&k)I|~ zSaEJGc<*eQ_)l*W*}5Ny75jauRSv4BMzC~I^tE^Kr42Vb2`I%FbD%M#9ZIq|SK!j>}sH1Wf*!)-;3^&OEkG69N5~VBLQgoY%SeNnj-e8d=A-WU?Pp7zWA zHhhltL?E)kKtDPgz4%?N82bt%?>27l;e-?$#ufNwt62wtBoM5ZE-}x&%vMT8cb!T2mOr8g&4Kfl;DnsLbW5UA% zi#wj;E}cY~*rpcP(3-cxxfhQAwJ(;b#d5tGIFp|rF*4^@=~x@b{^)I&e{c%BNpNO7 zTw7^r-{O_@ShGn@l^>Xh({q4nA-U9)!C;?YSI1D?2>%dzhoQgW-lQZFtzGxNK} zJkeb2qC^|NnfHnx@tT*f+Z%Oe1G4XtV8Vc;ke@HdThNC*4R#ymg_;{@O^1AJr>_Qt zV>ZFv&Vw;?SJIC3hazRV#oj8z4I%gyG5B@pA7x$dV^~G*-*J{k(NnVlREG6)(LDt* z%VMUx)=nxjWj_^Y>c8;PVe`MDiQ7$SV~6DlEjL`Q-iksfZ@qUQ^i{?a3=WrrVZ3u~>2c z(7Zm6PD!8nG-rWRPcItv8^sIhQq~%yeYY_lQ!in zZX2efT|irn8lHn}=tkvlD~qgIT{HPFu8iHg)09CS?fzYMVs9(Z=5cS)A}abwtLKG) zhk(g#rT($OXyn>i8-Lnq8F=q{Qi83(LsFgRwUTs8L_Mcd`V7uoPUoeQ3Uew^kB=;7 zoayx%Gb|_@7%gzy37?Dt=r2aVVe#Yo8Eq*{oEJH^E7oc6utNa8>eUzCx8uLD6k6l9 zcxo;*XXDKE2EjIeDdXbmn!sFqq_=nBn!QhkrjG;B0?#nO)_o)) z@Lhk_`@D4ilsD@G1q(E((kI+O2{Xa0she1>Z6dA1+;ih}(j6q<_k~$yY24Q4J zUw?R?141da80lhR8^r-@KC3h~wa*&Ly+J@4FZH_RfT8sufJ7$zhWR3g*)wZoM**Y0 zbfkW@47h`l_^jjI${yBHGgqmmu3M$8H#67^A}@Y43x~_+{t+=?>fyQRoWCdxkyPsl zATi<$cgY35mrCKEME>?Oz$R4zsX_G|6rRh;0mM`a$bZJoxt#rBz#abIu>W0>|5vM| aAItpyVFRoDd@BP8cs*6nK$OXuefU4Amlu%$ literal 44166 zcmeEu>-x<)}5TDn70x&)+=7)nwB=`M+( zYZ$om{oViJzPj_~oX=lB9E&$>`E&K zp2rn&Tqs9{y$*)w-G{!R+)#XjpUV~DmiG-ulTu0YjSdJqIFqXspj3LJPN@rEpX7|- z0hHb(K2-YukNrR2l8ZuFkdli3O+HgvhcW=U)R09NN@Uq|HU^xs{W!fgz2@*!8gwV1 z7*DxoNKAM!7@+w52eU{()W-BFO`YszgFhe9@~^IADeyp$Yw5j20lrcoQQ%R>8GT2T zqdy=>#2jL6_Az$jB`XC89(x+8M(RKJ60%By11G;9)nhtUFdYyDFjtb@b_Jo+1b#kx zidBV;S2&S@z;=YuFWy_J7xA@}+cXJuTzO37h0VgRUZq|2#;_d2ba4mKvDmQz8$mZ^ z9Z^ARQqoCZxC$YYDe+?(79CC}9D2A9$4%FzS(z=Z;H!qmO9?@4iYlbcT{(Asr38r= zWV*m0nvnh$G4?3Ia-9O9oTwO3b2C(BN7+#mM+CxsQ*fdCzsXvF`B*_J&#eod*xu+N zh{A3wjyz(7jMMn>6E8^rdkNmFpz|v;nD`P-{cNrQh`DE7?xPU$;w9mKOP$!z|11QE z8iOboJo}!m@3NILvP^;o(ZSfcW0o5O*HR3Esa;vGT@Pq%Y>&N?!H0u1^(O?)5rF;bXy2YgnvQDHz}x(0M?FM!~D@NJId>bfD!9r zADEQD-~J75HAYQ}RCl$`L3mKybMo=JX^_e_Tl0GN`x~_ObWmVO0_o-zWVl{)+!9z; zz=I2onci@g;wE0)$8MfqZ)4EukW?5{<$C%rs=Vu|v{q4(o7f$_c2o1~Z%1E?O*e!T z@%%%D%ww)8H1z~5&Ba9G&=PMvhCReU(a(N|4=P=OLFT9bd!HjgajV|x2mlw;z2ar)&#&4)E3{ECHC8!^F z69@M`^NBbdm^#(pC$(%~@* z%#i#VNQ4?BKXqinC_6|0;jrRni^u#X>zS$VT#w^*sC={MPc_-tOCzrs$E=A5|M_yC zp;*TK!+%24nR`~)hf7ojsUR_oRTYX!m~<#LFs24QnVFvPC*abEsNOKjT5NZJsTp_M zFks3PnbwR*%b8nW?`8dZW{`A-NBx_rr`C=ep361*Ad4_p#?RO0BbUu>|Jb+hiMvwE zn#WPzxI~&C2v#2-zK1Ur)e;|kGXvv-ad^*D+KH*J5DaecKRa?dSL)@rbSbIE@}8TA z2DkvrCS|Uh!H)C#E)E_u98al;W1wR{tH>?G+g&Z6tLsSuQIExfKKjAC z3DEp1t0J8)_RVx2AZ`?&N)sEfH#%wIAMlDoX_^+B_%7X<+igewjhg+dKdFT(!V$;(SR|5x zS++xnu2Mn*lWyxDczzW&!`U!)-%0YQ3hb`(*cC}en>kbhuXqa8za-X*62B`Ui&4f* z#pX~^ef1d8e`$QTa<|xLap=l$bFdkCw}--_En(JPaW<0uBy6cT{T4UXgEr{JHqj(7 zSci>Ej-yygOTIK2y}kdt=`=<5^g0u$R^HBh(&2N(w$O_?3)Y+e_v^!R!$@0ZpHogX z36Q?qlgD2Ii5`E^5~hC>oACHoXfw0)(Y-Mj3oe_Rpym9Iz>CAPqkrS)(W7qf$#mLu zX$1Wr^p3_Vjo}u7#JNrh?``PTr{~Ft{HLlE_SJUtPqh3a1eDKlc&fH6zdrlO2m(7c z-d_W3J5w?F;-R~JTUGvPJHFo^L~~2)lU%dr)i;b#;^F`aa}TL59boDFZz(B4X_Gw% zeQp3Gzyi<#^xbDPvTA|QUFH|Lwk@(zev;h;HjIq%WmHoSuRK`FQamlnCS|OGX>K&2@gTDmO3tRZ#1YmaP}@WJe)p z^GOw%1|jl$``7fVgi-tw-hf&St|kQK#tb%?^wAI$32Og!S^gAA0~ylsB@Pfy?VzdQ=2cMyqzE zfX;R@mVP)Y!F8M8BzN+W5>q8(Jd=bK5f_xC|K4-uCRI{$MD;&LS)8N6ynC8#IFCAW z4}vzHL$SKRYbPSb3DN;d$dBb&Fj;7Si?+A5|t%x4V@J3bR*$2}qFA>|7Mi0c@Kox|@1lxyd1Rq=hG<~aF| zhoTN&uFSjtL0eOK&ZeUJgniU!R&=xyY)rp;=7vbX?~z)h;;3SmNx=ql!b39Xj zETwQ2Nm_AyJ5~E~m#NjhH0}y*2AcW!+EfWs?Ey;^ni_lySO7^O7*XxHR+(A4v&>uHgt=l!o_x>N`k99Eo5Fjy$` z%R7>^q73jBdG7=oBp+3-^kkwV(**H`B>n5Nf?rPw)(-*jBYSeVhsyCZ@o+-SM+Em4 z;XP71^yJb_M&)84COQXQLfd+jZ;+GKalPSudl@UQE}2xxNFioLAoML;Gn*S5>e~#^ z=8HX3{I~aA#BF8L`Gyqfxjny;;>g%~41(_I*$?dku!uC%hrv%EpQ)7@Q{(4FDrVnH z94MF~^aKvRdjBJ@D1mh`4_sx=j*b+Zmi{fSCpbDTB+G0&9IRdfjdxFXm^I%{nqsPY z1p`-@?y}Ck3n6$R%Km<2Iazr!V*gFywa~1c%HrZ(vvX|*vFPUUEp8+MxZFeL?W@yp z>d>H>4bV8R@wu+!yl;Gku>a&I2%Rx=l{sP{zABeq0vaw=e3Y2Voxn%e_imYOVOQSLbSvjGJv7&(fR}x zZu`E#OLoIC_>&+)=7Ww~WhufaE^FDBMO_K{&4lw|0 zuzT+iMJJ|k7PdDiZKpmnr-J7`i^a`L#SdUEm4o2`wNq36+2wi@edPAgPfOhPL?M1; zjKI(Qmt>He-Mfm>+XW_vNUCZ)J<5$oc?v5@Howl6L1rY;jn_lBNwM#2m}I<}-Jum? z55uAjZX{wB(O`>$RP?6y?*(l<3O5ZY-Vt`m;x-)?=!r=k z)30zS3la>A+Nh$1ow6jRU^gP1&!peM>$V2#kalr8W?Yqc9H(Jq4r}_OvOl(UrI`zj6{<|xtP+wTH^jhb7KTS zoFA&EKVLt-Qv*dWij=U=a!01d=Y|uuO8jO4!aAQMAWXk1tBi%wO3sN=1${j|{X;z9 z+hg$VSaaqs!3a8&M)`^1eVY2|Q3B^8ItMKs1Oixp(hlf*-IWDyJT*Ogx_&m#F@H46 zgw{B48U7#rkv2gTfCKypx)W|JEf20}B?HDJ#mNC-1F!vXUV6w7;buwlU{#U6S0({M2BhGmew0LMX0KKwxTh^htaDDT0h)ni$L%@H%tinE^0WbdoSk9~LCbbJ{R-+KX<%7+6%?zUkI{DQa|dbZXbbCaWH)Z(nVh&`<+! z;Rl}F2ARz~k6CkSq(=%#`^ak<>#IMCTH%*N8roOz8Ftd*&S+4x={s=Mx-p2fyLygn*sCHP3c_OP>y6vCA!1gp zM+eN>Pij1*e+9Y{Au61i54#TrJR40aKxS=*1N=EQC@SNNCgpjOlOJ+?NHX2l+mRef9P>ZD~sjb>|B}Ea^wLuZ%7a>_SKI384 zOQ@~$F8TpzNwTN~-MZ5lW z^1>Rc-n97Kf)hiZ1w1Fq(Zu_3c{0Ux(P65reN?K`ES-8Rvth4U#OBS7$Ravp;oxYE zvt=RXY|IG~0-h>M+SS6#!((r*+?7VU`sY=WOWY-JO@hI$m#Mw&?oJGt?U(8}3;P#0 zHM&cWU{)0N7#qJ|LG{}}@Qv8YTrA^}eJO!DwHoeorHJhy7Zf8aSR=oyySJ7oYuC~( zkkMUUjMzjugLciSV@>0{Z7oS|$ph%0;s*+l*G(EWAKh>`T83Z+;7?2#iI};C+s)0* zf15?#AEW9Q3GMfKuY8ugTC2^jtgVfM2mesDs^Pw_Qzi%E+_)@gkqb^i-M(B%kG<%6 zBHQSDBSpz_1lsSU@)CC?%XFAogZm>xT^Y!K53f)c#91tPt!;zMX!}m;Mi5p8GivKG zZh2^OJ+i34*9PAnzJ3(_LSVk*B_0_QWOIp77nc)J^f8-3>3XTd>t0Dvpn6Ps0(}B_ zzqu0C(!Q$=74h2lnu<83*M8_d26XN^w`7r;z5hz*+DcQHA z)E16Fnk``s+`~AmNPYR2l27WZdJyQJKFo&Q^H~you~G$QF>*JF6j&gc(l`>YU+@Qd|CunBNZ$Wl1K3y zD2bSEiQ`dohKug%1*fXYs+XhhS&Pz3YE#&|Gnm5g0-iU3)BSnkw4opgrVq48>_9Zj z<+6h68!NaLEt^)zTc(XyegE=Um$e#>lmAGU>0)6Cs5Yo#lJ)(yvM~_7RrNA7>-RjK zBCOio_)4Co>g0Wi7zk&vL+7yzHi4F48HAwxQfQjA=@War@XYcV5AM}z$;RkaHB`m`EmVWUryw-b^iv^IypIcqO&eO z7IY|422WzBGXimGoE;56KLiCJeGifl?Kj39|DB3_VHeK3j=W`^zv75^t@t%k6> z`cASW1sGN70fHpBK~(;YL>s#q#2o8)jo}ymY_aKvoqBw_-oMe1xtz-l&6gfnxcj zTN)Y(@*BI5WY||rM+$}~_CO|r)*jft!3)SM16ut|N-sit814N2xkI8=&d$_%*@pK(= zm2D?;g^<)X5W_RwvA)@Y9p{tV#rkS~>!0#KEt1wIV{FV@S2yh2W)iSTBdh0BpS0A! zSN1dS-tp2@QC;_-;<7Px~_tklR&w9e_f zzpm8PiFopRWd+Aa3$*cCRlbTpXEqk`&z5`_rplrp^|9@Bs{>bYm*)o}^NuCl+?uxy zpTAArDXzji&M2q9NCS0IY_%eZ%p-<-SW5 z5EzMkDDm#+d~W;S!oS{rAq?DP)g>PsEyrzSFPc(Jyg$CoUr(+1D$4TI=ce#fak&QT-~ICewM0I*Ma=x1_GrFuK#%LptOog#&gV`( zcrLFV01^xZ!$??2({|7?TuC2l=qq-mb{!a9*#sJQ^n}7?*E>ll;0A}Bv&or_p zO`~GsZCB3>KQHi))S7;)lqeTVFeQNgXqCvhndKT4Dh0CL2O+TF=&Inmr=QDJ@Lwxk z>mV${MV9VEFf*pmPUU$#(&!q+9@!k3k8Or4l8d|~u%oCxZ&$OV4d?fmOk|?f8rl_?aOrcnJ$V7lcBoxJBWNdu0DEVhF zW!~$txHUqE&K_N9>O5t{I+l8S7Gel9&9?T0*1SX@qH4Z0k1OHboZ1ALA3aTNeTG>r z5aEGUBxOTogC{(7JUZQDr$92J{Xv~uTQ$C6aD+GN zSj%ZJs95xFike2-)Y{DLIt=^*6=uFnw~Pr~F3&Kh{gHb9@0wh5ui2^1u;TjmCkEZb zX|vq$tB`Q}CS;DrdyiT!L|G^#_5PT|2m=h^3W*|PdhUCAiH!r8LRpaB0qzgw_k0O~ z5lL3sHu0$+PHi&|BLXAn%B``oOeZLw{(}^ayk8M2HVI)niTCgE`)yuXT<{AxH_Yj5 z8KEHSO&5uJ#2$fi5$fv@v)57M%!l87Y%0h=_uPjZBp?t!HxJ{$&);R2C0|f=s&sP0 z-SC+m?)I4Z4I;FQEKCyzElOUW{oMLBJikKe_C6`=% zef8Bd#HNkxl{?MW)?WGY{$F}+68$fsL2DmC{cPr?r7Isam?V~R{7^1z&jhy>ezdm> z*+z*O;eUqP*w`R`EbK1^Ub>tHp2`;e8aG_k+y#z7y{dRke5ypQw0&5oU)R>u-0tJ% zfjK-{%egzXfnd@`+mANiId@>roggy*$$xTEBBo{GHqz98jWh7LIB4a>)cf+X=HjQv zW9CrJ7nGf?S3MJ?;Hu>KMTc8=C%YKWcdX44kh?UETcZqR-O$qdC946kip3VAskpA6 z97yIa%GK!?kL5y`RK$vj`Ix!x3R1RetPd)1jItAP=7M^-!Sq!(yXq7T)9AzFl1^h;@Ss{!C{e?r&4kS&rA4Fcfw##Unz577 z9zD&v@w(H-Xo+uXk0OX|Wm%P;{kkg2xiY}}Fl7ZH`wXtApUmUSE^c@rKFwSVOYRBE zW&kTNgqn9SdA2NgHZD1N1z)Uf)s(&vgGOFS&mn(GU+|Fp+9#>)L z%I!&9jq9@2OWL&3t?d3PPy(-`%;4yIEV+A3Vc)V?bAoiRD0Uja>5DnE|7&?mGzzuH z=qZTKt@CjNF~I3%PjH!VyzW!bt4~uVz|5i(EOxlVkczAaIBi&5tzIJQ#)@oOT)gaz z4ydTh+MOypixg74OY%--Kgk%vnphq4Zt@UB`G7zhez}N+5 zW@dBmX7BiTLW&qb8mLp8u~fV8;x;H-UZiCHF+6`&xS*hg+j{($tb5qAWp2M|x_jZgLS3L~SDhMTvH7sq-7W64$09m$E0};>GT;xo^YgC_;!j5o zeZgcb1uLBpOkIxXRU}t?_Rjl?hU=$+cQ+ozkXj}sZTir>K)E*PC`U7ncXpgstFL*8 zLJ-AXZ(5H0GtCTr42M@uYN3Pw>_~E^% z$N7XIKbB$_O_%0FXs37U>z&#RV%L5wpp3V8=We^Jmlh@LNiCup8bk&UXt70j@^z9F ze+6<{w{hjtt5Ch8M{srwGMW?lGkg zVkBX7$f%GJ$A&~s232o$M{p8lWyHK*_TM@7VgL)bDEJSfXKI-;5sc3RKlu1O;--VCRuS3nfeE%4MxQ>(GiR`) z<-%iJw2!K{jn|}n&uaxg{`#?8Y>F;hp)4gkvrHu zl@P>x@;FuXOK69*Q@Wz$#_JobVDL%_Dp@W?@U<4S?5(|xFCEO})rnEOuGMVzek z?7@UKDFj`TLnFX0nY5PW*qt-LGT723mVN)e8?i0y3;pvDEYf6MWaIWKns>}W{4YJ1 zE=u0>k{X22e!2*xMbc&)881oL{{W@e5P(P*pHWJjo$N?Y7f$`gGGnsnrcpAPPfrHk ztjZP`S2oY?dzJ&|=(ak~j=AidERQLvp|TWFd|nk3WoxNK16#?RhlcltQWfV+7aN3?)IcRv3Y&tx%xS>ep0At4>uii#tN@xzQ+{pL-}>;-y0zo zQCR2E9&{(`Wx2L$TuTj9R#h(go(vDqAY*YYa*P?78n~egQcMbM2VBa&2vI_I0^4Ib38d5MG@~Qzlt>%__~8_m5QN z&0OYZ?IMn?(7!&IFV2&Jv~GPwvMnE2@GRp`s~Eo;ofwVCk64HxEWvDfLZRqV`6E{D zQEocCK&Q013$@>vv>O`w@Glm2e<8*Kw86@FTQ(WI$?tbG&oIi`bvl-Xzn_W$3BO*C zWEIomjB=R4S{bQ6?vgC}?p7>6=Ll+xWzo0Kk?|4sh3vgV;S!qeep{TE8u_E_6{Ba< z+Y|P;d|IaIc0B(-Sh_b8rp;b|Axm-ei%%Ua58wAh4`^WKhtEjM3c(f#C45mSXb-H& zu37hS${d@qvynp_#%9aLhGUc0t~Ma@n18d43!XX)`au~nRoHnFiIq>Z-F?klLf(jv zn|Ff=LMAE1C;Q2i5hM=xQD0oPN0p3xo6_s*MQOfzE^Qr|F21vm4C^UqNJS6{O*pmq zq4ygVNF3tBXe~(Q_LFl=OfowTmv4j7DJe8GHZ~cPSb9n?XMdbzO{5 z`_a^#5T0aW=+WZCao;x25m)os$+Jlt2s&;5eDmM+VAvO}low2`u|kJ3B28hw`uacQ zxSX>5Rpwh)kE`Cc_KP^?Nopn;&rwSURh@jQeIaSfI`+S%JbdG!t4(5`W{aT_Fg}~% z{QcHNEut|(=xzyf+0jql!y1+MMXO6shf%c1sbW;EfEj8C3ZYdXnlji4TSZF9@XLlA z_6h{LHZIh7WfOU#pVR%m25_9vPl707(_n+n>bG$pacKQg*?^k?=B zC(@N6+?Jtl;Xe-`i#f>&ygXcSYHVq4Ufjn<5l$UT?&2&?PqK>szV0aa+&_3)BsODo z_9SJ|_3noA_VZZhgsI3h&zQ-PXp`5o{MU=#87eOUIV&fnBo?qyg>`Eg^udGB$HMSJg7 zGy-NuRZIT66@40v#ZP`sU61=I+B-5DrCGn&JhD*IIoXd}MDL+FkHF+yk|&bo_EF(+ z?|iy^>lu+>=XC}jLtcqm%|>x?eAOC}!udtYM;)k?`|rZQwdPnR_x`d^!L>xKDLmKi zzBY+-4^LY0fw$ec`IUh%7x#;VYwY&jdpt%%W?iNY>1RdKvu>xEIPH zf>sA?%oIg;^pC$>%l%nhoL~=6N8ytm9;&RmZJvf0s3k6SyR4L-HcRcx$R`(s$C!*e04( z=otUHw-A#=U{m0W z@|w#62_ip?YychyB!qZwq@6wz#8PnWPeik#%!+~pu=2t9hvq>y*aRd6Q(IN`wD8(< zibUQc+*#b-d1VLoSnw_>_uIssq#{-4$vqq+(I-U4cSUW}m1!d`jGWEGSTE9_SVGXV;NwT%7q%)nqGcOyu+Dw5bjvQTexQ58IFAQ6L$g{AOIu|O&q_{vJO*M2)DVUBoHn<3 z*>cXLaaN>{PJk6x2Ewh{W7Ri^9fL%YlBOF0F9!Q2pxgE6%RdU2Is?=NZeO8+c%n1# z{$Cue`hSLbU-@Ci?cxS-#9BjlPyE2&lqG#~iBFl_Xa|EBt(7(cPO78Tj zN*jABd+(YjZ;M8^df|sR8a!}qUA_JZCP0CxTkzn*uM(fQX&2Xd_^g?Ggua9%Y7e~_ zQG?*vePJ@fTg~lmJ4U;g`L611>RV`{H3!|u$kqx7BzgQ1fWJ-TI58;nv8a_iWmnbO zbh-CflNCqhc+0JNbvUOdo>bta|MI&qprqZ8!la6f^LLlBXV#E7TS@OKBTpIAhL1cp zc2J#rSpKLt`}OcM5U;rMkT!m}>aRGrjFEY@OU|Hw#WNGU@rVFk;ru#Ar<*Tj#BiqGsl^gfXHGTbBp;8l?MDulZ;!&D&R4He_AD zq{^O^XasC`GsVOKKVPh7&NSDi8LfS=V0T3r-gb$S-`kK>`VGwy_x|)0hYr@zlI34U zhbg=ft^b)Q>)lV@%Rb{RQ*~i}YR-TW^kr1louVd)-lP5F?bWmQlAOaN`PP3-I4R|` z6v&l6G__*&138B{;SVasD3&w00<65QiHpNk|EghO|8>m%w#)JsnK4KvM>O0Y0I(^PO@@artM7=8eWK+%)wG1} z_TE4*ZE4&@kxZ*F# zIBVOms2ium1&>x^Iqo+CGUh4$m^qW+iR*Q&EZ$ky~A&{p6Oxwyf}aq~wtH zpYr-lO-o0+X7PbK3bj@5dzv?p)tA@l1;k80C8eXR$l)?c*&USH9RY9t=>|Ss=a~Ks zeb5C&!5#x@%oe6TL^a<%L=eW)k~Sc&kO>Mx$N0+0Ni{_lZ&cgxOi6#6Z@K%FVILp! zuhnbV(ZQpTVc!5BpQH2qmy3RiqCZ`n8%3=N$8%TyXmp>|DH<=Glz3$AHape-Ls#w^ z_2rArl~Hr2aZ3jMBxavTI-VJDnE$uM`4%YWjs7u4F|}zLK5<5W9A8zAK4xC~Kwwq= z>9{rT{<7qHJ>)#6^*kIUGv^@aK@kH`s{%sAffQfi6B3N+yyUBY-=8z$om-ij7*V?Z z@p#UpW1Xpr*43I2=&gmmcxHdAK{QS{TH*X3-qzDasa%~x38L#mIizFpmZ>2QTiIET zQ}b6}cQfmz^xZjY1V)8PwNTdwy)-7?m}qeyBz#L#;W!KaD_=PE6kgZhUWsZxd+*Wk zpY*V>R_W=;Z1@%*`j1Stn)lXSrlIl|0OX>sg>*lTZ}j5(ZCyUJ`uTGRz_t6?<*{tN z`d8qinK_k9W===;zEjn0&K89(Fv{0Ecc37OWSwy8m@A*2e+_^vh^A5#`%{OEUUiwd zKBQR^ie+zYYM0tg7%P6}K@hUl`~i`X{gC0bUHkC@6`QyW^uvgWkxAxrpNVKT8<4Bc zEZuL9g;Er{eYXWW-*!}{`|jjF?2`VLzn!9km6w6!Xm(a7HU zz+Q9jZu~~D97nI?2%F%6Kh~c+w3$@t+3XHqUCmgx9b9K~c8jhC9nlA?=(<#Z>=!(^ zRcO9g>ihWxSfO=!u0`HRe9}q_=KTQuwWHL;CiWCSy_D?$e?rJT9p%JQSMOAd3j1Lp z`+*%M722j49k=_6UnPoXR-f|iDMVG^01)5p7JOpQozV%q!Vd1~M+NfB{zcq+gyV*D zk~HBa-Y$3l{=KJ_Kn~2DTDoOaw4j1xwjibsRP#u68;)g6pKI94z;JMAQ z|IgT}X||77S^04Mx1|zVDqMLuYHgMb^o&I^BMqdAmG1B1^V|nIW7#@G1LTaSH#8Vs z(Kc+S)qL67sHV1dY?_AUyCZk|yFYr(VS~>TEOc{GXh?n^sk{!XGjV>lTQ_XoW&-)) zRkiEKT08Y;WSpb7-lLiLLR(TBDoSl|-7@AO^aAJdkPxETtX0e{BwInKwW52#_T8Wg zaya<(0=0lut}@*mKTS?ZVyb^;%#Sa@hOgisbB#&+{n8=3ne=^eTZfuNid^(X_8f zlC{_x5^1w*kGYF{%V<9?|HKTjKQ7Aif6qq&x0`*eM+$}R_cbB!|NTZbFhU56LUsAJ z#r#M8hxMb_?^=Y4cAL*im&&Ji&#w0zP6mer7Fr|(7!PN59(#A$z&o2{crhg?&=|JI z4YN!LS;j`h7khUDn;m2{`#cEZ123A7(hL<(ePIle7Rs+*$g{&~;djsCquN0hH@ zq4c|omQ9ADA1<}~^$ku&k0oRLerPw(Ep{ z?|e{~!B8Ch%onptOlIP9Rfy$9+IK6&MrK=@tKBaXT8KY`{grXD`7jHC9djrqS$|QN zlpxegrcW3CjbpSA6tGG{W9q%KxM8*AmE~@=EZkQtX~P{W$`w{|H~E%UY^L-(QlEfG zoQji|p|0y3$b_=8yThQf z4xiW{jU<6b|82!eP-i(V_RFZp$q-pQdlEhUnxDhTfeqJX^X7JxZf2g1bM)OOU9bLW zCXBPW&Gp{L%P0M=F{R7B@D^ zs@mNG8(c>oDKcYAKrq|Q&75r8fRrDTba1&3Q-5Pv-!RF~mru)v)SX=qOo}yLZfS{m zf**+@>1-A@Hm)#!7n$zHDypioUgdX>Tkf5;4@ky+SzT)~tFqhu(naL|1iRZbo-z)? zTG)dpsBfhq-{ld?V_(-yL)~IS&DS48nwlpb`ATLRC~W^!L8N$f#++29+!m+g2;B$iv2mjdwcsNfulN`!8Cr@2y}- zy8lWD*D~-Zwjx`SY{zHVXs2c5ktqK=blp2G4=^7lCH>uZ;*USRKS?bob8>2sp|Zi2 zQqL>;n46n9wlq}r!~3kp(&wR^~WTk7yy7DUWTXNjY?uycJLER45w^uK?O+^ zmZ}En+2^IZw=2Y0wCsA4{r3%+Yk^rOFNZFtsoiI>{uF9_a#PQ?xd(O$-)9bU(}$C= zK4KOQnD%l~z+#f`U-=$s`a^Jc+)O*_H zGPE+Y)*|U}N6tpD86H-z%xNW_Wgj1*Z;w~=P}0d)$-k99mp$DuNRF>J40(5JANlF` zOhJ$SRRGmRtb{$5Ihhup2C5ehw2ab4e2%%n(?Os-5OXlgnWzWy!?S4w0CWl#90|wf z7%iOXy}sgztOT-y+V5IAh=GyOQGL3KYu!iv%6M~_<(Z@XV2@<_!yX1kt}_eX4<>&L zzUK@L5liw{3q6bG-VE{#&M$d4#bU`k%@&=k+||sGJ*kNhCBgE-(j>S_ZX?+_ImX$h z`p>MC(&Wxm#loR8al&ywu0%=_lo%(I)IW!8LV-?KhAQUWSt!yBWj zs6d#z5AkNuI58Iu$W9FTR3swL{)OWj6ws(`4aB@pf%R+Z`IVS}D&+X9+_7HaSy?-( z*)0i;c>r)tB<$J)nkRQnFZHnHd=heSd5v8uvDw)n`-*(GCkpm`iv5oGUPjt-t%lQT zs{B|9hs}YI;T+jOdJ%i>70@IB2mh>W*4?@+kFLSC zv;l6#7UO>*r0>?VlBfW3Ol%i2V|1_%H#|5?Oun@bJI={o*(g?b6 z?@S-MS0d7=tTb$+sMTkw+u7?>#8067zT*AeRtNb*=Pic%-I41cON zajn_roPbb29;2S89(M*vY1!Aarp%VRILYxom@uTzM>NLW8CRs zk4&EiikUR)yO&-DqqXtB$UwT)e6(;n)c7rvhUd@Y4HFDLw6Mz~X6)FE1*`-t+T?W3 z(;7c*zWJB-@H35a`kmV=Z5Y5!F%}xR`-z2&Q@IXVV6fzro;TRq{b7Ol>k@^#dRB%O z5<;%%G;}}ILVXY)m@7ByI#UFu2*wTFQO72aw20wTxiilz z7X!c?9rc$gp0J_`F9LUzMOtG53>e7kbED?E6Nn6FO;5;CHsMu=0eD{u?{22sJ$Z$@ zkc6qWUFRZQ4ml19-BW^U{H3yib~m`9wZcxoNTY5yWrPJcWnTD9$n% ztJM_qS4c|f`oJU=`6YN!c_3GkE}JGGz{WT|0cO$~WI#=ZrFG zT=-JCoMJYRVwD1WsL{tOkJ&ndj|x`<0vVmp*^5peUAByv_Fv+VH)^}9IO%AH*U~X} zuHvs&u|@D$87OZ~U#L2P=}y)MsA#xaUF+3|@8db$W*Pzp_R&|Jkz|82ZCB?Cc$bTi z>kgAf*PgF4hE~7)!pJ}Z3qLmio?n^?2f(aMk8;77rda84yE`J2vdJY@*?8;`pi_zfqD0m$yQU>DT?` zibkeSs(0E-RGmBjn_tYmhjLLzC(g3#yv#g~awjNc*@XZJE2((K$txHyox5af+jjD( zVQ%S1;TRcd*3|I3;umoD*-0eNZUT0^TP65C(PEHa`evmD|?(rYxp%l>U+eE zh&7F7+NO$jeO%!DhZGrgcRviCUIy*XFz36GRn;`ZCXp<$O}47meu1)6h%46gMayw< z#@iy@QQcdM{XN$o_KacH{=et zm5VG#->qhaS5y68UR}g%Z};2$^L+VS`jx=%iDZ$hnBT`1uI#1Rtcb}QjXC5wm)2*Q zpK4LAe@)i@&C{+ac<=IXoY53IjQL*?ng}D%7?-+*(*HXnDHTG+V-|;Jy8`WJT*;Ko zI5@s%xA-{Y=l7UB%Z8DrIaNj;*9hau?}4w^UJ6<|uvIT)<4+8>bf88qmh)^Q3A~8W%b%O8`4q`C)LTrmH6{ z(a`Uc;0_@lUsing-F)uZL!FTO)eC%((&HOj=YXRUuQdWGuwb^sAL>JX$&VpP*b~x; z9?&|0{wqj{7VjpJ%P_+r{Yyc~)XcxA3yp6!_yFK+jEV@`$BHPmkHEo)h6h8X3t%&0 zGJh1OK9@sS+)#9-7cQvP4aLy3`T>18;mtf-u2#*Bvbd?I?wg>w z{qcp9z(7FLfN=`$XvCDb4i!}Ux3oKRxQlbXNbbUGDUAE`YT7qr{ycE8%CsH~UP`{q03U3S)O#fy0f<> zeQ}~Kc_4jL?lPz`3EiE`xk4o0PKt#WSP|foE4{gqHdO(S(GD(GZ->Q+w>DjZiSVoE zoODm0VmU!B;_8=NmqM?C2+$)nr(V>B&$AaRM{;=j%5YMBU)u8ff8M6~{hhhx3$*oz zpxQBp!}Mh`=Ikf$Pg5RvZ##vAK@a8Qcnrjj%tw)o(pW3-0NMFmrj=>vrxU-dzF~9e zw%#W8j~X5v2Lx?jnC>lpsroP?nK3KPuENXOfYgb6J_-Qi3EH&%b*3#c;G)o&cO1+X zgJUw;?1XKxf8Jv|6I602e6lNKzfH@$`s_FbDm-W`x@-IicYEp+DHLLk?Kms3cjlKZ zAZ1Kc*p>k{n3Yx&xhll^6I8u31@s(Q3Xu@rw$kaBl)zBl0&C6sSFVfEzDG*zMV}QP z&UryZPgXL<#S1e^j zg@qK_b?=^`q>mfs_GI9^Hy@~RQcGCJi2%RqMAsc^;O5>xg`P%08F#MN&ASzwq*Xwj zRgLitrrthga%x9ovOUDABdWmPU8-6_sB%qd@lj3RrgvRgb9s++CRR8T36 z#ohFFx|aua6fKfHvX089!mOG1&A_J9&s!Lj&f9(1Qawn5`8u)}%gnjVosEH}8AaMI z{MGe0tEQv2$KY{3X-%9$c93=O@=M=&C*|vOlc`%aML6Bsiq>f^@ieTJ+HPGqMzVPH z$G&5y233OSnwHP4#6Z!)`R#=-e4)jUZNdR>Rg-DIVZzhV*&%`<$Z zj`1Z~_4=x8@e8yo*KUb`Q7F|8tB8YCsgMj>+a@Eja{S)_P{S+K-C*O2|J?@q2U4O# z^Gzga=>rvtGOnFDYtjj96Vx? z%WH|YcUt_NOR%cSIS;d(#qGb3=-07O;W^rt|4;P|IuPQDO_;;@{uk~BrLO>iz=daq znY5?SKq-@@yy4smjXsH%?nl3_i~`y)8KNx8pm1;_orn(h?7xkId>;c^#WNcmu(_%c z%IB~2#|XugPYOaKWR^7VW%;cjsDHJXnJFA~P9-y6X6i;x#CL?SKAI@ibPF4sub+TG z8!snf9lm=a@_^qxbvY0YfAaSpO_*7(+J6TmW_G$OI&OxcLY8H|X1f zi@aLt$^WH+?JZ^!(o&1AQ=&tCO#mpK97L0)RpZ{fRCPe}Eit394hQ_?tanG)N0cBF zC%uKXSp3=yei8lY08-~o2O0R{9Md!Y*fqxsB0JJdp!g} zx0Z(tr{TW=*CvDeoXx~bJ=nB=!AD9k_(L`pu3>i=*+&I%uu5cb28eQ)XUT|;FT7g7 zVji80AI&0hoaZA^R;>>GR%S^*Qv&9G;@i6V>fk7x?_RAw{@uuc<6IetO(Zl%pEHef zC?1(RB`SQjzaX#-=!-t4Sm(MirQG1shHf<>Q-S*m88N?V{0Pu z>7tL>BQCIKy>}z+8;U?RCs=`%y^%?K365|`x5hCik)x63U!r(LHDs5^NY{umO7LFh zP=X(A7>b3uTw1Y<`-NX1YhFmQ@2r80p4KQN7mMvIxb`Z|dJ(yl-+%KQMkTo@;s^Jq zH*Z@?<)=;=d8XiBDjf%f47@5(%;H<2EU&b?;ed7Di~rjSKnsq3xm#@f6Myg11o&Y^ z7w-6VEbx}oPr@?l42_N9k(BVozrEGS*&<*@Z6;|Fy)o*6D_*?{pGe++u9PEl36`<= zcs0Cg7I~JRL=iGiWXH^yeAc2~{&14`OL&!|a~)`HIKgZK5-AN8(|pe5&wmJGl~8x% z(s~On)wr3)9$F4@%2_#0;W3WJGtO;G>!I!nsD7IjjHI1Wuf;^7vyJ=Y0g8O1NA#W; zs^tUtDR z+>+ByY^7KN=|8H1tgyl@d=xdNThS${`R|&XoR6Itm8&AcHl$npl^H@C<3D2aoJ0Ux zTtIYA-&vzoNJhJhLWXq^ia~93gpomiK(lWA(|NEU^aOdE34+3dVNN;{<%SNi7vi13 z*Hxo(i*Qa&*^+oz-vonlbj`&G2-DDc8e!RE-M-mMTPBE3(UP^deR68U6xh?sp!1O} znd)a+I!&ZvvP$oWyT^C0mQT>w6qARj8Ovw+g9Se_B~hr*LIF5^fA3Ac$dJ!S7c%Pm!;svTXQ1YPt8*c7&py(?AtBmd-?ax?QxE}9erG} z2t#rVFSq9LDKT%tIfIXP=N?M+t_u8~-{jenWn*TErFu%*bDlqLdW6HS>c#sSx%sUk-w1+fAxIXKT1<`$FCS~Y zg&nq{pb};~r~Qh0`w$7*lZj(zT^nF5l&At_a$ zrE*nRe-%`$T5$f2S@5pEj9KH|i5(QgLg6@jku^jH<3UM^O(N|7HcCRUzKE(QB^bvI z$wiLlW7D;O@g<@)h$s^UYzu4K|MXg3eFyfKjx%!(`j0i1EWBCW_*GGa!Qe!TMRP8B zpFSNlg@|_*>CJ_e2%cUe#c?SfrCSr3;1eYv___Ep_L*k|^i?V7tp)2uI5r(q_eEgY zxPNY!cnz22Zo0cg_1-oo|I+72I!8{r@oyYiKB+`>=BN~)?K|)7Y?V33kCPOVnz16p zQ46vQ_endGl@hXkEz=Zgs7mHHzQ4EBMGS#01_)I5T)XMnkX0T?!jo`UQ|>7UZ~9Xp zldAFjlOya(Hp42aY4k2qc2VS@GalJWlHcTrQ8tK6z(V9$uPU(H$nXGCE%%uDS(SAHQ35obhM3^YWT^o~?P(hjN}{_EQ<${Y>IRX~#MTzYpT(M=^KRQ+_J z+N~8Z`~CcOEV`-^lj?{oF6HxWw$|AE| zy`{bLpa7QjeuEVMlBliN0Z`%wETj*$#|GffwrG}4_Q;yipszTz6WSz@8ZtdO{e#*x z4c@e0s8pJM|Ib+6c=+gJ$|*!y)gHCNGn9-pSf{Q8{za-9zSEP$xo;86i#FPx)RiW* zfkJ4>fI$EL3u9Zfo9S|m9*~XCZJNLRw2x6ZTYj55oHddy!Ce7Y_c9=Gjh5TIf41o^ zBE$bcrgE`Sa?nE>X)g-`8CaSg7@YhkKYR=XG+*Z?cZ_?Ac8uSWnj+^x<4%7pUA`-= zNu`DsJ36Azm*?nuc{HK^M2H4hAqXIvXNmbdGi_U|lZMEb`7af*P|_M$m3$;xJ9&9O zSy@Ks>)BYQY>jcG`oV}AkEQn7;Wz(Y~_7o1yz=qP4WfUf=gj6JMel5-@#B zXypvAsd!LoT1*cfdN4l_Tgvul!?Rg zP}MA5KixNnn;KV#+VOF-TEj-G9;U93spGcMR*A`MZ%xC;gGye& z;34C$lA`hW3-#GWiaYdk3vJm=4aWt5ao8J-ZVdhJoc#hbO!}S5HsnTv0Z(Ie_IN=f z&!zrQBByo0cZMBmH9GhfoxN?4yqZeBV0d6x>d3DCxC6MsXJB3~w@i4q$RT z6Dj@R>}suhVC398W-U+N?`sd_xuT=u=uteTT5u}E=_f)-*P>PmLE!=r%QXY|A12aema z$aLbL$*h}Xw_FjCDzJ-m`_DYNpV0@Tp&18U)Rk~8CF7U1-=hgXgkw$6CWC-q?{iBu zRvjp7_MAs8@DbiR{H!xsw!Svso&DIjAxENVU&&@(`kOOqrgYMj1(E)GcrSM|5b|TW zbNw7B3d3wSC(POR081o-22)-^!^>o=m=bqT)HzeTaf*11kXdPjpk}Z;lp}i_1!c}~ z)uZRTh7x*~Tl_kcwSTR0dorDEk0xkJ%i@Eyip1jG&oE8M=i!7Zy2TO?cQc;@bN8+bi)VOL4`tW_E!@%05JEnv zMIDp5FHQ2NK)QawFTMtArXFhti<(F`$@sWJ=sXB;=zOaaYlK1^Y!Sxo^ z_sAp_{<3Q~C2}PJ>_a6d49nC)<;p`&=FxLbmYOFnX3*Z)WvW-I5#oo3x^%D6mORKd zFkt1h{XW9P(cU#`;XexM%!=r&9I|a-EXrKbpd}DU6b+2OP*0lf)xH|?H{LRl6F!g5 z-&tUz`@y-)AUZpmpmGRVGR|Pa0vTObL(cU|yNFe^Pm3WC3bjl8&WPvNj73+jQ6KqH z|CTM*S1orcKkw$o0w&!P&QJQ@J4$w|$3{W`RyZD2f+4Ah7u-~CAD6|_+7*edqOvEL zZ6txNE33?`>YQRfD+|gQ^o%7dp>!ZDML_TZ3BVEoFc*oT2_7|YlMQ2cDA3B zsGDr;_39ztfYHFPiW|$?FS{(osvGy5***y7zh=`Oaw#RHJH)ial0(SxXJ8^AZ=G>7 z@@K^mo~2L7)_#psWJ(_c6?SE3WnE2>Djj+D3~y= z+OCjprzu7*oaL(uH+}8yHl)^oeNooThO0g71f%%jdo+9vrS{3($RLJjg$eSgWAP2y zas(eiv0IZk*y&XT3kb*?qsVrxVA~#>9YH+}lmxam<6aqtlN#937G-7_Pop|;EbFwX z8AaOS<`Tmi_40xwA${J;+5>sOrAeA5FD7#tN$XrZA zhhGqPA;TLvH(6;y`tWqsm|tFJr$vfqM;LEEd{mZ@S} zoxm)uReq~p%Dnt;ufpU!l@v&FMUlCEWGv^6ayGKXFGxb5E8pJjGMo&CW@G&(N}9rY zO8sH(FFMA#X)=C-!`1WIHc4H1HTmR25|yTS=%Sp?GFS;XLMJD_OKY<1W=JSazdk9e zOOmY)h?eTAkF@&)`e~UW-kPh-Z-~^~sf24>w89|%9u$j@bwT#fm~|5mip;dXleF$< z4GUsL72|5>Z97Sw+944lW!}hy%jfX#er@_yk5ZiC#pZ$k#p)$PS&dbi0!f1n80yfk zAp>M%WGsr3PpLyBjOxnq1RB^>rIB1|*^3T03yW@uiK{ByUt^g**l8A^8dlAEPMOGJRhT(>mBZ zXdjz)*x*h|lDB-Eua?uGqimcNK*K41`4=azZszlOMG4j3Atdu{rs+71QHu2|u?Sn* z`>!*_Jl1J~hh^70&7&BV$z>PraP zms#AQy%2ci4rs9vc1zR~Sk%IgAjP>?$z{FnFu4aISc%7s?wg@AZfvd_bX;=A zh%8H-`+?{adeu%*S$bHN+BTxAH#m{1rL}B(z1sKziDjO1DIt=6p}wX2JB#^}Kih|f z=yf<-jw6aKTm>HtHt=r2tjPZ@e3EaI8N}1ZRPJ_nUBeQ01J6Eim*F$~$!WaPuE;se zT%{CrR?4Q(S47uRh!|1Pmr^uOrZdZx{0Z7aVh+qaRA=$aaR^>fHS}%pfhXQDoTH>I;wWE*Vnjj#4pnDGV!u}nV`H@PSX^izc<=%O46j3H;7We7+ z63X1IWMniYsSgowN3j~B)(b0N`6c7Ay<0Ps-)d~0S}h%8X`gT9klK|`XSjLlFM1U8 zdfEI&6n#@shN!n-CR;IK^@Z+fBG`W4Z@@$+Bm6wl61bRyR2*vsKp`&Un_B$DaQf%Q zwE;~bpZW{Mb2Y?FIEE#Tcc%O3_m$R+KF1~J>TJFC&D@$kmgfD{cM)sqlzy_mv z$mEoc<`DXJRh9^Yp+du6AUgv=4z-kSZ7CT!tA)$-R^EtJhZ(JksC|3%TrFUCG5X@^ z{)_xUA1yBWbJkZT+MQP^X5YSi`%M8?S9|0@6)krHS&)U8$nyZBiII(53IaVv!1oG} zp?H8Q#MzS-MAiC6Zrdl&%raNRKYP!h!6huHb8;maF7TWdGfHS}K$HAQ(&`{}?r^F= z>-$oyW-nrk8Cmo~08Dbx$?)kh#cNQN{%|MBP*vc$VxyLAhdDp@{TU%;>R;;w|JMn1 zZM4Hz9_-OTu8}0PD(7^_Mz!qSE311PFmas(BXiL7J@+k`&bUAy*@IVdPw`S+W*Te2 zh*2Y?&smPlIz#7G0(5T#27?HH=r~MX4;;LcC?zJ1eKv%87T-Z&dL`11C#}R6BCm)Q zT=e-Al(z~!7%w~ZSMzM(EdLeC(nhiOH2IN`f1Qprc@%U%`ItD9T%Sy+AEp@$CE5u? zH4wXe0y%eYCiz<49zz$Szf@{{X8=tNabUo^DTV0oc~#SVnE7eJxhUAXdT;tALmqwdkZ) zQe7~XD>6IS zTF1rgb30(mnDn%JDWN9miZ2e)-V{!{;K?aIi*ku$J z60uc1CHy6VWrXiRGAtaMRxnIuaft!S5Y914zF8y86s2n#7DqI@>I1DWddCQfvbhgS z>`K%@N>n@|F{?t^dVE5TDDCZ$p^I8?+OaZ2-BscSZetcK+R<&fv&%ny#?LKKGB!&m7U5hFH6?l+8ERtxB)!zmz34{S`uO1ba{k zNhm&?p|!(;O^R)<-wUc&#Yv&TJM@&w19CZf$G@?^i)#-$ z^pv6q4*DynPtXWqOQ(15r5u+RM5sc;?t|VzPRG%K=wP@d@#qE5 zqhrr&%4^|Rz^0>cKOrWV&b#`xd)a}adW(6fb{$gQEel~(ysf8kDdSd9x~SlHIAE2= z#yrartN*3Ay~Pze((AE}d9Hy@q;pNJlrf(<1ZPN>lodxZL#oFilerKlV!VBXx3N8^ zu>wdv5*~wB&c4*XBX5FJ+9ZZY33_BdDRbGzhN8G38=g&~u(qXo8Ek&de@!$~o~-@P zFErunZie`f^}8}Jc&%S9=Ch?Wl3V}!4Lo>^u8>tc#`UWI>5W!tCdInzE>&<@5bQUR zbv*DEtzWK#&b`LX<6@?aLOoY$>OZ4icBUf^l{;m6&JX$=Q%G|$?I&%av@2eZXGHE( zCvfZ={nq(?PnT6OhU}mWhCqe17%+TN-wF5>@?l0QegEgs$u9psH0WFco^n2gtr2Kn z0+boEBJ_psXJ*D1_zt_YR)vmg{gbyoKZ6MgVp6qTq{d)urCru(`^SL`~`j!g@hlE^;4<4GsW;M z9|t^74Q%^LmEq-2?j-#PzP!kVrcbIp-QPxnD(9`@=Dh>M5@|4h01aJ~hry327j?tn z#vA<|^LAn&9JgH~_TGD4d>VuK4yE-))}M9Wcs{<`{q8VXIOZIpw8FE9%i!2+3SNw^ z@BGY9dQ-`9eSgl9vTgQfx~{EEI6w4n0;kL3<Sg8*ugx4JdYFq6hDjMEgD`FcrLj z%sNCr8pr!`LrxRKh&f88&E5sVYdFpw;#k8~aMY%FpIKggq9>~#A%n5gT8riNHk1X~ zKyxIYzi?q7b9GH8( zHxcJm7S=dhp?lbb5op4OV6wWf`TcXn0NFSYFjz$(&D?HAyvl z3|FK?RTWQazS|1>QgI@?PWBQ5xGpKv z{I_*$)~<6-p7$~;jbS9*v^L5w3Xc)V4B1D4+SGQDNlc8`PnSZk#sAdpbi!jKJR+*J zQ~&FIk#`{(T z4|5)`ayyjJ)NE9ib$L-S1P|_$4csNjW?$8H{Rb~T$6EqN{=0fs&eXLwJVPDFx3G-; zi_V+!bzueMWqa0bFhH+X7#rnUi$wHw|y zp+7^Eg7)%?!Yszc1_Le}pAUkDL0z!F=rmpK=X`DMAkYOZg5+lV;$Q8X(9CEh5j~Jr z#4JNCu7LU1p^wTQmJ{|OBdZ84VyNxf**%OnvR8llwQIzR4LQ7g)7uecr;)-^pI7uc zyJ7#(cXbEgdl7**erKNzqxk@>{J7Yk7IpuFZ|C_Y>$#>u;&Q%V@SmZH>nH=``+$5r zkt!6+h!`@qsNyy|x2~~ImA|I6IJ_HLrWbhE4}-H24F?4z#$e?wmojEZDwL{>8k4PX z7c!=bo28WFR0axGg6F7SnB@u6!!EM~EH&dR3$Ov$z_02-EhO8yg`rN@_2MCVOw(-v zv7WBKJ#5XhSHACh>RW8^%;P8_9y$m(IAmq*abunE3C1QQmt!wtAD>m!y~S4j@M^J# z5Oah8Aj2hJE3;q}mclFUo^=F3$5A<_;+;`jzHQuf^fa_S)Zi`krzvXXh=uY`<6_dl z6Kh`sd|S7l_d&2vPZxF18NG+{FLObi>t{>HKYL%W%{RPiTfL?$i}=Vs!?#t=B5XuN zr_*20&$<{2Dy|W5*b%btsh!Pl_%u#6MDI{cGqmz|H7@Yr=b->A7a~1Sb6eKKyIw-2D zs+o@ggj01?KX}kZ-Ah~@yrb^9A?RTSLPZunx{U?i@_J5qDP!_MDe}TmHm7M2hhAkl z&1C4zwnp5%_4s;MU{NiV#qg%CftVvwx!6({p*%|uu`#;J-8z%Htc!!bH)cwrEN({d zc`Klk{FyZ9mUm;?DZj#&7yC7Y-ByaWhf1i56A^MdE3f~WXTFycdW z1Q6DS+r7rySsiVP34`P>^<6EgT!>^JJt`>eQcS#7#f@}u-4*^&{uYwzls0_Jl16AI zTA;$JBpLerw+K7z>9Y3)cV%AvRY|tX_?!Q@^6q4gFO%?NZu%_=*kyr*p=#2g3HLVT zti(NTVB>Me=FKg2$6IAWMz0^3iZI(qE+6DU;nshi`(-@-WfwVm=9orS_CF zb$tE|_fp_WLoLr@$}ITdNj9;-yt)qGf#0hicPfka0azCep!W!lhM2Su#(%=^xY3WS{sCHt}fqak#9| zh_`QC>5zwb!#}=hh&0>~aV!gyxXjA%H45>Gy-2Et%QzPuWJ|q|v2G!!6qiL8OeG&M zlo6;mi>Oow*cCe*f`3{byMFz~Z{s_ea=u^$s<4-blZCcx)Luvujr=LBw|oqZF& zzIj^_biqx({dXixiM0syP}{T=c=1zT@c7oIX`^}>;K=~~;ZKi4jnCiQASxg|OL0yH z!DC_s_v!Qm6^LK7s!N=qQ@0*NkYYz~OhM^W3`-gMgnc1W@jJxJMn>u~7l8*!(CT`R zI~K8R`l=5BV}%58Fb|`CuEhI-TO{ji*@>2!3IG37@C<{v_fhk(;8|K1hh$T<`$_<~ ze9I9v04BFdYBTgm;gas6ML4@d1f7Jh+6-Yz%oTeF36V@+QoVM7(}eM>S_O(9bSV`& z-<3Y^EZ+kjtkrd{K|F0Sv_$D30>oKpJ)xo|D&{PFC(m<%5v=UpwyKHQty0yZId z;MrCB{7&9ZHV;yte$Y>$d!T}|KJ)Y!7qM^1_uYyuLu1?x@ik<^)s|UtG4e#L{D3m* z7UIC7W+u|yxqm(7d7Z0+^!@iK$l^PE4}jzP_}?;=&v9GLB2gTEjXZe8C6uUERQGaJ zw#pT;gysBPL*x~^!+y^B&5H&ZcAxw&|FEh+#*5vn;87cN$Ve(SzCXQ5mSd;%o9V43sMNT9hSc+?ZX?dh^o{FN9 zK>-}Bfvm8E*vPg7?O)Ld7N%ZO##x5tEurV~pp?yj4ZW8kG;8722SA@myPuutUN(?G`23 z;9aUu#SGHE zz9EywqV+uwKC4xTo-Rg_~X(TSp3Rc;PVK#{{JfDZ5O~o+~kHMhuRDb14(9G z*OL}n(Xl8I*W>lnuz_3aim+{-Dg&}b+?eVgnuBTj#nHaO9C{+@60Fh=J3TDeTwCA> z9O?MF@%a4@izv>66(kP>%g2yLer`sy!Oo;$eU1h6P(SgYXj+75z%~ByVk1EuKpZi> z!?QF)D*clLT5CI-rQ75AEz=ukV(fE=OQqSWwA&0`A#jZ_?-AszorHZZqaT2gxH1F`7 z-^&i5nH(g;=Y4(AjtNW~%X&ET>f{YE^~(joW+&@;$q1sFu0C&? z3J>exEw^3WZ14Gxqwv_{rud+LDWQ&;NB%ED50oNvaNj`uoE3Ib_{}jg14NSz2%S5I zfc6O%Mu!Qh9Zy8?gf&x_C)}dY0LPPz@n$m1Sc%!ygZOc%CU+u)5;zWUM!Kt;)I3B0 zXH&;%!Q<@++sp2iYPY}yUI4;j=m@A<0HbVG&#}ott^w$?bYd)OTcdv~LgL4sCle)1 zVwh+O_V1ccyzsE=7pNFoG^*ma_$5W8$Xp}_wk+1RFcRgpJ$`T$NHFxGc`B9(Gr8iS`hOM8_z$tj&;s}*xBp3w`1lyBEwDUx z+qkaYzr!58{U)EZX?S)^6YvhJRbZ5X>uKM!K9ZrxbVS$y^Nql zn}ikuIJEwq)yt1TM1Etih8o2-3%c=hBy`xoaz%MR_t=EAWAxotGBGuXC;Ug6IVS01 zD~c8n9{)Qo(bNTOiwPYQW2W9Z3}IOOCn+wY5zz9{Q!MCZ!t?BGYZL+_BaHpu@AmM~8Ek09bm z!5nOm>v+(z$IwTw=?2}IJ8?RrKbx$KeXWbZA6!To*!yMYUk0?cT6Oz_3nFM=UxvZP zYVIxanHI~T(p>Olv%=z}NW9>hz?|%lg;G{R8V4quUbGNT78ua618Qpj;E`%!fq^J&ViSE z2Gn9x?;FD74k~t^Rk|LnRe1ueuhfWN+J*v(h67z~ZZ}(T5*vG8bQk;~U(^NQIbF0! z0AtHpEqEv!!&=zs%<1!z$Z?1p?I77l{_;?QuNzUN1Augj=relRtAi&eNJ#M{up(1h z{k-Ss7xu>lf!qT<^*B`XvyYvw&JO-^+5AB<1X~=!-{3!&A1F=aiDF-24p*}w1Fe?QB; z(CZoNi=Jjo!4(D7ueYGrs5Utp6yVV^LusnDWYp$GIGAZn1k=S=C93=_HC98pgl+^` zg|Z0d>?Pk|#8!yglGwq=*l;v71F-hD>v3ldY&wf&5201g(of8sozsGjAa96AO_by~ zV9;fer=fc*wknw{&^OPy0!z#nb98N}SJuH08kgkXF}d0r?M-+)dJu^z0diSJ@DKLt z#H|^vQxt*iwE|{s*1agEpYMms9LLI$2^%-7&qa}%mtu!O7NviIO_5_>3&1bc)}Wn1 z-}wr1G;-vS03bF1o*56Xoh!k-tEatZ^(SjZ4Dk7p`G_hu_MV%X;6W|D6455JR|ZG) z{@RkGIS*Ixz(0(^5EoFd7D@Q~kci;RnA_`|FsKb?F#q!X=l%<=i0}44c|L05fQB=t z+vj|NnSV>k%!LWG+3oK%0s83f9P*;XO~yp?UNMUvTy-=t1qFHvJ|yX* zQNy-S#UEkZb}9xUfJkkRq?(r(0=NCPUKEx(UA)1e7?fBAk}Oh5NBo}Fer)a?Y;h&& zzl-B)v0qubO*;~_z(mowTrq_8yu8g`{w_k{HEf4GF0W!^C|6xajsWt|7p|aWyTbK8v^S$l>>+$r0o6qjVuOePb+j56j*q~ zA5J=4mprI4{D8Dvmkt_jI9?YnEfO6cigP3_(#Mf_sT4~t*=DmEr@P)P9cSC6<22rV z|0W-Qj~1gE@rSCO7XTc622uQ90Yd~ub4TK@Xq!%iF5l-;lu?^HcleBns9gDM<7}u|T1RS|6K!4O8#|4mm}Dvf3T-GJLdB~x2M0&zGg`KVx?iYXfu6%v!9bUx)zdqdGoJr7 zIi?JmM@m7b1Jzq~myDi|m2BcyP?Kr}YJmk;Hj8u3myl4y@=ityJo@I}Pi@9T%EGgn zC5W&|Vi08+NS3Epek*f1zJEb5%Znd`S_8cmb$7&K*^CY%QnW_HBfKyi$Hkf zhPSc@v=Q(g?$IP|cgIZ7&6))sh8codk$9TWj4A$N{-kH<6i1uu9sLp`FK(`aEh>9_ z$0|+91AEW*tD};L|SVw7$Na<&NhM1kZ3t1{dL#|j;2Vy6eBeZp`-|)=?*KlGv z5?Qk;ej5 z{Dns*9|`EDrqT-Zj8(cNVvfS&-*!&m&xd>yCrOS}QUItWP+FO!J~eez{F+GTh- zfl|BW4If7E=nyrc(Lmat#2IKToygzW%u#k}?d<{!7KaQ7K+lXXpx&Qnt!ola#%H~8a0?xR7IBnwf@qc> zraR$o@hlnW7biTHSm%tJ0RXQIV4&UJtv*$fn)_b>$ov}HvRo!$7!f}Q`o4i6?b~OR zfxe3dYDp_+on+=nbB|s#m7tLWZALvsE~NyrW)L7a zR#IRb0sRW#@~>HN+Q zpO@`_(uEUplJ>{+I|G$d5|C$IdiZ7%`dTWAdYNk=%WKXoI4W*#_-I8^%ARx_Z{*s) zh&KTW9K&d1pF%re8_h;J+?w3PMZ=I_EcSIiNn+>QQ3rE+ z+wouXBy^uLY$qiB?*2QmUT>-DsR5Ra@zbt{@d?yT!ZC5(T2iW6;aCc0f#w1zVzvd7 zGO4?yeZ@i#3o?9R91i*H(D`9;#9i`1-D7Eh%JFs|Oa+ZZag!~xuZW1Wuz5pvlsjhJ z%O^MX-)5ve$gnDc#Q$z|$5yKJSv|U|ya4a2_pB|zYmRhLLDS*noXw@ItyEl1{o35i z1Z_8twGQzLe|K(>zJWh8>v7=#vt6B#D)A+?&&SjtjeLUfGTpr12g{0;C7(qbrpp!$ zV0Xv)ocS21r+`7I3dpi_( zyYi)#^lzsyuJmSNq<-v$YK$mFkACiXSEfp-p7y6z+W!&FXMiVxJ%mzyv@P_Ne9Gr# zdm73At==>>TZsX$k}ope^k|SZEU?u({8}RGa~c_$4oRLQl8$2%1t~e;L3Pj%MbENb zn*C3C2rp3u`tj$?75L2JN%EH>%dpYWz-VAr>`dlu9v)ZQL-(gtRV0S1Se-NWHlU<0 zB1%a#;_Q?E1BhQs30dKX-3HKhg%P1fA^NA6y54DRTr71t&kT8#7L4C^O)R{Ie(&B zCac^ZGZYKvpGc+2jP+-<6@z@VQet0y(=;~n#<;1?{1?PzMxFKQ#(lnbHWVDqjm`Y2 ztf2UNZCvwm!^hhLW$rE-`rkm|1^Bw~#sM&Nm4Tk8?qx_Zy{s5j^%z+0kAko5yhMQ^ zeRRfdV9xY)ZK2$pI^i6g>n#^u?fhVH)lI>Y5#~$|9&CG>Nv`_JTqGWl#UOKWPamk# zFf=t6DEg|GyecFQ5UWITu}ssp6@@dJrDn@utz>({nZZzzYxa&3P%r0&*t^xR3z@1X z2(%;|Nr{pS9h8^|VF5wW zv;VdEZ588h8zGEfgEl8m{~U|j9`eGwmX*g5n(UW$XTNM}#&_wCp(`tDUi`)nslE=b z#@}kVOW22y7dgP$?*!z}Oqge+9s5%O>XrscVv0q2o&evS*w~2^ee$RP2M6cnSUdM)LAlM%3IFdeTZqSL^yKowbGaU++}e&5^KzV*&XYpPQ&O%rqR`Z{tg!3Th0y?CA+MZ);}n= zu|yxZXwHEXsioLRN24T%=ud}G{6KJaDT8ShzNr38pqe5ey#7#f$sD4Gd`%5QIaDNt*6y6pNm)aB_Q2Eo$r;p-a zgpAo4A> z=na)Kd`NC)KCF$aumewN=m%yR`QG}@*Bkx+GN6>fHdhWxiZZfzjdcQwj!55Cb)4=} z*8j!frRu64#-=CxDWKq_Z=*FX514XwzS9fP9Hc5Y6xyS2A|bs-cbLK+SF3ja<8;4_ zhC1FGv)L=Wmd-q5QOT)&SVswPBl!0GoA76{v$z=1ksZD%1t5bf_Rl#UI(YC2tPgV7 zF!qRU(cQqNW$-TwgLLiZ+eDwdM8%LeqjYkoi&FKu>331Nx~TA(?t2Tzq4m}L3d0}2 z+iRE#HGhEKIqDJty$09Nv#u6!f=LCwg5sVEkVIFT3&xoET;=(Pva;Zyc>?oBr;Jhw z!Ah|akG`T1Dxp%xj_CnEw7Anvo6d3*nrzY9!J`UJ{I3XBTGm z(~PxsQvvTZTBO*=CT>4)S`m&XG-^A}oy*Easib2L5qgewn3XSFW|?JaQpfuvV-u(16PT$=O}j_?l1tjJkjoq88sNufL#N+r zR2{>9G_*tu3!TfGQ5G%%{%|`gXrm+p-51=hD-392+Li4D$3M(=u0k$6KnnF9Swoh% zFf`>=kzpr20|j9Q6Jqx+A?#!RCHw!i_tk$QDIuLocXugDBaQSR((pcf-}_(OJD-`KCZ6-`Ip^%N_g;Igcm4u_edV?#gBPqO zX*s|0D@q0)=s);QB!*_N_?C;O9)A->$!9=WL#u9+M%SQ69U*$$5a7Fc~uFnJCd5Y-I#$h$JY;qXR z6#dz}Q}P54PG0o|m1|I*wO}@Q)R5!W-A`mh4N@}~vyry?Q#ObffUhX_zR$f>Es$67 zy@9+-CE~GOS9=!8vMf?}Xh?jWtvI$WS1L_hqO8^=XeP37-Kk7mEc0T!ADy!FiRuf~ zvGE5#kEc7n^4Gz*vi!qR0_Hjgjp#p^8{4nK7OG2IMzP@q|AMucfC89)80eL5fLNVC zj=icD$B10bjyY3Zy7dP`?5a65ewVL1(_s#xlJ+2Z;!_=Z(nD7irETRH=#nr!%c!P( zPSMl+SApekf6IHRe6Dz_aon$zBAh$h#JxEuxwQe~+ZwL!yS4s7e3Yei9P5e=bd&bZ zzmCST-T_m#Xdb0yNN3j!DZN;P3; zZFS+gb7`T}vk`+Gz5$d?Zs+%gZUzqdE*?0ld}bI^Wk(9PpGv(AwGa9_GC9wX05iy=wp8K&Cl?yuYP{`I3u3ox?O&fyAw8H{~((f zcyG~~S}d=0;2>iAE>&g4U+$pw!oxojDr#U62~xYv)t* zdRcbfeWW>Z&uSud$y=XHMC5xMXyAc(HzWZwW^egi5>AlTpR(_tf#x zXwj0y?`(ZLH&?=w$t1;P4X)KV#IdbQIH?ld!OyHl40HE*I>5H3F zmB%!$nBCfC5W8qb_75cjGlH99B_v#cZutIG^y$ngJIgiyn=Q&IdE?T5~ z_E}8C6}k02RHh0ObKa|^!&nBOcqKN1haxq>q;>%=>uaBSs~+aRBJ$kZ^-Af~%S!tD zCg=1SbtRGIq5-qXT7YAl0Y@l%lrmra+T7HZ5yFBAS7O?--PXVH>moA2XX94uteiiA zXb{EAbPGgDMrVx?U_(Ga=IP0087eZkY{9#o@JfH8HR&C_Ko@_vEzwktSf=rr(dEmV z9v?+5N4seTN$b*Fthuw;IueTp#w=a{DaISF$id&4u!!8rr~t^NY%?*6Y6_X_@Ciga zp7!X98=^B%b)gCqB;{8mgIS4a)oCSbFoQY5>>AKPaRN-RPx3ZkNS;(d;es89ftByX z_yg0oQ+Kka6TK>hJ|9(_WyuB+48kMQ4j#xk8mk-sQ+Zp~bRH00$lp(K^abx0{A{kx znR}7$DZOE*5G|*KAxt6{{bsyepQvHGnD~)M$$Bghw77~0s%a_yvw?wRFV{h8?%=Ou zXIjx{Nv~He-1o|eCG=V#xypksB&(~Cn%4Rex8A#n{sQc#&+_cg%!k>S)xPpWl2^{GRbBd)wfGnuQDWpEt~TQ0 z?&i<5OdNvg5i>y5EIk3CnJ7yc4gG?1|6pmzGUpuzj}pLU>@MlcBFMg7pY4tR*h}gB(~=mY%1}n|_63=G ziyxagnyf9brpH(|bl|T0DkH2X3a(uG4`=AKUlfwU)2%lmxr9cIeN+B@ER`y;8I~M| z$C8|V4RYU;&xgN}Now;84q`eNPEOC^IUEc8JL7z9)qHIh(JK>O!WP%ZTP%Oub~${C zBAcEM1r;G*cV4V=$3_o#zIfwTh@Zu4#5&?i=IQ-D_-lXC^r8@nijAThA>niq8|dE& zSc4-l`FR|yf}w6v0>?Pc6H@V@IpP^koLQTM?(R(xe&9(MWTV5S{^D)WZiGhX$Uk-+ z5+H;%)^}CpSD~!=>P&#n=j0OJR&VqS=)p&~QSm{gY|NA6PvD8k{5kueM|EEzzj?oU zs<49mGlN>@Pmn3l-y^;0Tvj6Y@jr<3QOL2LTSW;@4wsU*q-I(y_4blGJ=f~@0*12hWOUw$&urvTi3Kb(|uG7no17}XIS?; z6A}InNPpP=*^%TQAh*=KHNcNfSN>34jYb;;2p~%FzxE8+B{}k3IJ53S*=sxa(7S}; zA^eYb{x%q(WWG#Fo|OeihKM3bkN(~TlO7aWFII-@dzYxV#`~O}(biS$FxqHRLpd>& z`O%x3_D`-h*X%SY)|Z|jxJ$oO`+yqfa-}laS`P>_J#N#pai@`Qdi|)JFBAHHAMd7X zsXm&(%W}I;pM;^|1Uuy#g=*>W_cEn7UHj5yduhuJA>psY#4KSfqs^q*0PoQ@&TW6S zGA9Dtczw@pWWD7I&Dh?Xvsf4YX@pfWv@St zX!W&4>nm1=ED7!o`HcI&!Lmdxh@3nm6aGPkUT`JjTNIfa_$V(Ra&%nypSx!vuGiYu z=wXQ6hLe<6)9}A9!tti?mmU{KC|662x8fdS9PGncFk;SUCSO1e|hc-LM<8#4bM7!%;T7`NPaIF zv~TsE`@FwjEg(n86d&wJ+B9OI5N=2_VygAxuf(4vw& zQ5pm%j7~2O(X15BTC8E$!tCBP7F0mL>G zESh=z24eF+9rA7C!=wXhBSz^_j}iovFy*%iEzYfT1Z&&sJ#5y(zf!pp zdG$u_yC1G^v$H7wCr`7U64Nc3Hz<((0P;=(PxhX?!h`eNU zfK5K{Co-gTK7S@xD@j48Pwg#C}Hp(*$;xqGAK`nXgeDnYwcx^(_0HQ*(>Kjp2(gvmizG0Ld)xq?nN>~wl7_94p`_T*tn_GW1*jl~hj z^JHGfY=cS&mUl#IKFdc4rKP7s$dG~<$0=*bu1sD7oVEEC>u91tFab2=*SWFqpv&JR zBPnzkC8M1E2TodPhK&jbE-uqbsg0~>E!@#l`TSaAdX`b+CIYa{=2wz*u#wN6BQ5#a z2P~M{+s{yc41L7Tr3`|%5rFt(?j`P4A7~u!1k_crfY3q4#m$TB&QbYC*T*Uey|-_M zLsD;_6G5_cpB<81Q7{-I1FgKY6y_Nabba!4OsA7!E|2oBYs9Cpz$MOSUY?u6^{SXDEWF)h zwLq8FE$oN+jL?+&R#OM=p{U)E1U%F=b}_7~ZSp;Z9$lbCB8yEv7M=qVuWK(GTkil4 z)cpfMyVTDan-2C#6PMNrL&^U5u8eudek=m$5*d!9x7cl|avT1G6-D9?$2N=V4w{hL z9m1gk!JvI$dV&QjU81=cCZr-jC#R^mo}aw>vDV#NCdYhoi1kiR0SPb#@#gab^Q2z* z9_8yDjd6Hs{*?@#>mlQeI{UKrXEFWqs{B5VJqY~yLR`dQ#*uy+5Q!o{D-p&y6Jkx4 zpT04xY@ReLk3T zG8Gm8MjA;;Oo|Ti&S2@GkoW#K3-2+)$Eahnqd*&iE+vq$%1tHl*uVk0vO6BjBy@@Vu&1qR))= zHJD}Q!&k&rk~Fr0Z72u^i1mRau_+6NKAKud#RU1Urltvq)bdp3&eTRh(>6eiU7SBE zh6L?8xSlg`Pt^OJeZSRxOwOzj%05vh162S`h@ed~N!7RtkJ{mIx{BLB>#=42($qeRqaW{Bta8oZ9>!`H`=<*?giO;MRBx;}8vxaVMo`QL!XRgrfWz2wqnjd#24 z0t1^ZHAGkg2XpU|Yd*_M`u}w4(SX-#qaTM8x_AM??>wsa{;*(q`OOG1Y2zIbl$oc0 zlQ$9+_br?fR`mwowg1OkwPzB~X@}0GubeC1ePIEJToohbEt~gWD`sM)$PDTKEk#2g z_1{ZY_{TBnK=ONab9GW`I^pFo46z|}^zX);QpT))q2}AKqYH8Gf($(EPfv<)5r==* zdL28NQUA5vj+Kb}2LyzHHn2qn27J{9D3q9pUWQiLW?Bme!9!Z_B+=(6pcd3%7TTlm z?pLjUM4rzx6)uJ2M?nEdr!qdhI#+fv`DF*K$R}isTR+V6hHTN~a=_Km5kGT?ci(Wd zC3zMGgg0hDi==T^0kw@G{Z5Tl$~NH?(~cTj5hu;2m2YH`9ounSLd{!KJWx)4T8QyAGPNk6`jjMHFdmprd=zg%N4@)58 zxkE<}5DnG+AfJKgDECPNOeq|#jRrG)VGm&ba$@Uc-KY%2QsS7l;3gHLC%Wg852q;I${Hi?XDC7RRh*uDhc<(#|d9QspZgsOOXmTt_LsEAD~6}cpEJK zTjC*wAfd#;aGRk7g+zwW&S(N!c$b4hxH3OIdWFIuWlyH^1Uw~h_uq8V46I`)o?F0! zzXFTcJNPa6ASGTn5hpfO4)R_ZJ7;uGQ6j<*tV*+v?T!o893!Jrr~p&6JM-6l&Vm5d zJWMNkV_NAB?HM(pMlJb8n}s6jD0C)Bp?A;}YkvxpCE}uG`=9f;UHWe#lwbXmCik18 zOduKPz#R|C6Y_fOloXz)1Gmgb4ZG!Mf_JBcQ@%9O5mwO=wj6rp7)ozdeWndwL|8i~ zHw>(uk4FOvPKJFST3BGg0|r2&15h)K@KDN-iWstday?}j3x!}6iiH<7B7gu0hCK3RdqpLpV71h)EM88}_>xHPOUzuu~;`~D-;~?ff$Af@0PLV7yjZdyKo!U>N zv&v=rh%&`~G6}qvgfoR|pzv(oi^6R_tGwlXo}v#d!f+QeJn$?|Q2hss!p!BlSn=X1 zXvzVFqw{t!jsM9?I9PbrC7~nzR_TD}lrAK$T^JpkPC#7o;Sg{y>r5hF569yim3VhG z$`I=qd;`wUYw+m$%%UqDvo@b9PF6tY4Br2-qYtH0Mv@ibRhX;vj+VM!RiGer6eCL- z3VHSlu`L*2EgjESs#oc^QZc^_n~#J@{5`%GPpMfwVGqQp* zcgL6Yf$I7sIB#p-fW!2a0{@i z%dyzIxfd|Y==vCu_pWZC0LkO$RO-X^c@~Ag9J7iS2Iz|%mB$r*eq2aAJk34prf*sJ zqN>4rEt1c1ax!hXV~_p&^oK|M-GQeWyGY+qL_WZR3gBej@U_m+3djcm*s)c$I2o$a zv_y2twx6dCNS3NlZpNonBjZf{Ph1>{oQWWVuJ@rB^V2a<$K^m~5^)f+uZn`2Mb^-% z?do6f&BgZk8psaEX4oo0YI*M;JSn4^F*dFgDdGRA?Aw=IcK!Uhv;{elT%S^Dc^q2r z7!)_!jNTC+46`byZ;Q(-H_mSL?SaS0_J@9rfX8&8{@UpeK!6U1uI+Aj+U{$Ce3oo4 zqb?mf=?dJe62LgSP3S*`Gcv!Uyz5f99(GA;&%MJ(R{f(qF z5h@q|g`inRniI@|2bad@m1yoTIrm>-&x3C4o>7>+T37y#%D{VRJz!#-5Gi&k4Q|EOzQ2GGt>91`(YYkab}dJa1oFQkHXPFD_A}dN>J0hBjL3EOI%;eG1IH} zfp^dE!&oM<9*Zbu7d@*F;SamLG}p9Gp9kyB9A2Tkd=KC;-(`E)#6hCFV}sTF=R7)vEQ9RVR@j8Hpg0*MF!zOkjGaqne4YFdOdMRWvJr4j1?4*lMgy zES{@V>F?Nmu01J*9>dIzp~0e3*;(0CGiU0mxd2FR*>tRGu?6t>MPlFZj|}>8j38*Si!ehcE@X&!x5m>yQh2 zx{mvjWy3OKzZcuAoh6YAzbQUSg|~9U*ZC;IriEWdcg=yEbJh~%HshLTYZ;Gob0bQH z4i{&M9Y3^xO*IxAL*_PNHn*XGg2-it?m7Cthu;|AiL*>Ta{gnu%BM8l3+^n!JWo>r zmozaCwk%V{ndgpYcdbWdkvp-H;^9jznK?PJq+5OOpQ8SB%PB@rS+O^TyVZZ$*Qs*r zUUq4Y-=ndL-p!Tcx$E(4mZibl4kfKlN4~ugp5`zAESQ+(R8rrUbtkw^>(iJmCwqSU zFJsLi)fYpUODr60(_$8-f=7TGP4CigJKR+|{u_CcYC)TNo3bE9k*6$}Km2^^V0xyU zH`j&@$1pKBk$056b-7CPm%ZQ?IAYf4^C2_z z_ibwFu$3u!!(xbXiKe{<&NqTc+AQqu2n$-sYRpqEoS=;^BxROc;^xE^EXV+tmK2)! zyyuLJxth{okpt)Nu*dLBRGHVii^N3vhwatt-&a*6k6q#@;E3Jf9{_LQl0s*p5oGgk zA9k9Y7iVjFZ4Bsr04be!7bFbh-Q@2<;=MrFhZL=r4Y9ZmYUqL{E`N%-Nv;2leuARK z>C!e?tM}EJnuae8qMXf%FqI^(fB`rH4yOHGuzKjA8=3uW`-+;nWV;QVvG;X7Tbv<1 z=IWN9(;I#z=#>PIM$a`NOR9qLCC`bbeBA2an5YJ%&ssMacO+pdnT+aqc^sA<-FmEmtk-yrGez)M+rv2&q_VWa%bJJm8UJ3Uzk|62t3SA?BiU$kURe9iD zctR8549y=%9wvp+VLaY@8r#|)5vv~d3(n_n$s*wuLw$glqJR$U$V$np}-jL zkE9oYeH-o3cK(Kn=gIXai~bF2|3ZV9Cs{f zZuC|?1{NSRbdy}E&om;7)DqXL`!!j&y|vG5Wjjfdj8F%h?|v>Fp4GLzB;t0M1g8f#Y#ujUtC%_oM>x?tBz?{u*jZnT<-DbUo9MjRVQ zdvyzvgd!bJy<$cGt{b|2^cGDrD10ceH8N@Yaqhke_Y!^Z@{?hT9etL3i~#~zuv;$^ zl4qp>Fp-E`L^BIstkQ!QJr=?6QtPGKfD1=#_#)=#`oEje?H7qMqZ=#zqJaRn_HO0I zJJ(%ytt4S!2P|2O^Jq>hsQ>b8W!~kO%QIF)ZKa#g$=G@6LxP&`k_^xv<-%xp&#o4V z`7yb*$9zUSdJGrU2Ct%D$+A)Onk}L80H-J=Xl$?F}gfU zAHoJ-ww`+qH!_)q08RL%K!HG%O17`YR&*N0Z8h}vKx7KVi}17C3xEh>yHs?eE&V%6 zbA3-YLzWElEq>L}6qTVUm#MD)X8>1|L;A>MK=08-)oQ=eA}(}aN`G{ZY|>`nSh!bX zq}+K0LONe5{JYEl)m5q6_(l4tt|PvW9FRLX#vDi0T%pV2=`}Jk8he@cZi3P_CBj#m`$T3$S|UM8J)`G8?UXasP9#avIC1AzyhOzPY({1 zJ4CqAeX+J5mTV_tmC&+=OL82+0I5rI7wQ@+Gvec zXh5=&nYe@L^9Vq@F$o54w%)1=4F@1qj?!!PUCdoyvF2l|TE|S`sG9 zzVwPwZkw0Nwxwjo)F%wcI^{q)AhAs%I^}PEdH~dQo#8ikCKUILOoZ(cGr94v*5+C` z1MH&GD_(4;^2BV!o8+{yC*P0M1Cb~)$JvUNhcJ+mpQ#|C+EQb8@>$mn>ow!q%9k+< zY%gRrAk%xoNYZ#;iQ&R_oqT}`4AFP!W#HLM?v8-ezO-i4w42$Bo4Zt<{|uA48)V4m zgYRNfA9zA=-_7?*U-mIj@sWtjSc|cZ*IZI780`GjvNEZQmKa0mH9u^{mwYpF69k|e z?zS^KuQHEwD4e2xCyjIh3`j@5ekPL%rWgAG+4gGgWN`V=4x~7M@W0YTa;_pd!&k23 zX|CcgE{s-WR5y8(O8jrtwHi1Jf8*OTJS&b0e=pnAd968L?Wx9 z3KSR-1uV=!H<_$P5;0(n|NGkindET>{jZrNNe>dd%xSNLh3v~fz(YY+6;>r}_UZot D=4bd( diff --git a/tests/test_rasterize_meshes.py b/tests/test_rasterize_meshes.py index 5fde5ac7..c6746ccf 100644 --- a/tests/test_rasterize_meshes.py +++ b/tests/test_rasterize_meshes.py @@ -10,6 +10,10 @@ from pytorch3d.renderer.mesh.rasterize_meshes import ( rasterize_meshes, rasterize_meshes_python, ) +from pytorch3d.renderer.mesh.utils import ( + _clip_barycentric_coordinates, + _interpolate_zbuf, +) from pytorch3d.structures import Meshes from pytorch3d.utils import ico_sphere @@ -21,6 +25,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase): self._simple_blurry_raster(rasterize_meshes_python, device, bin_size=-1) self._test_behind_camera(rasterize_meshes_python, device, bin_size=-1) self._test_perspective_correct(rasterize_meshes_python, device, bin_size=-1) + self._test_barycentric_clipping(rasterize_meshes_python, device, bin_size=-1) self._test_back_face_culling(rasterize_meshes_python, device, bin_size=-1) def test_simple_cpu_naive(self): @@ -170,8 +175,29 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase): verts2.requires_grad = True meshes_cuda = Meshes(verts=[verts2], faces=[faces2]) - args_cpu = (meshes_cpu, image_size, radius, faces_per_pixel) - args_cuda = (meshes_cuda, image_size, radius, faces_per_pixel, 0, 0) + barycentric_clip = True + args_cpu = ( + meshes_cpu, + image_size, + radius, + faces_per_pixel, + None, + None, + False, + barycentric_clip, + False, + ) + args_cuda = ( + meshes_cuda, + image_size, + radius, + faces_per_pixel, + 0, + 0, + False, + barycentric_clip, + False, + ) self._compare_impls( rasterize_meshes, rasterize_meshes, @@ -333,6 +359,39 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase): idxs_cuda[:K] = sorted(idxs_cuda[:K]) self.assertEqual(idxs_cpu, idxs_cuda) + def test_python_vs_cpp_bary_clip(self): + torch.manual_seed(232) + N = 2 + V = 10 + F = 5 + verts1 = torch.randn(N, V, 3, requires_grad=True) + verts2 = verts1.detach().clone().requires_grad_(True) + faces = torch.randint(V, size=(N, F, 3)) + meshes1 = Meshes(verts1, faces) + meshes2 = Meshes(verts2, faces) + + kwargs = {"image_size": 24, "clip_barycentric_coords": True} + fn1 = functools.partial(rasterize_meshes, meshes1, **kwargs) + fn2 = functools.partial(rasterize_meshes_python, meshes2, **kwargs) + args = () + self._compare_impls(fn1, fn2, args, args, verts1, verts2, compare_grads=True) + + def test_cpp_vs_cuda_bary_clip(self): + meshes = ico_sphere(2, device=torch.device("cpu")) + verts1, faces1 = meshes.get_mesh_verts_faces(0) + verts1.requires_grad = True + meshes1 = Meshes(verts=[verts1], faces=[faces1]) + device = get_random_cuda_device() + verts2 = verts1.detach().to(device).requires_grad_(True) + faces2 = faces1.detach().clone().to(device) + meshes2 = Meshes(verts=[verts2], faces=[faces2]) + + kwargs = {"image_size": 64, "clip_barycentric_coords": True} + fn1 = functools.partial(rasterize_meshes, meshes1, **kwargs) + fn2 = functools.partial(rasterize_meshes, meshes2, bin_size=0, **kwargs) + args = () + self._compare_impls(fn1, fn2, args, args, verts1, verts2, compare_grads=True) + def test_python_vs_cpp_perspective_correct(self): torch.manual_seed(232) N = 2 @@ -621,6 +680,82 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase): self.assertLess(zbuf_f_bary_diff, 1e-4) self.assertLess(zbuf_t_bary_diff, 1e-4) + def _test_barycentric_clipping(self, rasterize_meshes_fn, device, bin_size=None): + # fmt: off + verts = torch.tensor([ + [-0.4, -0.4, 10], # noqa: E241, E201 + [ 0.4, -0.4, 10], # noqa: E241, E201 + [ 0.0, 0.4, 20], # noqa: E241, E201 + ], dtype=torch.float32, device=device) + # fmt: on + faces = torch.tensor([[0, 1, 2]], device=device) + meshes = Meshes(verts=[verts], faces=[faces]) + kwargs = { + "meshes": meshes, + "image_size": 5, + "faces_per_pixel": 1, + "blur_radius": 0.2, + "perspective_correct": False, + "clip_barycentric_coords": False, # Initially set this to false + } + if bin_size != -1: + kwargs["bin_size"] = bin_size + + # Run with and without perspective correction + idx_f, zbuf_f, bary_f, dists_f = rasterize_meshes_fn(**kwargs) + + # fmt: off + expected_bary = torch.tensor([ + [ + [-1.0000, -1.0000, -1.0000], # noqa: E241, E201 + [-1.0000, -1.0000, -1.0000], # noqa: E241, E201 + [-0.2500, -0.2500, 1.5000], # noqa: E241, E201 + [-1.0000, -1.0000, -1.0000], # noqa: E241, E201 + [-1.0000, -1.0000, -1.0000] # noqa: E241, E201 + ], + [ + [-1.0000, -1.0000, -1.0000], # noqa: E241, E201 + [-0.5000, 0.5000, 1.0000], # noqa: E241, E201 + [-0.0000, -0.0000, 1.0000], # noqa: E241, E201 + [ 0.5000, -0.5000, 1.0000], # noqa: E241, E201 + [-1.0000, -1.0000, -1.0000] # noqa: E241, E201 + ], + [ + [-1.0000, -1.0000, -1.0000], # noqa: E241, E201 + [-0.2500, 0.7500, 0.5000], # noqa: E241, E201 + [ 0.2500, 0.2500, 0.5000], # noqa: E241, E201 + [ 0.7500, -0.2500, 0.5000], # noqa: E241, E201 + [-1.0000, -1.0000, -1.0000] # noqa: E241, E201 + ], + [ + [-0.5000, 1.5000, -0.0000], # noqa: E241, E201 + [-0.0000, 1.0000, -0.0000], # noqa: E241, E201 + [ 0.5000, 0.5000, -0.0000], # noqa: E241, E201 + [ 1.0000, -0.0000, -0.0000], # noqa: E241, E201 + [ 1.5000, -0.5000, 0.0000] # noqa: E241, E201 + ], + [ + [-1.0000, -1.0000, -1.0000], # noqa: E241, E201 + [ 0.2500, 1.2500, -0.5000], # noqa: E241, E201 + [ 0.7500, 0.7500, -0.5000], # noqa: E241, E201 + [ 1.2500, 0.2500, -0.5000], # noqa: E241, E201 + [-1.0000, -1.0000, -1.0000] # noqa: E241, E201 + ] + ], dtype=torch.float32, device=device).view(1, 5, 5, 1, 3) + # fmt: on + + self.assertClose(expected_bary, bary_f, atol=1e-4) + + # calculate the expected clipped barycentrics and zbuf + expected_bary_clipped = _clip_barycentric_coordinates(expected_bary) + expected_z_clipped = _interpolate_zbuf(idx_f, expected_bary_clipped, meshes) + + kwargs["clip_barycentric_coords"] = True + idx_t, zbuf_t, bary_t, dists_t = rasterize_meshes_fn(**kwargs) + + self.assertClose(expected_bary_clipped, bary_t, atol=1e-4) + self.assertClose(expected_z_clipped, zbuf_t, atol=1e-4) + def _test_behind_camera(self, rasterize_meshes_fn, device, bin_size=None): """ All verts are behind the camera so nothing should get rasterized. diff --git a/tests/test_render_meshes.py b/tests/test_render_meshes.py index 0ae19471..c4325e30 100644 --- a/tests/test_render_meshes.py +++ b/tests/test_render_meshes.py @@ -212,6 +212,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): image_size=512, blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma, faces_per_pixel=80, + clip_barycentric_coords=True, ) # Init rasterizer settings @@ -269,11 +270,19 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): # the cow is facing the -z direction. lights.location = torch.tensor([0.0, 0.0, 2.0], device=device)[None] + blend_params = BlendParams( + sigma=1e-1, + gamma=1e-4, + background_color=torch.tensor([1.0, 1.0, 1.0], device=device), + ) # Init renderer renderer = MeshRenderer( rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings), shader=TexturedSoftPhongShader( - lights=lights, cameras=cameras, materials=materials + lights=lights, + cameras=cameras, + materials=materials, + blend_params=blend_params, ), ) @@ -346,6 +355,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): image_size=512, blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma, faces_per_pixel=100, + clip_barycentric_coords=True, ) # Load reference image