mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-03 02:35:58 +08:00
Move coarse rasterization to new file
Summary: In preparation for sharing coarse rasterization between point clouds and meshes, move the functions to a new file. No code changes. Reviewed By: bottler Differential Revision: D30367812 fbshipit-source-id: 9e73835a26c4ac91f5c9f61ff682bc8218e36c6a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
f2c44e3540
commit
62dbf371ae
@@ -1,79 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#define BINMASK_H
|
||||
|
||||
// A BitMask represents a bool array of shape (H, W, N). We pack values into
|
||||
// the bits of unsigned ints; a single unsigned int has B = 32 bits, so to hold
|
||||
// all values we use H * W * (N / B) = H * W * D values. We want to store
|
||||
// BitMasks in shared memory, so we assume that the memory has already been
|
||||
// allocated for it elsewhere.
|
||||
class BitMask {
|
||||
public:
|
||||
__device__ BitMask(unsigned int* data, int H, int W, int N)
|
||||
: data(data), H(H), W(W), B(8 * sizeof(unsigned int)), D(N / B) {
|
||||
// TODO: check if the data is null.
|
||||
N = ceilf(N % 32); // take ceil incase N % 32 != 0
|
||||
block_clear(); // clear the data
|
||||
}
|
||||
|
||||
// Use all threads in the current block to clear all bits of this BitMask
|
||||
__device__ void block_clear() {
|
||||
for (int i = threadIdx.x; i < H * W * D; i += blockDim.x) {
|
||||
data[i] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
__device__ int _get_elem_idx(int y, int x, int d) {
|
||||
return y * W * D + x * D + d / B;
|
||||
}
|
||||
|
||||
__device__ int _get_bit_idx(int d) {
|
||||
return d % B;
|
||||
}
|
||||
|
||||
// Turn on a single bit (y, x, d)
|
||||
__device__ void set(int y, int x, int d) {
|
||||
int elem_idx = _get_elem_idx(y, x, d);
|
||||
int bit_idx = _get_bit_idx(d);
|
||||
const unsigned int mask = 1U << bit_idx;
|
||||
atomicOr(data + elem_idx, mask);
|
||||
}
|
||||
|
||||
// Turn off a single bit (y, x, d)
|
||||
__device__ void unset(int y, int x, int d) {
|
||||
int elem_idx = _get_elem_idx(y, x, d);
|
||||
int bit_idx = _get_bit_idx(d);
|
||||
const unsigned int mask = ~(1U << bit_idx);
|
||||
atomicAnd(data + elem_idx, mask);
|
||||
}
|
||||
|
||||
// Check whether the bit (y, x, d) is on or off
|
||||
__device__ bool get(int y, int x, int d) {
|
||||
int elem_idx = _get_elem_idx(y, x, d);
|
||||
int bit_idx = _get_bit_idx(d);
|
||||
return (data[elem_idx] >> bit_idx) & 1U;
|
||||
}
|
||||
|
||||
// Compute the number of bits set in the row (y, x, :)
|
||||
__device__ int count(int y, int x) {
|
||||
int total = 0;
|
||||
for (int i = 0; i < D; ++i) {
|
||||
int elem_idx = y * W * D + x * D + i;
|
||||
unsigned int elem = data[elem_idx];
|
||||
total += __popc(elem);
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
private:
|
||||
unsigned int* data;
|
||||
int H, W, B, D;
|
||||
};
|
||||
@@ -13,7 +13,6 @@
|
||||
#include <cstdio>
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
#include "rasterize_points/bitmask.cuh"
|
||||
#include "rasterize_points/rasterization_utils.cuh"
|
||||
|
||||
namespace {
|
||||
@@ -217,231 +216,6 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
|
||||
return std::make_tuple(point_idxs, zbuf, pix_dists);
|
||||
}
|
||||
|
||||
// ****************************************************************************
|
||||
// * COARSE RASTERIZATION *
|
||||
// ****************************************************************************
|
||||
|
||||
__global__ void RasterizePointsCoarseCudaKernel(
|
||||
const float* points, // (P, 3)
|
||||
const int64_t* cloud_to_packed_first_idx, // (N)
|
||||
const int64_t* num_points_per_cloud, // (N)
|
||||
const float* radius,
|
||||
const int N,
|
||||
const int P,
|
||||
const int H,
|
||||
const int W,
|
||||
const int bin_size,
|
||||
const int chunk_size,
|
||||
const int max_points_per_bin,
|
||||
int* points_per_bin,
|
||||
int* bin_points) {
|
||||
extern __shared__ char sbuf[];
|
||||
const int M = max_points_per_bin;
|
||||
|
||||
// Integer divide round up
|
||||
const int num_bins_x = 1 + (W - 1) / bin_size;
|
||||
const int num_bins_y = 1 + (H - 1) / bin_size;
|
||||
|
||||
// NDC range depends on the ratio of W/H
|
||||
// The shorter side from (H, W) is given an NDC range of 2.0 and
|
||||
// the other side is scaled by the ratio of H:W.
|
||||
const float NDC_x_half_range = NonSquareNdcRange(W, H) / 2.0f;
|
||||
const float NDC_y_half_range = NonSquareNdcRange(H, W) / 2.0f;
|
||||
|
||||
// Size of half a pixel in NDC units is the NDC half range
|
||||
// divided by the corresponding image dimension
|
||||
const float half_pix_x = NDC_x_half_range / W;
|
||||
const float half_pix_y = NDC_y_half_range / H;
|
||||
|
||||
// This is a boolean array of shape (num_bins_y, num_bins_x, chunk_size)
|
||||
// stored in shared memory that will track whether each point in the chunk
|
||||
// falls into each bin of the image.
|
||||
BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size);
|
||||
|
||||
// Have each block handle a chunk of points and build a 3D bitmask in
|
||||
// shared memory to mark which points hit which bins. In this first phase,
|
||||
// each thread processes one point at a time. After processing the chunk,
|
||||
// one thread is assigned per bin, and the thread counts and writes the
|
||||
// points for the bin out to global memory.
|
||||
const int chunks_per_batch = 1 + (P - 1) / chunk_size;
|
||||
const int num_chunks = N * chunks_per_batch;
|
||||
for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
|
||||
const int batch_idx = chunk / chunks_per_batch;
|
||||
const int chunk_idx = chunk % chunks_per_batch;
|
||||
const int point_start_idx = chunk_idx * chunk_size;
|
||||
|
||||
binmask.block_clear();
|
||||
|
||||
// Using the batch index of the thread get the start and stop
|
||||
// indices for the points.
|
||||
const int64_t cloud_point_start_idx = cloud_to_packed_first_idx[batch_idx];
|
||||
const int64_t cloud_point_stop_idx =
|
||||
cloud_point_start_idx + num_points_per_cloud[batch_idx];
|
||||
|
||||
// Have each thread handle a different point within the chunk
|
||||
for (int p = threadIdx.x; p < chunk_size; p += blockDim.x) {
|
||||
const int p_idx = point_start_idx + p;
|
||||
|
||||
// Check if point index corresponds to the cloud in the batch given by
|
||||
// batch_idx.
|
||||
if (p_idx >= cloud_point_stop_idx || p_idx < cloud_point_start_idx) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const float px = points[p_idx * 3 + 0];
|
||||
const float py = points[p_idx * 3 + 1];
|
||||
const float pz = points[p_idx * 3 + 2];
|
||||
const float p_radius = radius[p_idx];
|
||||
if (pz < 0)
|
||||
continue; // Don't render points behind the camera.
|
||||
const float px0 = px - p_radius;
|
||||
const float px1 = px + p_radius;
|
||||
const float py0 = py - p_radius;
|
||||
const float py1 = py + p_radius;
|
||||
|
||||
// Brute-force search over all bins; TODO something smarter?
|
||||
// For example we could compute the exact bin where the point falls,
|
||||
// then check neighboring bins. This way we wouldn't have to check
|
||||
// all bins (however then we might have more warp divergence?)
|
||||
for (int by = 0; by < num_bins_y; ++by) {
|
||||
// Get y extent for the bin. PixToNonSquareNdc gives us the location of
|
||||
// the center of each pixel, so we need to add/subtract a half
|
||||
// pixel to get the true extent of the bin.
|
||||
const float by0 = PixToNonSquareNdc(by * bin_size, H, W) - half_pix_y;
|
||||
const float by1 =
|
||||
PixToNonSquareNdc((by + 1) * bin_size - 1, H, W) + half_pix_y;
|
||||
const bool y_overlap = (py0 <= by1) && (by0 <= py1);
|
||||
|
||||
if (!y_overlap) {
|
||||
continue;
|
||||
}
|
||||
for (int bx = 0; bx < num_bins_x; ++bx) {
|
||||
// Get x extent for the bin; again we need to adjust the
|
||||
// output of PixToNonSquareNdc by half a pixel.
|
||||
const float bx0 = PixToNonSquareNdc(bx * bin_size, W, H) - half_pix_x;
|
||||
const float bx1 =
|
||||
PixToNonSquareNdc((bx + 1) * bin_size - 1, W, H) + half_pix_x;
|
||||
const bool x_overlap = (px0 <= bx1) && (bx0 <= px1);
|
||||
|
||||
if (x_overlap) {
|
||||
binmask.set(by, bx, p);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
// Now we have processed every point in the current chunk. We need to
|
||||
// count the number of points in each bin so we can write the indices
|
||||
// out to global memory. We have each thread handle a different bin.
|
||||
for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x;
|
||||
byx += blockDim.x) {
|
||||
const int by = byx / num_bins_x;
|
||||
const int bx = byx % num_bins_x;
|
||||
const int count = binmask.count(by, bx);
|
||||
const int points_per_bin_idx =
|
||||
batch_idx * num_bins_y * num_bins_x + by * num_bins_x + bx;
|
||||
|
||||
// This atomically increments the (global) number of points found
|
||||
// in the current bin, and gets the previous value of the counter;
|
||||
// this effectively allocates space in the bin_points array for the
|
||||
// points in the current chunk that fall into this bin.
|
||||
const int start = atomicAdd(points_per_bin + points_per_bin_idx, count);
|
||||
|
||||
// Now loop over the binmask and write the active bits for this bin
|
||||
// out to bin_points.
|
||||
int next_idx = batch_idx * num_bins_y * num_bins_x * M +
|
||||
by * num_bins_x * M + bx * M + start;
|
||||
for (int p = 0; p < chunk_size; ++p) {
|
||||
if (binmask.get(by, bx, p)) {
|
||||
// TODO: Throw an error if next_idx >= M -- this means that
|
||||
// we got more than max_points_per_bin in this bin
|
||||
// TODO: check if atomicAdd is needed in line 265.
|
||||
bin_points[next_idx] = point_start_idx + p;
|
||||
next_idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor RasterizePointsCoarseCuda(
|
||||
const at::Tensor& points, // (P, 3)
|
||||
const at::Tensor& cloud_to_packed_first_idx, // (N)
|
||||
const at::Tensor& num_points_per_cloud, // (N)
|
||||
const std::tuple<int, int> image_size,
|
||||
const at::Tensor& radius,
|
||||
const int bin_size,
|
||||
const int max_points_per_bin) {
|
||||
TORCH_CHECK(
|
||||
points.ndimension() == 2 && points.size(1) == 3,
|
||||
"points must have dimensions (num_points, 3)");
|
||||
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1},
|
||||
cloud_to_packed_first_idx_t{
|
||||
cloud_to_packed_first_idx, "cloud_to_packed_first_idx", 2},
|
||||
num_points_per_cloud_t{num_points_per_cloud, "num_points_per_cloud", 3};
|
||||
at::CheckedFrom c = "RasterizePointsCoarseCuda";
|
||||
at::checkAllSameGPU(
|
||||
c, {points_t, cloud_to_packed_first_idx_t, num_points_per_cloud_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int H = std::get<0>(image_size);
|
||||
const int W = std::get<1>(image_size);
|
||||
|
||||
const int P = points.size(0);
|
||||
const int N = num_points_per_cloud.size(0);
|
||||
const int M = max_points_per_bin;
|
||||
|
||||
// Integer divide round up.
|
||||
const int num_bins_y = 1 + (H - 1) / bin_size;
|
||||
const int num_bins_x = 1 + (W - 1) / bin_size;
|
||||
|
||||
if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) {
|
||||
// Make sure we do not use too much shared memory.
|
||||
std::stringstream ss;
|
||||
ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y
|
||||
<< ", num_bins_x: " << num_bins_x << ", "
|
||||
<< "; that's too many!";
|
||||
AT_ERROR(ss.str());
|
||||
}
|
||||
auto opts = num_points_per_cloud.options().dtype(at::kInt);
|
||||
at::Tensor points_per_bin = at::zeros({N, num_bins_y, num_bins_x}, opts);
|
||||
at::Tensor bin_points = at::full({N, num_bins_y, num_bins_x, M}, -1, opts);
|
||||
|
||||
if (bin_points.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return bin_points;
|
||||
}
|
||||
|
||||
const int chunk_size = 512;
|
||||
const size_t shared_size = num_bins_y * num_bins_x * chunk_size / 8;
|
||||
const size_t blocks = 64;
|
||||
const size_t threads = 512;
|
||||
|
||||
RasterizePointsCoarseCudaKernel<<<blocks, threads, shared_size, stream>>>(
|
||||
points.contiguous().data_ptr<float>(),
|
||||
cloud_to_packed_first_idx.contiguous().data_ptr<int64_t>(),
|
||||
num_points_per_cloud.contiguous().data_ptr<int64_t>(),
|
||||
radius.contiguous().data_ptr<float>(),
|
||||
N,
|
||||
P,
|
||||
H,
|
||||
W,
|
||||
bin_size,
|
||||
chunk_size,
|
||||
M,
|
||||
points_per_bin.contiguous().data_ptr<int32_t>(),
|
||||
bin_points.contiguous().data_ptr<int32_t>());
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return bin_points;
|
||||
}
|
||||
|
||||
// ****************************************************************************
|
||||
// * FINE RASTERIZATION *
|
||||
// ****************************************************************************
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <torch/extension.h>
|
||||
#include <cstdio>
|
||||
#include <tuple>
|
||||
#include "rasterize_coarse/rasterize_coarse.h"
|
||||
#include "utils/pytorch3d_cutils.h"
|
||||
|
||||
// ****************************************************************************
|
||||
@@ -104,6 +105,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaive(
|
||||
// * COARSE RASTERIZATION *
|
||||
// ****************************************************************************
|
||||
|
||||
// RasterizePointsCoarseCuda in rasterize_coarse/rasterize_coarse.h
|
||||
|
||||
torch::Tensor RasterizePointsCoarseCpu(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& cloud_to_packed_first_idx,
|
||||
@@ -113,16 +116,6 @@ torch::Tensor RasterizePointsCoarseCpu(
|
||||
const int bin_size,
|
||||
const int max_points_per_bin);
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
torch::Tensor RasterizePointsCoarseCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& cloud_to_packed_first_idx,
|
||||
const torch::Tensor& num_points_per_cloud,
|
||||
const std::tuple<int, int> image_size,
|
||||
const torch::Tensor& radius,
|
||||
const int bin_size,
|
||||
const int max_points_per_bin);
|
||||
#endif
|
||||
// Args:
|
||||
// points: Tensor of shape (P, 3) giving (packed) positions for
|
||||
// points in all N pointclouds in the batch where P is the total
|
||||
|
||||
Reference in New Issue
Block a user