From 62dbf371aef2aeac11802901a771d85116a3717d Mon Sep 17 00:00:00 2001 From: Justin Johnson Date: Wed, 8 Sep 2021 16:15:36 -0700 Subject: [PATCH] Move coarse rasterization to new file Summary: In preparation for sharing coarse rasterization between point clouds and meshes, move the functions to a new file. No code changes. Reviewed By: bottler Differential Revision: D30367812 fbshipit-source-id: 9e73835a26c4ac91f5c9f61ff682bc8218e36c6a --- .../bitmask.cuh | 0 .../csrc/rasterize_coarse/rasterize_coarse.cu | 481 ++++++++++++++++++ .../csrc/rasterize_coarse/rasterize_coarse.h | 38 ++ .../csrc/rasterize_meshes/rasterize_meshes.cu | 233 --------- .../csrc/rasterize_meshes/rasterize_meshes.h | 15 +- .../csrc/rasterize_points/rasterize_points.cu | 226 -------- .../csrc/rasterize_points/rasterize_points.h | 13 +- pytorch3d/csrc/utils/float_math.cuh | 8 + 8 files changed, 534 insertions(+), 480 deletions(-) rename pytorch3d/csrc/{rasterize_points => rasterize_coarse}/bitmask.cuh (100%) create mode 100644 pytorch3d/csrc/rasterize_coarse/rasterize_coarse.cu create mode 100644 pytorch3d/csrc/rasterize_coarse/rasterize_coarse.h diff --git a/pytorch3d/csrc/rasterize_points/bitmask.cuh b/pytorch3d/csrc/rasterize_coarse/bitmask.cuh similarity index 100% rename from pytorch3d/csrc/rasterize_points/bitmask.cuh rename to pytorch3d/csrc/rasterize_coarse/bitmask.cuh diff --git a/pytorch3d/csrc/rasterize_coarse/rasterize_coarse.cu b/pytorch3d/csrc/rasterize_coarse/rasterize_coarse.cu new file mode 100644 index 00000000..89250165 --- /dev/null +++ b/pytorch3d/csrc/rasterize_coarse/rasterize_coarse.cu @@ -0,0 +1,481 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include "rasterize_coarse/bitmask.cuh" +#include "rasterize_points/rasterization_utils.cuh" +#include "utils/float_math.cuh" +#include "utils/geometry_utils.cuh" // For kEpsilon -- gross + +// Get the xyz coordinates of the three vertices for the face given by the +// index face_idx into face_verts. +__device__ thrust::tuple GetSingleFaceVerts( + const float* face_verts, + int face_idx) { + const float x0 = face_verts[face_idx * 9 + 0]; + const float y0 = face_verts[face_idx * 9 + 1]; + const float z0 = face_verts[face_idx * 9 + 2]; + const float x1 = face_verts[face_idx * 9 + 3]; + const float y1 = face_verts[face_idx * 9 + 4]; + const float z1 = face_verts[face_idx * 9 + 5]; + const float x2 = face_verts[face_idx * 9 + 6]; + const float y2 = face_verts[face_idx * 9 + 7]; + const float z2 = face_verts[face_idx * 9 + 8]; + + const float3 v0xyz = make_float3(x0, y0, z0); + const float3 v1xyz = make_float3(x1, y1, z1); + const float3 v2xyz = make_float3(x2, y2, z2); + + return thrust::make_tuple(v0xyz, v1xyz, v2xyz); +} + +__global__ void RasterizeMeshesCoarseCudaKernel( + const float* face_verts, + const int64_t* mesh_to_face_first_idx, + const int64_t* num_faces_per_mesh, + const float blur_radius, + const int N, + const int F, + const int H, + const int W, + const int bin_size, + const int chunk_size, + const int max_faces_per_bin, + int* faces_per_bin, + int* bin_faces) { + extern __shared__ char sbuf[]; + const int M = max_faces_per_bin; + // Integer divide round up + const int num_bins_x = 1 + (W - 1) / bin_size; + const int num_bins_y = 1 + (H - 1) / bin_size; + + // NDC range depends on the ratio of W/H + // The shorter side from (H, W) is given an NDC range of 2.0 and + // the other side is scaled by the ratio of H:W. + const float NDC_x_half_range = NonSquareNdcRange(W, H) / 2.0f; + const float NDC_y_half_range = NonSquareNdcRange(H, W) / 2.0f; + + // Size of half a pixel in NDC units is the NDC half range + // divided by the corresponding image dimension + const float half_pix_x = NDC_x_half_range / W; + const float half_pix_y = NDC_y_half_range / H; + + // This is a boolean array of shape (num_bins_y, num_bins_x, chunk_size) + // stored in shared memory that will track whether each point in the chunk + // falls into each bin of the image. + BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size); + + // Have each block handle a chunk of faces + const int chunks_per_batch = 1 + (F - 1) / chunk_size; + const int num_chunks = N * chunks_per_batch; + + for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) { + const int batch_idx = chunk / chunks_per_batch; // batch index + const int chunk_idx = chunk % chunks_per_batch; + const int face_start_idx = chunk_idx * chunk_size; + + binmask.block_clear(); + const int64_t mesh_face_start_idx = mesh_to_face_first_idx[batch_idx]; + const int64_t mesh_face_stop_idx = + mesh_face_start_idx + num_faces_per_mesh[batch_idx]; + + // Have each thread handle a different face within the chunk + for (int f = threadIdx.x; f < chunk_size; f += blockDim.x) { + const int f_idx = face_start_idx + f; + + // Check if face index corresponds to the mesh in the batch given by + // batch_idx + if (f_idx >= mesh_face_stop_idx || f_idx < mesh_face_start_idx) { + continue; + } + + // Get xyz coordinates of the three face vertices. + const auto v012 = GetSingleFaceVerts(face_verts, f_idx); + const float3 v0 = thrust::get<0>(v012); + const float3 v1 = thrust::get<1>(v012); + const float3 v2 = thrust::get<2>(v012); + + // Compute screen-space bbox for the triangle expanded by blur. + float xmin = FloatMin3(v0.x, v1.x, v2.x) - sqrt(blur_radius); + float ymin = FloatMin3(v0.y, v1.y, v2.y) - sqrt(blur_radius); + float xmax = FloatMax3(v0.x, v1.x, v2.x) + sqrt(blur_radius); + float ymax = FloatMax3(v0.y, v1.y, v2.y) + sqrt(blur_radius); + float zmin = FloatMin3(v0.z, v1.z, v2.z); + + // Faces with at least one vertex behind the camera won't render + // correctly and should be removed or clipped before calling the + // rasterizer + if (zmin < kEpsilon) { + continue; + } + + // Brute-force search over all bins; TODO(T54294966) something smarter. + for (int by = 0; by < num_bins_y; ++by) { + // Y coordinate of the top and bottom of the bin. + // PixToNdc gives the location of the center of each pixel, so we + // need to add/subtract a half pixel to get the true extent of the bin. + // Reverse ordering of Y axis so that +Y is upwards in the image. + const float bin_y_min = + PixToNonSquareNdc(by * bin_size, H, W) - half_pix_y; + const float bin_y_max = + PixToNonSquareNdc((by + 1) * bin_size - 1, H, W) + half_pix_y; + const bool y_overlap = (ymin <= bin_y_max) && (bin_y_min < ymax); + + for (int bx = 0; bx < num_bins_x; ++bx) { + // X coordinate of the left and right of the bin. + // Reverse ordering of x axis so that +X is left. + const float bin_x_max = + PixToNonSquareNdc((bx + 1) * bin_size - 1, W, H) + half_pix_x; + const float bin_x_min = + PixToNonSquareNdc(bx * bin_size, W, H) - half_pix_x; + + const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax); + if (y_overlap && x_overlap) { + binmask.set(by, bx, f); + } + } + } + } + __syncthreads(); + // Now we have processed every face in the current chunk. We need to + // count the number of faces in each bin so we can write the indices + // out to global memory. We have each thread handle a different bin. + for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x; + byx += blockDim.x) { + const int by = byx / num_bins_x; + const int bx = byx % num_bins_x; + const int count = binmask.count(by, bx); + const int faces_per_bin_idx = + batch_idx * num_bins_y * num_bins_x + by * num_bins_x + bx; + + // This atomically increments the (global) number of faces found + // in the current bin, and gets the previous value of the counter; + // this effectively allocates space in the bin_faces array for the + // faces in the current chunk that fall into this bin. + const int start = atomicAdd(faces_per_bin + faces_per_bin_idx, count); + + // Now loop over the binmask and write the active bits for this bin + // out to bin_faces. + int next_idx = batch_idx * num_bins_y * num_bins_x * M + + by * num_bins_x * M + bx * M + start; + for (int f = 0; f < chunk_size; ++f) { + if (binmask.get(by, bx, f)) { + // TODO(T54296346) find the correct method for handling errors in + // CUDA. Throw an error if num_faces_per_bin > max_faces_per_bin. + // Either decrease bin size or increase max_faces_per_bin + bin_faces[next_idx] = face_start_idx + f; + next_idx++; + } + } + } + __syncthreads(); + } +} + +__global__ void RasterizePointsCoarseCudaKernel( + const float* points, // (P, 3) + const int64_t* cloud_to_packed_first_idx, // (N) + const int64_t* num_points_per_cloud, // (N) + const float* radius, + const int N, + const int P, + const int H, + const int W, + const int bin_size, + const int chunk_size, + const int max_points_per_bin, + int* points_per_bin, + int* bin_points) { + extern __shared__ char sbuf[]; + const int M = max_points_per_bin; + + // Integer divide round up + const int num_bins_x = 1 + (W - 1) / bin_size; + const int num_bins_y = 1 + (H - 1) / bin_size; + + // NDC range depends on the ratio of W/H + // The shorter side from (H, W) is given an NDC range of 2.0 and + // the other side is scaled by the ratio of H:W. + const float NDC_x_half_range = NonSquareNdcRange(W, H) / 2.0f; + const float NDC_y_half_range = NonSquareNdcRange(H, W) / 2.0f; + + // Size of half a pixel in NDC units is the NDC half range + // divided by the corresponding image dimension + const float half_pix_x = NDC_x_half_range / W; + const float half_pix_y = NDC_y_half_range / H; + + // This is a boolean array of shape (num_bins_y, num_bins_x, chunk_size) + // stored in shared memory that will track whether each point in the chunk + // falls into each bin of the image. + BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size); + + // Have each block handle a chunk of points and build a 3D bitmask in + // shared memory to mark which points hit which bins. In this first phase, + // each thread processes one point at a time. After processing the chunk, + // one thread is assigned per bin, and the thread counts and writes the + // points for the bin out to global memory. + const int chunks_per_batch = 1 + (P - 1) / chunk_size; + const int num_chunks = N * chunks_per_batch; + for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) { + const int batch_idx = chunk / chunks_per_batch; + const int chunk_idx = chunk % chunks_per_batch; + const int point_start_idx = chunk_idx * chunk_size; + + binmask.block_clear(); + + // Using the batch index of the thread get the start and stop + // indices for the points. + const int64_t cloud_point_start_idx = cloud_to_packed_first_idx[batch_idx]; + const int64_t cloud_point_stop_idx = + cloud_point_start_idx + num_points_per_cloud[batch_idx]; + + // Have each thread handle a different point within the chunk + for (int p = threadIdx.x; p < chunk_size; p += blockDim.x) { + const int p_idx = point_start_idx + p; + + // Check if point index corresponds to the cloud in the batch given by + // batch_idx. + if (p_idx >= cloud_point_stop_idx || p_idx < cloud_point_start_idx) { + continue; + } + + const float px = points[p_idx * 3 + 0]; + const float py = points[p_idx * 3 + 1]; + const float pz = points[p_idx * 3 + 2]; + const float p_radius = radius[p_idx]; + if (pz < 0) + continue; // Don't render points behind the camera. + const float px0 = px - p_radius; + const float px1 = px + p_radius; + const float py0 = py - p_radius; + const float py1 = py + p_radius; + + // Brute-force search over all bins; TODO something smarter? + // For example we could compute the exact bin where the point falls, + // then check neighboring bins. This way we wouldn't have to check + // all bins (however then we might have more warp divergence?) + for (int by = 0; by < num_bins_y; ++by) { + // Get y extent for the bin. PixToNonSquareNdc gives us the location of + // the center of each pixel, so we need to add/subtract a half + // pixel to get the true extent of the bin. + const float by0 = PixToNonSquareNdc(by * bin_size, H, W) - half_pix_y; + const float by1 = + PixToNonSquareNdc((by + 1) * bin_size - 1, H, W) + half_pix_y; + const bool y_overlap = (py0 <= by1) && (by0 <= py1); + + if (!y_overlap) { + continue; + } + for (int bx = 0; bx < num_bins_x; ++bx) { + // Get x extent for the bin; again we need to adjust the + // output of PixToNonSquareNdc by half a pixel. + const float bx0 = PixToNonSquareNdc(bx * bin_size, W, H) - half_pix_x; + const float bx1 = + PixToNonSquareNdc((bx + 1) * bin_size - 1, W, H) + half_pix_x; + const bool x_overlap = (px0 <= bx1) && (bx0 <= px1); + + if (x_overlap) { + binmask.set(by, bx, p); + } + } + } + } + __syncthreads(); + // Now we have processed every point in the current chunk. We need to + // count the number of points in each bin so we can write the indices + // out to global memory. We have each thread handle a different bin. + for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x; + byx += blockDim.x) { + const int by = byx / num_bins_x; + const int bx = byx % num_bins_x; + const int count = binmask.count(by, bx); + const int points_per_bin_idx = + batch_idx * num_bins_y * num_bins_x + by * num_bins_x + bx; + + // This atomically increments the (global) number of points found + // in the current bin, and gets the previous value of the counter; + // this effectively allocates space in the bin_points array for the + // points in the current chunk that fall into this bin. + const int start = atomicAdd(points_per_bin + points_per_bin_idx, count); + + // Now loop over the binmask and write the active bits for this bin + // out to bin_points. + int next_idx = batch_idx * num_bins_y * num_bins_x * M + + by * num_bins_x * M + bx * M + start; + for (int p = 0; p < chunk_size; ++p) { + if (binmask.get(by, bx, p)) { + // TODO: Throw an error if next_idx >= M -- this means that + // we got more than max_points_per_bin in this bin + // TODO: check if atomicAdd is needed in line 265. + bin_points[next_idx] = point_start_idx + p; + next_idx++; + } + } + } + __syncthreads(); + } +} + +at::Tensor RasterizeMeshesCoarseCuda( + const at::Tensor& face_verts, + const at::Tensor& mesh_to_face_first_idx, + const at::Tensor& num_faces_per_mesh, + const std::tuple image_size, + const float blur_radius, + const int bin_size, + const int max_faces_per_bin) { + TORCH_CHECK( + face_verts.ndimension() == 3 && face_verts.size(1) == 3 && + face_verts.size(2) == 3, + "face_verts must have dimensions (num_faces, 3, 3)"); + + // Check inputs are on the same device + at::TensorArg face_verts_t{face_verts, "face_verts", 1}, + mesh_to_face_first_idx_t{ + mesh_to_face_first_idx, "mesh_to_face_first_idx", 2}, + num_faces_per_mesh_t{num_faces_per_mesh, "num_faces_per_mesh", 3}; + at::CheckedFrom c = "RasterizeMeshesCoarseCuda"; + at::checkAllSameGPU( + c, {face_verts_t, mesh_to_face_first_idx_t, num_faces_per_mesh_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(face_verts.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + const int H = std::get<0>(image_size); + const int W = std::get<1>(image_size); + + const int F = face_verts.size(0); + const int N = num_faces_per_mesh.size(0); + const int M = max_faces_per_bin; + + // Integer divide round up. + const int num_bins_y = 1 + (H - 1) / bin_size; + const int num_bins_x = 1 + (W - 1) / bin_size; + + if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) { + std::stringstream ss; + ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y + << ", num_bins_x: " << num_bins_x << ", " + << "; that's too many!"; + AT_ERROR(ss.str()); + } + auto opts = num_faces_per_mesh.options().dtype(at::kInt); + at::Tensor faces_per_bin = at::zeros({N, num_bins_y, num_bins_x}, opts); + at::Tensor bin_faces = at::full({N, num_bins_y, num_bins_x, M}, -1, opts); + + if (bin_faces.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return bin_faces; + } + + const int chunk_size = 512; + const size_t shared_size = num_bins_y * num_bins_x * chunk_size / 8; + const size_t blocks = 64; + const size_t threads = 512; + + RasterizeMeshesCoarseCudaKernel<<>>( + face_verts.contiguous().data_ptr(), + mesh_to_face_first_idx.contiguous().data_ptr(), + num_faces_per_mesh.contiguous().data_ptr(), + blur_radius, + N, + F, + H, + W, + bin_size, + chunk_size, + M, + faces_per_bin.data_ptr(), + bin_faces.data_ptr()); + + AT_CUDA_CHECK(cudaGetLastError()); + return bin_faces; +} + +at::Tensor RasterizePointsCoarseCuda( + const at::Tensor& points, // (P, 3) + const at::Tensor& cloud_to_packed_first_idx, // (N) + const at::Tensor& num_points_per_cloud, // (N) + const std::tuple image_size, + const at::Tensor& radius, + const int bin_size, + const int max_points_per_bin) { + TORCH_CHECK( + points.ndimension() == 2 && points.size(1) == 3, + "points must have dimensions (num_points, 3)"); + + // Check inputs are on the same device + at::TensorArg points_t{points, "points", 1}, + cloud_to_packed_first_idx_t{ + cloud_to_packed_first_idx, "cloud_to_packed_first_idx", 2}, + num_points_per_cloud_t{num_points_per_cloud, "num_points_per_cloud", 3}; + at::CheckedFrom c = "RasterizePointsCoarseCuda"; + at::checkAllSameGPU( + c, {points_t, cloud_to_packed_first_idx_t, num_points_per_cloud_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(points.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + const int H = std::get<0>(image_size); + const int W = std::get<1>(image_size); + + const int P = points.size(0); + const int N = num_points_per_cloud.size(0); + const int M = max_points_per_bin; + + // Integer divide round up. + const int num_bins_y = 1 + (H - 1) / bin_size; + const int num_bins_x = 1 + (W - 1) / bin_size; + + if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) { + // Make sure we do not use too much shared memory. + std::stringstream ss; + ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y + << ", num_bins_x: " << num_bins_x << ", " + << "; that's too many!"; + AT_ERROR(ss.str()); + } + auto opts = num_points_per_cloud.options().dtype(at::kInt); + at::Tensor points_per_bin = at::zeros({N, num_bins_y, num_bins_x}, opts); + at::Tensor bin_points = at::full({N, num_bins_y, num_bins_x, M}, -1, opts); + + if (bin_points.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return bin_points; + } + + const int chunk_size = 512; + const size_t shared_size = num_bins_y * num_bins_x * chunk_size / 8; + const size_t blocks = 64; + const size_t threads = 512; + + RasterizePointsCoarseCudaKernel<<>>( + points.contiguous().data_ptr(), + cloud_to_packed_first_idx.contiguous().data_ptr(), + num_points_per_cloud.contiguous().data_ptr(), + radius.contiguous().data_ptr(), + N, + P, + H, + W, + bin_size, + chunk_size, + M, + points_per_bin.contiguous().data_ptr(), + bin_points.contiguous().data_ptr()); + + AT_CUDA_CHECK(cudaGetLastError()); + return bin_points; +} diff --git a/pytorch3d/csrc/rasterize_coarse/rasterize_coarse.h b/pytorch3d/csrc/rasterize_coarse/rasterize_coarse.h new file mode 100644 index 00000000..6c0f0b41 --- /dev/null +++ b/pytorch3d/csrc/rasterize_coarse/rasterize_coarse.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +// Arguments are the same as RasterizeMeshesCoarse from +// rasterize_meshes/rasterize_meshes.h +#ifdef WITH_CUDA +torch::Tensor RasterizeMeshesCoarseCuda( + const torch::Tensor& face_verts, + const torch::Tensor& mesh_to_face_first_idx, + const torch::Tensor& num_faces_per_mesh, + const std::tuple image_size, + const float blur_radius, + const int bin_size, + const int max_faces_per_bin); +#endif + +// Arguments are the same as RasterizePointsCoarse from +// rasterize_points/rasterize_points.h +#ifdef WITH_CUDA +torch::Tensor RasterizePointsCoarseCuda( + const torch::Tensor& points, + const torch::Tensor& cloud_to_packed_first_idx, + const torch::Tensor& num_points_per_cloud, + const std::tuple image_size, + const torch::Tensor& radius, + const int bin_size, + const int max_points_per_bin); +#endif diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu index 2dd5fde1..73e89974 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu @@ -14,7 +14,6 @@ #include #include #include -#include "rasterize_points/bitmask.cuh" #include "rasterize_points/rasterization_utils.cuh" #include "utils/float_math.cuh" #include "utils/geometry_utils.cuh" @@ -32,14 +31,6 @@ __device__ bool operator<(const Pixel& a, const Pixel& b) { return a.z < b.z; } -__device__ float FloatMin3(const float p1, const float p2, const float p3) { - return fminf(p1, fminf(p2, p3)); -} - -__device__ float FloatMax3(const float p1, const float p2, const float p3) { - return fmaxf(p1, fmaxf(p2, p3)); -} - // Get the xyz coordinates of the three vertices for the face given by the // index face_idx into face_verts. __device__ thrust::tuple GetSingleFaceVerts( @@ -630,230 +621,6 @@ at::Tensor RasterizeMeshesBackwardCuda( return grad_face_verts; } -// **************************************************************************** -// * COARSE RASTERIZATION * -// **************************************************************************** - -__global__ void RasterizeMeshesCoarseCudaKernel( - const float* face_verts, - const int64_t* mesh_to_face_first_idx, - const int64_t* num_faces_per_mesh, - const float blur_radius, - const int N, - const int F, - const int H, - const int W, - const int bin_size, - const int chunk_size, - const int max_faces_per_bin, - int* faces_per_bin, - int* bin_faces) { - extern __shared__ char sbuf[]; - const int M = max_faces_per_bin; - // Integer divide round up - const int num_bins_x = 1 + (W - 1) / bin_size; - const int num_bins_y = 1 + (H - 1) / bin_size; - - // NDC range depends on the ratio of W/H - // The shorter side from (H, W) is given an NDC range of 2.0 and - // the other side is scaled by the ratio of H:W. - const float NDC_x_half_range = NonSquareNdcRange(W, H) / 2.0f; - const float NDC_y_half_range = NonSquareNdcRange(H, W) / 2.0f; - - // Size of half a pixel in NDC units is the NDC half range - // divided by the corresponding image dimension - const float half_pix_x = NDC_x_half_range / W; - const float half_pix_y = NDC_y_half_range / H; - - // This is a boolean array of shape (num_bins_y, num_bins_x, chunk_size) - // stored in shared memory that will track whether each point in the chunk - // falls into each bin of the image. - BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size); - - // Have each block handle a chunk of faces - const int chunks_per_batch = 1 + (F - 1) / chunk_size; - const int num_chunks = N * chunks_per_batch; - - for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) { - const int batch_idx = chunk / chunks_per_batch; // batch index - const int chunk_idx = chunk % chunks_per_batch; - const int face_start_idx = chunk_idx * chunk_size; - - binmask.block_clear(); - const int64_t mesh_face_start_idx = mesh_to_face_first_idx[batch_idx]; - const int64_t mesh_face_stop_idx = - mesh_face_start_idx + num_faces_per_mesh[batch_idx]; - - // Have each thread handle a different face within the chunk - for (int f = threadIdx.x; f < chunk_size; f += blockDim.x) { - const int f_idx = face_start_idx + f; - - // Check if face index corresponds to the mesh in the batch given by - // batch_idx - if (f_idx >= mesh_face_stop_idx || f_idx < mesh_face_start_idx) { - continue; - } - - // Get xyz coordinates of the three face vertices. - const auto v012 = GetSingleFaceVerts(face_verts, f_idx); - const float3 v0 = thrust::get<0>(v012); - const float3 v1 = thrust::get<1>(v012); - const float3 v2 = thrust::get<2>(v012); - - // Compute screen-space bbox for the triangle expanded by blur. - float xmin = FloatMin3(v0.x, v1.x, v2.x) - sqrt(blur_radius); - float ymin = FloatMin3(v0.y, v1.y, v2.y) - sqrt(blur_radius); - float xmax = FloatMax3(v0.x, v1.x, v2.x) + sqrt(blur_radius); - float ymax = FloatMax3(v0.y, v1.y, v2.y) + sqrt(blur_radius); - float zmin = FloatMin3(v0.z, v1.z, v2.z); - - // Faces with at least one vertex behind the camera won't render - // correctly and should be removed or clipped before calling the - // rasterizer - if (zmin < kEpsilon) { - continue; - } - - // Brute-force search over all bins; TODO(T54294966) something smarter. - for (int by = 0; by < num_bins_y; ++by) { - // Y coordinate of the top and bottom of the bin. - // PixToNdc gives the location of the center of each pixel, so we - // need to add/subtract a half pixel to get the true extent of the bin. - // Reverse ordering of Y axis so that +Y is upwards in the image. - const float bin_y_min = - PixToNonSquareNdc(by * bin_size, H, W) - half_pix_y; - const float bin_y_max = - PixToNonSquareNdc((by + 1) * bin_size - 1, H, W) + half_pix_y; - const bool y_overlap = (ymin <= bin_y_max) && (bin_y_min < ymax); - - for (int bx = 0; bx < num_bins_x; ++bx) { - // X coordinate of the left and right of the bin. - // Reverse ordering of x axis so that +X is left. - const float bin_x_max = - PixToNonSquareNdc((bx + 1) * bin_size - 1, W, H) + half_pix_x; - const float bin_x_min = - PixToNonSquareNdc(bx * bin_size, W, H) - half_pix_x; - - const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax); - if (y_overlap && x_overlap) { - binmask.set(by, bx, f); - } - } - } - } - __syncthreads(); - // Now we have processed every face in the current chunk. We need to - // count the number of faces in each bin so we can write the indices - // out to global memory. We have each thread handle a different bin. - for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x; - byx += blockDim.x) { - const int by = byx / num_bins_x; - const int bx = byx % num_bins_x; - const int count = binmask.count(by, bx); - const int faces_per_bin_idx = - batch_idx * num_bins_y * num_bins_x + by * num_bins_x + bx; - - // This atomically increments the (global) number of faces found - // in the current bin, and gets the previous value of the counter; - // this effectively allocates space in the bin_faces array for the - // faces in the current chunk that fall into this bin. - const int start = atomicAdd(faces_per_bin + faces_per_bin_idx, count); - - // Now loop over the binmask and write the active bits for this bin - // out to bin_faces. - int next_idx = batch_idx * num_bins_y * num_bins_x * M + - by * num_bins_x * M + bx * M + start; - for (int f = 0; f < chunk_size; ++f) { - if (binmask.get(by, bx, f)) { - // TODO(T54296346) find the correct method for handling errors in - // CUDA. Throw an error if num_faces_per_bin > max_faces_per_bin. - // Either decrease bin size or increase max_faces_per_bin - bin_faces[next_idx] = face_start_idx + f; - next_idx++; - } - } - } - __syncthreads(); - } -} - -at::Tensor RasterizeMeshesCoarseCuda( - const at::Tensor& face_verts, - const at::Tensor& mesh_to_face_first_idx, - const at::Tensor& num_faces_per_mesh, - const std::tuple image_size, - const float blur_radius, - const int bin_size, - const int max_faces_per_bin) { - TORCH_CHECK( - face_verts.ndimension() == 3 && face_verts.size(1) == 3 && - face_verts.size(2) == 3, - "face_verts must have dimensions (num_faces, 3, 3)"); - - // Check inputs are on the same device - at::TensorArg face_verts_t{face_verts, "face_verts", 1}, - mesh_to_face_first_idx_t{ - mesh_to_face_first_idx, "mesh_to_face_first_idx", 2}, - num_faces_per_mesh_t{num_faces_per_mesh, "num_faces_per_mesh", 3}; - at::CheckedFrom c = "RasterizeMeshesCoarseCuda"; - at::checkAllSameGPU( - c, {face_verts_t, mesh_to_face_first_idx_t, num_faces_per_mesh_t}); - - // Set the device for the kernel launch based on the device of the input - at::cuda::CUDAGuard device_guard(face_verts.device()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - const int H = std::get<0>(image_size); - const int W = std::get<1>(image_size); - - const int F = face_verts.size(0); - const int N = num_faces_per_mesh.size(0); - const int M = max_faces_per_bin; - - // Integer divide round up. - const int num_bins_y = 1 + (H - 1) / bin_size; - const int num_bins_x = 1 + (W - 1) / bin_size; - - if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) { - std::stringstream ss; - ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y - << ", num_bins_x: " << num_bins_x << ", " - << "; that's too many!"; - AT_ERROR(ss.str()); - } - auto opts = num_faces_per_mesh.options().dtype(at::kInt); - at::Tensor faces_per_bin = at::zeros({N, num_bins_y, num_bins_x}, opts); - at::Tensor bin_faces = at::full({N, num_bins_y, num_bins_x, M}, -1, opts); - - if (bin_faces.numel() == 0) { - AT_CUDA_CHECK(cudaGetLastError()); - return bin_faces; - } - - const int chunk_size = 512; - const size_t shared_size = num_bins_y * num_bins_x * chunk_size / 8; - const size_t blocks = 64; - const size_t threads = 512; - - RasterizeMeshesCoarseCudaKernel<<>>( - face_verts.contiguous().data_ptr(), - mesh_to_face_first_idx.contiguous().data_ptr(), - num_faces_per_mesh.contiguous().data_ptr(), - blur_radius, - N, - F, - H, - W, - bin_size, - chunk_size, - M, - faces_per_bin.data_ptr(), - bin_faces.data_ptr()); - - AT_CUDA_CHECK(cudaGetLastError()); - return bin_faces; -} - // **************************************************************************** // * FINE RASTERIZATION * // **************************************************************************** diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h index d3ae63f9..3e4b40ab 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h @@ -10,6 +10,7 @@ #include #include #include +#include "rasterize_coarse/rasterize_coarse.h" #include "utils/pytorch3d_cutils.h" // **************************************************************************** @@ -236,6 +237,8 @@ torch::Tensor RasterizeMeshesBackward( // * COARSE RASTERIZATION * // **************************************************************************** +// RasterizeMeshesCoarseCuda in rasterize_coarse/rasterize_coarse.h + torch::Tensor RasterizeMeshesCoarseCpu( const torch::Tensor& face_verts, const at::Tensor& mesh_to_face_first_idx, @@ -245,16 +248,6 @@ torch::Tensor RasterizeMeshesCoarseCpu( const int bin_size, const int max_faces_per_bin); -#ifdef WITH_CUDA -torch::Tensor RasterizeMeshesCoarseCuda( - const torch::Tensor& face_verts, - const torch::Tensor& mesh_to_face_first_idx, - const torch::Tensor& num_faces_per_mesh, - const std::tuple image_size, - const float blur_radius, - const int bin_size, - const int max_faces_per_bin); -#endif // Args: // face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for // faces in all the meshes in the batch. Concretely, @@ -499,7 +492,7 @@ RasterizeMeshes( const bool cull_backfaces) { if (bin_size > 0 && max_faces_per_bin > 0) { // Use coarse-to-fine rasterization - auto bin_faces = RasterizeMeshesCoarse( + at::Tensor bin_faces = RasterizeMeshesCoarse( face_verts, mesh_to_face_first_idx, num_faces_per_mesh, diff --git a/pytorch3d/csrc/rasterize_points/rasterize_points.cu b/pytorch3d/csrc/rasterize_points/rasterize_points.cu index fcfb769b..251be7f4 100644 --- a/pytorch3d/csrc/rasterize_points/rasterize_points.cu +++ b/pytorch3d/csrc/rasterize_points/rasterize_points.cu @@ -13,7 +13,6 @@ #include #include #include -#include "rasterize_points/bitmask.cuh" #include "rasterize_points/rasterization_utils.cuh" namespace { @@ -217,231 +216,6 @@ std::tuple RasterizePointsNaiveCuda( return std::make_tuple(point_idxs, zbuf, pix_dists); } -// **************************************************************************** -// * COARSE RASTERIZATION * -// **************************************************************************** - -__global__ void RasterizePointsCoarseCudaKernel( - const float* points, // (P, 3) - const int64_t* cloud_to_packed_first_idx, // (N) - const int64_t* num_points_per_cloud, // (N) - const float* radius, - const int N, - const int P, - const int H, - const int W, - const int bin_size, - const int chunk_size, - const int max_points_per_bin, - int* points_per_bin, - int* bin_points) { - extern __shared__ char sbuf[]; - const int M = max_points_per_bin; - - // Integer divide round up - const int num_bins_x = 1 + (W - 1) / bin_size; - const int num_bins_y = 1 + (H - 1) / bin_size; - - // NDC range depends on the ratio of W/H - // The shorter side from (H, W) is given an NDC range of 2.0 and - // the other side is scaled by the ratio of H:W. - const float NDC_x_half_range = NonSquareNdcRange(W, H) / 2.0f; - const float NDC_y_half_range = NonSquareNdcRange(H, W) / 2.0f; - - // Size of half a pixel in NDC units is the NDC half range - // divided by the corresponding image dimension - const float half_pix_x = NDC_x_half_range / W; - const float half_pix_y = NDC_y_half_range / H; - - // This is a boolean array of shape (num_bins_y, num_bins_x, chunk_size) - // stored in shared memory that will track whether each point in the chunk - // falls into each bin of the image. - BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size); - - // Have each block handle a chunk of points and build a 3D bitmask in - // shared memory to mark which points hit which bins. In this first phase, - // each thread processes one point at a time. After processing the chunk, - // one thread is assigned per bin, and the thread counts and writes the - // points for the bin out to global memory. - const int chunks_per_batch = 1 + (P - 1) / chunk_size; - const int num_chunks = N * chunks_per_batch; - for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) { - const int batch_idx = chunk / chunks_per_batch; - const int chunk_idx = chunk % chunks_per_batch; - const int point_start_idx = chunk_idx * chunk_size; - - binmask.block_clear(); - - // Using the batch index of the thread get the start and stop - // indices for the points. - const int64_t cloud_point_start_idx = cloud_to_packed_first_idx[batch_idx]; - const int64_t cloud_point_stop_idx = - cloud_point_start_idx + num_points_per_cloud[batch_idx]; - - // Have each thread handle a different point within the chunk - for (int p = threadIdx.x; p < chunk_size; p += blockDim.x) { - const int p_idx = point_start_idx + p; - - // Check if point index corresponds to the cloud in the batch given by - // batch_idx. - if (p_idx >= cloud_point_stop_idx || p_idx < cloud_point_start_idx) { - continue; - } - - const float px = points[p_idx * 3 + 0]; - const float py = points[p_idx * 3 + 1]; - const float pz = points[p_idx * 3 + 2]; - const float p_radius = radius[p_idx]; - if (pz < 0) - continue; // Don't render points behind the camera. - const float px0 = px - p_radius; - const float px1 = px + p_radius; - const float py0 = py - p_radius; - const float py1 = py + p_radius; - - // Brute-force search over all bins; TODO something smarter? - // For example we could compute the exact bin where the point falls, - // then check neighboring bins. This way we wouldn't have to check - // all bins (however then we might have more warp divergence?) - for (int by = 0; by < num_bins_y; ++by) { - // Get y extent for the bin. PixToNonSquareNdc gives us the location of - // the center of each pixel, so we need to add/subtract a half - // pixel to get the true extent of the bin. - const float by0 = PixToNonSquareNdc(by * bin_size, H, W) - half_pix_y; - const float by1 = - PixToNonSquareNdc((by + 1) * bin_size - 1, H, W) + half_pix_y; - const bool y_overlap = (py0 <= by1) && (by0 <= py1); - - if (!y_overlap) { - continue; - } - for (int bx = 0; bx < num_bins_x; ++bx) { - // Get x extent for the bin; again we need to adjust the - // output of PixToNonSquareNdc by half a pixel. - const float bx0 = PixToNonSquareNdc(bx * bin_size, W, H) - half_pix_x; - const float bx1 = - PixToNonSquareNdc((bx + 1) * bin_size - 1, W, H) + half_pix_x; - const bool x_overlap = (px0 <= bx1) && (bx0 <= px1); - - if (x_overlap) { - binmask.set(by, bx, p); - } - } - } - } - __syncthreads(); - // Now we have processed every point in the current chunk. We need to - // count the number of points in each bin so we can write the indices - // out to global memory. We have each thread handle a different bin. - for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x; - byx += blockDim.x) { - const int by = byx / num_bins_x; - const int bx = byx % num_bins_x; - const int count = binmask.count(by, bx); - const int points_per_bin_idx = - batch_idx * num_bins_y * num_bins_x + by * num_bins_x + bx; - - // This atomically increments the (global) number of points found - // in the current bin, and gets the previous value of the counter; - // this effectively allocates space in the bin_points array for the - // points in the current chunk that fall into this bin. - const int start = atomicAdd(points_per_bin + points_per_bin_idx, count); - - // Now loop over the binmask and write the active bits for this bin - // out to bin_points. - int next_idx = batch_idx * num_bins_y * num_bins_x * M + - by * num_bins_x * M + bx * M + start; - for (int p = 0; p < chunk_size; ++p) { - if (binmask.get(by, bx, p)) { - // TODO: Throw an error if next_idx >= M -- this means that - // we got more than max_points_per_bin in this bin - // TODO: check if atomicAdd is needed in line 265. - bin_points[next_idx] = point_start_idx + p; - next_idx++; - } - } - } - __syncthreads(); - } -} - -at::Tensor RasterizePointsCoarseCuda( - const at::Tensor& points, // (P, 3) - const at::Tensor& cloud_to_packed_first_idx, // (N) - const at::Tensor& num_points_per_cloud, // (N) - const std::tuple image_size, - const at::Tensor& radius, - const int bin_size, - const int max_points_per_bin) { - TORCH_CHECK( - points.ndimension() == 2 && points.size(1) == 3, - "points must have dimensions (num_points, 3)"); - - // Check inputs are on the same device - at::TensorArg points_t{points, "points", 1}, - cloud_to_packed_first_idx_t{ - cloud_to_packed_first_idx, "cloud_to_packed_first_idx", 2}, - num_points_per_cloud_t{num_points_per_cloud, "num_points_per_cloud", 3}; - at::CheckedFrom c = "RasterizePointsCoarseCuda"; - at::checkAllSameGPU( - c, {points_t, cloud_to_packed_first_idx_t, num_points_per_cloud_t}); - - // Set the device for the kernel launch based on the device of the input - at::cuda::CUDAGuard device_guard(points.device()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - const int H = std::get<0>(image_size); - const int W = std::get<1>(image_size); - - const int P = points.size(0); - const int N = num_points_per_cloud.size(0); - const int M = max_points_per_bin; - - // Integer divide round up. - const int num_bins_y = 1 + (H - 1) / bin_size; - const int num_bins_x = 1 + (W - 1) / bin_size; - - if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) { - // Make sure we do not use too much shared memory. - std::stringstream ss; - ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y - << ", num_bins_x: " << num_bins_x << ", " - << "; that's too many!"; - AT_ERROR(ss.str()); - } - auto opts = num_points_per_cloud.options().dtype(at::kInt); - at::Tensor points_per_bin = at::zeros({N, num_bins_y, num_bins_x}, opts); - at::Tensor bin_points = at::full({N, num_bins_y, num_bins_x, M}, -1, opts); - - if (bin_points.numel() == 0) { - AT_CUDA_CHECK(cudaGetLastError()); - return bin_points; - } - - const int chunk_size = 512; - const size_t shared_size = num_bins_y * num_bins_x * chunk_size / 8; - const size_t blocks = 64; - const size_t threads = 512; - - RasterizePointsCoarseCudaKernel<<>>( - points.contiguous().data_ptr(), - cloud_to_packed_first_idx.contiguous().data_ptr(), - num_points_per_cloud.contiguous().data_ptr(), - radius.contiguous().data_ptr(), - N, - P, - H, - W, - bin_size, - chunk_size, - M, - points_per_bin.contiguous().data_ptr(), - bin_points.contiguous().data_ptr()); - - AT_CUDA_CHECK(cudaGetLastError()); - return bin_points; -} - // **************************************************************************** // * FINE RASTERIZATION * // **************************************************************************** diff --git a/pytorch3d/csrc/rasterize_points/rasterize_points.h b/pytorch3d/csrc/rasterize_points/rasterize_points.h index e119bcb2..25ee9dcd 100644 --- a/pytorch3d/csrc/rasterize_points/rasterize_points.h +++ b/pytorch3d/csrc/rasterize_points/rasterize_points.h @@ -10,6 +10,7 @@ #include #include #include +#include "rasterize_coarse/rasterize_coarse.h" #include "utils/pytorch3d_cutils.h" // **************************************************************************** @@ -104,6 +105,8 @@ std::tuple RasterizePointsNaive( // * COARSE RASTERIZATION * // **************************************************************************** +// RasterizePointsCoarseCuda in rasterize_coarse/rasterize_coarse.h + torch::Tensor RasterizePointsCoarseCpu( const torch::Tensor& points, const torch::Tensor& cloud_to_packed_first_idx, @@ -113,16 +116,6 @@ torch::Tensor RasterizePointsCoarseCpu( const int bin_size, const int max_points_per_bin); -#ifdef WITH_CUDA -torch::Tensor RasterizePointsCoarseCuda( - const torch::Tensor& points, - const torch::Tensor& cloud_to_packed_first_idx, - const torch::Tensor& num_points_per_cloud, - const std::tuple image_size, - const torch::Tensor& radius, - const int bin_size, - const int max_points_per_bin); -#endif // Args: // points: Tensor of shape (P, 3) giving (packed) positions for // points in all N pointclouds in the batch where P is the total diff --git a/pytorch3d/csrc/utils/float_math.cuh b/pytorch3d/csrc/utils/float_math.cuh index b28518b8..1f7b2de1 100644 --- a/pytorch3d/csrc/utils/float_math.cuh +++ b/pytorch3d/csrc/utils/float_math.cuh @@ -42,6 +42,14 @@ __device__ inline float2 operator*(const float a, const float2& b) { return make_float2(a * b.x, a * b.y); } +__device__ inline float FloatMin3(const float a, const float b, const float c) { + return fminf(a, fminf(b, c)); +} + +__device__ inline float FloatMax3(const float a, const float b, const float c) { + return fmaxf(a, fmaxf(b, c)); +} + __device__ inline float dot(const float2& a, const float2& b) { return a.x * b.x + a.y * b.y; }