mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Replacing custom CUDA block reductions with CUB in sample_farthest_points
Summary: Removing hardcoded block reduction operation from `sample_farthest_points.cu` code, and replace it with `cub::BlockReduce` reducing complexity of the code, and letting established libraries do the thinking for us. Reviewed By: bottler Differential Revision: D38617147 fbshipit-source-id: b230029c55f05cda0aab1648d3105a8d3e92d27b
This commit is contained in:
		
							parent
							
								
									597bc7c7f6
								
							
						
					
					
						commit
						8ea4da2938
					
				@ -12,6 +12,7 @@
 | 
			
		||||
#include <math.h>
 | 
			
		||||
#include <stdio.h>
 | 
			
		||||
#include <stdlib.h>
 | 
			
		||||
#include <cub/cub.cuh>
 | 
			
		||||
#include "utils/warp_reduce.cuh"
 | 
			
		||||
 | 
			
		||||
template <unsigned int block_size>
 | 
			
		||||
@ -25,20 +26,19 @@ __global__ void FarthestPointSamplingKernel(
 | 
			
		||||
    const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits> start_idxs
 | 
			
		||||
    // clang-format on
 | 
			
		||||
) {
 | 
			
		||||
  typedef cub::BlockReduce<
 | 
			
		||||
      cub::KeyValuePair<int64_t, float>,
 | 
			
		||||
      block_size,
 | 
			
		||||
      cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY>
 | 
			
		||||
      BlockReduce;
 | 
			
		||||
  __shared__ typename BlockReduce::TempStorage temp_storage;
 | 
			
		||||
  __shared__ int64_t selected_store;
 | 
			
		||||
 | 
			
		||||
  // Get constants
 | 
			
		||||
  const int64_t N = points.size(0);
 | 
			
		||||
  const int64_t P = points.size(1);
 | 
			
		||||
  const int64_t D = points.size(2);
 | 
			
		||||
 | 
			
		||||
  // Create single shared memory buffer which is split and cast to different
 | 
			
		||||
  // types: dists/dists_idx are used to save the maximum distances seen by the
 | 
			
		||||
  // points processed by any one thread and the associated point indices.
 | 
			
		||||
  // These values only need to be accessed by other threads in this block which
 | 
			
		||||
  // are processing the same batch and not by other blocks.
 | 
			
		||||
  extern __shared__ char shared_buf[];
 | 
			
		||||
  float* dists = (float*)shared_buf; // block_size floats
 | 
			
		||||
  int64_t* dists_idx = (int64_t*)&dists[block_size]; // block_size int64_t
 | 
			
		||||
 | 
			
		||||
  // Get batch index and thread index
 | 
			
		||||
  const int64_t batch_idx = blockIdx.x;
 | 
			
		||||
  const size_t tid = threadIdx.x;
 | 
			
		||||
@ -82,43 +82,26 @@ __global__ void FarthestPointSamplingKernel(
 | 
			
		||||
      max_dist = (p_min_dist > max_dist) ? p_min_dist : max_dist;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // After going through all points for this thread, save the max
 | 
			
		||||
    // point and idx seen by this thread. Each thread sees P/block_size points.
 | 
			
		||||
    dists[tid] = max_dist;
 | 
			
		||||
    dists_idx[tid] = max_dist_idx;
 | 
			
		||||
    // Sync to ensure all threads in the block have updated their max point.
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
    // Parallelized block reduction to find the max point seen by
 | 
			
		||||
    // all the threads in this block for iteration k.
 | 
			
		||||
    // Each block represents one batch element so we can use a divide/conquer
 | 
			
		||||
    // approach to find the max, syncing all threads after each step.
 | 
			
		||||
 | 
			
		||||
    for (int s = block_size / 2; s > 0; s >>= 1) {
 | 
			
		||||
      if (tid < s) {
 | 
			
		||||
        // Compare the best point seen by two threads and update the shared
 | 
			
		||||
        // memory at the location of the first thread index with the max out
 | 
			
		||||
        // of the two threads.
 | 
			
		||||
        if (dists[tid] < dists[tid + s]) {
 | 
			
		||||
          dists[tid] = dists[tid + s];
 | 
			
		||||
          dists_idx[tid] = dists_idx[tid + s];
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      __syncthreads();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // TODO(nikhilar): As reduction proceeds, the number of “active” threads
 | 
			
		||||
    // decreases. When tid < 32, there should only be one warp left which could
 | 
			
		||||
    // be unrolled.
 | 
			
		||||
 | 
			
		||||
    // The overall max after reducing will be saved
 | 
			
		||||
    // at the location of tid = 0.
 | 
			
		||||
    selected = dists_idx[0];
 | 
			
		||||
    // max_dist, max_dist_idx are now the max point and idx seen by this thread.
 | 
			
		||||
    // Now find the index corresponding to the maximum distance seen by any
 | 
			
		||||
    // thread. (This value is only on thread 0.)
 | 
			
		||||
    selected =
 | 
			
		||||
        BlockReduce(temp_storage)
 | 
			
		||||
            .Reduce(
 | 
			
		||||
                cub::KeyValuePair<int64_t, float>(max_dist_idx, max_dist),
 | 
			
		||||
                cub::ArgMax(),
 | 
			
		||||
                block_size)
 | 
			
		||||
            .key;
 | 
			
		||||
 | 
			
		||||
    if (tid == 0) {
 | 
			
		||||
      // Write the farthest point for iteration k to global memory
 | 
			
		||||
      idxs[batch_idx][k] = selected;
 | 
			
		||||
      selected_store = selected;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Ensure `selected` in all threads equals the global maximum.
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
    selected = selected_store;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -185,15 +168,8 @@ at::Tensor FarthestPointSamplingCuda(
 | 
			
		||||
  auto min_point_dist_a =
 | 
			
		||||
      min_point_dist.packed_accessor64<float, 2, at::RestrictPtrTraits>();
 | 
			
		||||
 | 
			
		||||
  // Initialize the shared memory which will be used to store the
 | 
			
		||||
  // distance/index of the best point seen by each thread.
 | 
			
		||||
  size_t shared_mem = threads * sizeof(float) + threads * sizeof(int64_t);
 | 
			
		||||
  // TODO: using shared memory for min_point_dist gives an ~2x speed up
 | 
			
		||||
  // compared to using a global (N, P) shaped tensor, however for
 | 
			
		||||
  // larger pointclouds this may exceed the shared memory limit per block.
 | 
			
		||||
  // If a speed up is required for smaller pointclouds, then the storage
 | 
			
		||||
  // could be switched to shared memory if the required total shared memory is
 | 
			
		||||
  // within the memory limit per block.
 | 
			
		||||
  // TempStorage for the reduction uses static shared memory only.
 | 
			
		||||
  size_t shared_mem = 0;
 | 
			
		||||
 | 
			
		||||
  // Support a case for all powers of 2 up to MAX_THREADS_PER_BLOCK possible per
 | 
			
		||||
  // block.
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user