diff --git a/pytorch3d/csrc/rasterize_coarse/rasterize_coarse.cu b/pytorch3d/csrc/rasterize_coarse/rasterize_coarse.cu index 0f49e279..5cc76577 100644 --- a/pytorch3d/csrc/rasterize_coarse/rasterize_coarse.cu +++ b/pytorch3d/csrc/rasterize_coarse/rasterize_coarse.cu @@ -50,6 +50,29 @@ __global__ void TriangleBoundingBoxKernel( } } +__global__ void PointBoundingBoxKernel( + const float* points, // (P, 3) + const float* radius, // (P,) + const int P, + float* bboxes, // (4, P) + bool* skip_points) { + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + const int num_threads = blockDim.x * gridDim.x; + for (int p = tid; p < P; p += num_threads) { + const float x = points[p * 3 + 0]; + const float y = points[p * 3 + 1]; + const float z = points[p * 3 + 2]; + const float r = radius[p]; + // TODO: change to kEpsilon to match triangles? + const bool skip = z < 0; + bboxes[0 * P + p] = x - r; + bboxes[1 * P + p] = x + r; + bboxes[2 * P + p] = y - r; + bboxes[3 * P + p] = y + r; + skip_points[p] = skip; + } +} + __global__ void RasterizeCoarseCudaKernel( const float* bboxes, // (4, E) (xmin, xmax, ymin, ymax) const bool* should_skip, // (E,) @@ -242,150 +265,6 @@ at::Tensor RasterizeCoarseCuda( return bin_elems; } -__global__ void RasterizePointsCoarseCudaKernel( - const float* points, // (P, 3) - const int64_t* cloud_to_packed_first_idx, // (N) - const int64_t* num_points_per_cloud, // (N) - const float* radius, - const int N, - const int P, - const int H, - const int W, - const int bin_size, - const int chunk_size, - const int max_points_per_bin, - int* points_per_bin, - int* bin_points) { - extern __shared__ char sbuf[]; - const int M = max_points_per_bin; - - // Integer divide round up - const int num_bins_x = 1 + (W - 1) / bin_size; - const int num_bins_y = 1 + (H - 1) / bin_size; - - // NDC range depends on the ratio of W/H - // The shorter side from (H, W) is given an NDC range of 2.0 and - // the other side is scaled by the ratio of H:W. - const float NDC_x_half_range = NonSquareNdcRange(W, H) / 2.0f; - const float NDC_y_half_range = NonSquareNdcRange(H, W) / 2.0f; - - // Size of half a pixel in NDC units is the NDC half range - // divided by the corresponding image dimension - const float half_pix_x = NDC_x_half_range / W; - const float half_pix_y = NDC_y_half_range / H; - - // This is a boolean array of shape (num_bins_y, num_bins_x, chunk_size) - // stored in shared memory that will track whether each point in the chunk - // falls into each bin of the image. - BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size); - - // Have each block handle a chunk of points and build a 3D bitmask in - // shared memory to mark which points hit which bins. In this first phase, - // each thread processes one point at a time. After processing the chunk, - // one thread is assigned per bin, and the thread counts and writes the - // points for the bin out to global memory. - const int chunks_per_batch = 1 + (P - 1) / chunk_size; - const int num_chunks = N * chunks_per_batch; - for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) { - const int batch_idx = chunk / chunks_per_batch; - const int chunk_idx = chunk % chunks_per_batch; - const int point_start_idx = chunk_idx * chunk_size; - - binmask.block_clear(); - - // Using the batch index of the thread get the start and stop - // indices for the points. - const int64_t cloud_point_start_idx = cloud_to_packed_first_idx[batch_idx]; - const int64_t cloud_point_stop_idx = - cloud_point_start_idx + num_points_per_cloud[batch_idx]; - - // Have each thread handle a different point within the chunk - for (int p = threadIdx.x; p < chunk_size; p += blockDim.x) { - const int p_idx = point_start_idx + p; - - // Check if point index corresponds to the cloud in the batch given by - // batch_idx. - if (p_idx >= cloud_point_stop_idx || p_idx < cloud_point_start_idx) { - continue; - } - - const float px = points[p_idx * 3 + 0]; - const float py = points[p_idx * 3 + 1]; - const float pz = points[p_idx * 3 + 2]; - const float p_radius = radius[p_idx]; - if (pz < 0) - continue; // Don't render points behind the camera. - const float px0 = px - p_radius; - const float px1 = px + p_radius; - const float py0 = py - p_radius; - const float py1 = py + p_radius; - - // Brute-force search over all bins; TODO something smarter? - // For example we could compute the exact bin where the point falls, - // then check neighboring bins. This way we wouldn't have to check - // all bins (however then we might have more warp divergence?) - for (int by = 0; by < num_bins_y; ++by) { - // Get y extent for the bin. PixToNonSquareNdc gives us the location of - // the center of each pixel, so we need to add/subtract a half - // pixel to get the true extent of the bin. - const float by0 = PixToNonSquareNdc(by * bin_size, H, W) - half_pix_y; - const float by1 = - PixToNonSquareNdc((by + 1) * bin_size - 1, H, W) + half_pix_y; - const bool y_overlap = (py0 <= by1) && (by0 <= py1); - - if (!y_overlap) { - continue; - } - for (int bx = 0; bx < num_bins_x; ++bx) { - // Get x extent for the bin; again we need to adjust the - // output of PixToNonSquareNdc by half a pixel. - const float bx0 = PixToNonSquareNdc(bx * bin_size, W, H) - half_pix_x; - const float bx1 = - PixToNonSquareNdc((bx + 1) * bin_size - 1, W, H) + half_pix_x; - const bool x_overlap = (px0 <= bx1) && (bx0 <= px1); - - if (x_overlap) { - binmask.set(by, bx, p); - } - } - } - } - __syncthreads(); - // Now we have processed every point in the current chunk. We need to - // count the number of points in each bin so we can write the indices - // out to global memory. We have each thread handle a different bin. - for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x; - byx += blockDim.x) { - const int by = byx / num_bins_x; - const int bx = byx % num_bins_x; - const int count = binmask.count(by, bx); - const int points_per_bin_idx = - batch_idx * num_bins_y * num_bins_x + by * num_bins_x + bx; - - // This atomically increments the (global) number of points found - // in the current bin, and gets the previous value of the counter; - // this effectively allocates space in the bin_points array for the - // points in the current chunk that fall into this bin. - const int start = atomicAdd(points_per_bin + points_per_bin_idx, count); - - // Now loop over the binmask and write the active bits for this bin - // out to bin_points. - int next_idx = batch_idx * num_bins_y * num_bins_x * M + - by * num_bins_x * M + bx * M + start; - for (int p = 0; p < chunk_size; ++p) { - if (binmask.get(by, bx, p)) { - // TODO: Throw an error if next_idx >= M -- this means that - // we got more than max_points_per_bin in this bin - // TODO: check if atomicAdd is needed in line 265. - bin_points[next_idx] = point_start_idx + p; - next_idx++; - } - } - } - __syncthreads(); - } -} - at::Tensor RasterizeMeshesCoarseCuda( const at::Tensor& face_verts, const at::Tensor& mesh_to_face_first_idx, @@ -442,8 +321,8 @@ at::Tensor RasterizeMeshesCoarseCuda( at::Tensor RasterizePointsCoarseCuda( const at::Tensor& points, // (P, 3) - const at::Tensor& cloud_to_packed_first_idx, // (N) - const at::Tensor& num_points_per_cloud, // (N) + const at::Tensor& cloud_to_packed_first_idx, // (N,) + const at::Tensor& num_points_per_cloud, // (N,) const std::tuple image_size, const at::Tensor& radius, const int bin_size, @@ -465,54 +344,30 @@ at::Tensor RasterizePointsCoarseCuda( at::cuda::CUDAGuard device_guard(points.device()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - const int H = std::get<0>(image_size); - const int W = std::get<1>(image_size); - + // Allocate tensors for bboxes and should_skip const int P = points.size(0); - const int N = num_points_per_cloud.size(0); - const int M = max_points_per_bin; + auto float_opts = points.options().dtype(at::kFloat); + auto bool_opts = points.options().dtype(at::kBool); + at::Tensor bboxes = at::empty({4, P}, float_opts); + at::Tensor should_skip = at::empty({P}, bool_opts); - // Integer divide round up. - const int num_bins_y = 1 + (H - 1) / bin_size; - const int num_bins_x = 1 + (W - 1) / bin_size; - - if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) { - // Make sure we do not use too much shared memory. - std::stringstream ss; - ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y - << ", num_bins_x: " << num_bins_x << ", " - << "; that's too many!"; - AT_ERROR(ss.str()); - } - auto opts = num_points_per_cloud.options().dtype(at::kInt); - at::Tensor points_per_bin = at::zeros({N, num_bins_y, num_bins_x}, opts); - at::Tensor bin_points = at::full({N, num_bins_y, num_bins_x, M}, -1, opts); - - if (bin_points.numel() == 0) { - AT_CUDA_CHECK(cudaGetLastError()); - return bin_points; - } - - const int chunk_size = 512; - const size_t shared_size = num_bins_y * num_bins_x * chunk_size / 8; - const size_t blocks = 64; - const size_t threads = 512; - - RasterizePointsCoarseCudaKernel<<>>( + // Launch kernel to compute point bboxes + const size_t blocks = 128; + const size_t threads = 256; + PointBoundingBoxKernel<<>>( points.contiguous().data_ptr(), - cloud_to_packed_first_idx.contiguous().data_ptr(), - num_points_per_cloud.contiguous().data_ptr(), radius.contiguous().data_ptr(), - N, P, - H, - W, - bin_size, - chunk_size, - M, - points_per_bin.contiguous().data_ptr(), - bin_points.contiguous().data_ptr()); - + bboxes.contiguous().data_ptr(), + should_skip.contiguous().data_ptr()); AT_CUDA_CHECK(cudaGetLastError()); - return bin_points; + + return RasterizeCoarseCuda( + bboxes, + should_skip, + cloud_to_packed_first_idx, + num_points_per_cloud, + image_size, + bin_size, + max_points_per_bin); }