mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Replacing custom CUDA block reductions with CUB in sample_farthest_points
Summary: Removing hardcoded block reduction operation from `sample_farthest_points.cu` code, and replace it with `cub::BlockReduce` reducing complexity of the code, and letting established libraries do the thinking for us. Reviewed By: bottler Differential Revision: D38617147 fbshipit-source-id: b230029c55f05cda0aab1648d3105a8d3e92d27b
This commit is contained in:
parent
597bc7c7f6
commit
8ea4da2938
@ -12,6 +12,7 @@
|
|||||||
#include <math.h>
|
#include <math.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
|
#include <cub/cub.cuh>
|
||||||
#include "utils/warp_reduce.cuh"
|
#include "utils/warp_reduce.cuh"
|
||||||
|
|
||||||
template <unsigned int block_size>
|
template <unsigned int block_size>
|
||||||
@ -25,20 +26,19 @@ __global__ void FarthestPointSamplingKernel(
|
|||||||
const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits> start_idxs
|
const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits> start_idxs
|
||||||
// clang-format on
|
// clang-format on
|
||||||
) {
|
) {
|
||||||
|
typedef cub::BlockReduce<
|
||||||
|
cub::KeyValuePair<int64_t, float>,
|
||||||
|
block_size,
|
||||||
|
cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY>
|
||||||
|
BlockReduce;
|
||||||
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||||
|
__shared__ int64_t selected_store;
|
||||||
|
|
||||||
// Get constants
|
// Get constants
|
||||||
const int64_t N = points.size(0);
|
const int64_t N = points.size(0);
|
||||||
const int64_t P = points.size(1);
|
const int64_t P = points.size(1);
|
||||||
const int64_t D = points.size(2);
|
const int64_t D = points.size(2);
|
||||||
|
|
||||||
// Create single shared memory buffer which is split and cast to different
|
|
||||||
// types: dists/dists_idx are used to save the maximum distances seen by the
|
|
||||||
// points processed by any one thread and the associated point indices.
|
|
||||||
// These values only need to be accessed by other threads in this block which
|
|
||||||
// are processing the same batch and not by other blocks.
|
|
||||||
extern __shared__ char shared_buf[];
|
|
||||||
float* dists = (float*)shared_buf; // block_size floats
|
|
||||||
int64_t* dists_idx = (int64_t*)&dists[block_size]; // block_size int64_t
|
|
||||||
|
|
||||||
// Get batch index and thread index
|
// Get batch index and thread index
|
||||||
const int64_t batch_idx = blockIdx.x;
|
const int64_t batch_idx = blockIdx.x;
|
||||||
const size_t tid = threadIdx.x;
|
const size_t tid = threadIdx.x;
|
||||||
@ -82,43 +82,26 @@ __global__ void FarthestPointSamplingKernel(
|
|||||||
max_dist = (p_min_dist > max_dist) ? p_min_dist : max_dist;
|
max_dist = (p_min_dist > max_dist) ? p_min_dist : max_dist;
|
||||||
}
|
}
|
||||||
|
|
||||||
// After going through all points for this thread, save the max
|
// max_dist, max_dist_idx are now the max point and idx seen by this thread.
|
||||||
// point and idx seen by this thread. Each thread sees P/block_size points.
|
// Now find the index corresponding to the maximum distance seen by any
|
||||||
dists[tid] = max_dist;
|
// thread. (This value is only on thread 0.)
|
||||||
dists_idx[tid] = max_dist_idx;
|
selected =
|
||||||
// Sync to ensure all threads in the block have updated their max point.
|
BlockReduce(temp_storage)
|
||||||
__syncthreads();
|
.Reduce(
|
||||||
|
cub::KeyValuePair<int64_t, float>(max_dist_idx, max_dist),
|
||||||
// Parallelized block reduction to find the max point seen by
|
cub::ArgMax(),
|
||||||
// all the threads in this block for iteration k.
|
block_size)
|
||||||
// Each block represents one batch element so we can use a divide/conquer
|
.key;
|
||||||
// approach to find the max, syncing all threads after each step.
|
|
||||||
|
|
||||||
for (int s = block_size / 2; s > 0; s >>= 1) {
|
|
||||||
if (tid < s) {
|
|
||||||
// Compare the best point seen by two threads and update the shared
|
|
||||||
// memory at the location of the first thread index with the max out
|
|
||||||
// of the two threads.
|
|
||||||
if (dists[tid] < dists[tid + s]) {
|
|
||||||
dists[tid] = dists[tid + s];
|
|
||||||
dists_idx[tid] = dists_idx[tid + s];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(nikhilar): As reduction proceeds, the number of “active” threads
|
|
||||||
// decreases. When tid < 32, there should only be one warp left which could
|
|
||||||
// be unrolled.
|
|
||||||
|
|
||||||
// The overall max after reducing will be saved
|
|
||||||
// at the location of tid = 0.
|
|
||||||
selected = dists_idx[0];
|
|
||||||
|
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
// Write the farthest point for iteration k to global memory
|
// Write the farthest point for iteration k to global memory
|
||||||
idxs[batch_idx][k] = selected;
|
idxs[batch_idx][k] = selected;
|
||||||
|
selected_store = selected;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure `selected` in all threads equals the global maximum.
|
||||||
|
__syncthreads();
|
||||||
|
selected = selected_store;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -185,15 +168,8 @@ at::Tensor FarthestPointSamplingCuda(
|
|||||||
auto min_point_dist_a =
|
auto min_point_dist_a =
|
||||||
min_point_dist.packed_accessor64<float, 2, at::RestrictPtrTraits>();
|
min_point_dist.packed_accessor64<float, 2, at::RestrictPtrTraits>();
|
||||||
|
|
||||||
// Initialize the shared memory which will be used to store the
|
// TempStorage for the reduction uses static shared memory only.
|
||||||
// distance/index of the best point seen by each thread.
|
size_t shared_mem = 0;
|
||||||
size_t shared_mem = threads * sizeof(float) + threads * sizeof(int64_t);
|
|
||||||
// TODO: using shared memory for min_point_dist gives an ~2x speed up
|
|
||||||
// compared to using a global (N, P) shaped tensor, however for
|
|
||||||
// larger pointclouds this may exceed the shared memory limit per block.
|
|
||||||
// If a speed up is required for smaller pointclouds, then the storage
|
|
||||||
// could be switched to shared memory if the required total shared memory is
|
|
||||||
// within the memory limit per block.
|
|
||||||
|
|
||||||
// Support a case for all powers of 2 up to MAX_THREADS_PER_BLOCK possible per
|
// Support a case for all powers of 2 up to MAX_THREADS_PER_BLOCK possible per
|
||||||
// block.
|
// block.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user