mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
Summary: Implements K-Nearest Neighbors with C++ and CUDA versions. KNN in CUDA is highly nontrivial. I've implemented a few different versions of the kernel, and we heuristically dispatch to different kernels based on the problem size. Some of the kernels rely on template specialization on either D or K, so we use template metaprogramming to compile specialized versions for ranges of D and K. These kernels are up to 3x faster than our existing 1-nearest-neighbor kernels, so we should also consider swapping out `nn_points_idx` to use these kernels in the backend. I've been working mostly on the CUDA kernels, and haven't converged on the correct Python API. I still want to benchmark against FAISS to see how far away we are from their performance. Reviewed By: bottler Differential Revision: D19729286 fbshipit-source-id: 608ffbb7030c21fe4008f330522f4890f0c3c21a
121 lines
4.3 KiB
Plaintext
121 lines
4.3 KiB
Plaintext
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
|
|
// This converts dynamic array lookups into static array lookups, for small
|
|
// arrays up to size 32.
|
|
//
|
|
// Suppose we have a small thread-local array:
|
|
//
|
|
// float vals[10];
|
|
//
|
|
// Ideally we should only index this array using static indices:
|
|
//
|
|
// for (int i = 0; i < 10; ++i) vals[i] = i * i;
|
|
//
|
|
// If we do so, then the CUDA compiler may be able to place the array into
|
|
// registers, which can have a big performance improvement. However if we
|
|
// access the array dynamically, the the compiler may force the array into
|
|
// local memory, which has the same latency as global memory.
|
|
//
|
|
// These functions convert dynamic array access into static array access
|
|
// using a brute-force lookup table. It can be used like this:
|
|
//
|
|
// float vals[10];
|
|
// int idx = 3;
|
|
// float val = 3.14f;
|
|
// RegisterIndexUtils<float, 10>::set(vals, idx, val);
|
|
// float val2 = RegisterIndexUtils<float, 10>::get(vals, idx);
|
|
//
|
|
// The implementation is based on fbcuda/RegisterUtils.cuh:
|
|
// https://github.com/facebook/fbcuda/blob/master/RegisterUtils.cuh
|
|
// To avoid depending on the entire library, we just reimplement these two
|
|
// functions. The fbcuda implementation is a bit more sophisticated, and uses
|
|
// the preprocessor to generate switch statements that go up to N for each
|
|
// value of N. We are lazy and just have a giant explicit switch statement.
|
|
//
|
|
// We might be able to use a template metaprogramming approach similar to
|
|
// DispatchKernel1D for this. However DispatchKernel1D is intended to be used
|
|
// for dispatching to the correct CUDA kernel on the host, while this is
|
|
// is intended to run on the device. I was concerned that a metaprogramming
|
|
// approach for this might lead to extra function calls at runtime if the
|
|
// compiler fails to optimize them away, which could be very slow on device.
|
|
// However I didn't actually benchmark or test this.
|
|
template<typename T, int N>
|
|
struct RegisterIndexUtils {
|
|
__device__ __forceinline__ static T get(const T arr[N], int idx) {
|
|
if (idx < 0 || idx >= N) return T();
|
|
switch (idx) {
|
|
case 0: return arr[0];
|
|
case 1: return arr[1];
|
|
case 2: return arr[2];
|
|
case 3: return arr[3];
|
|
case 4: return arr[4];
|
|
case 5: return arr[5];
|
|
case 6: return arr[6];
|
|
case 7: return arr[7];
|
|
case 8: return arr[8];
|
|
case 9: return arr[9];
|
|
case 10: return arr[10];
|
|
case 11: return arr[11];
|
|
case 12: return arr[12];
|
|
case 13: return arr[13];
|
|
case 14: return arr[14];
|
|
case 15: return arr[15];
|
|
case 16: return arr[16];
|
|
case 17: return arr[17];
|
|
case 18: return arr[18];
|
|
case 19: return arr[19];
|
|
case 20: return arr[20];
|
|
case 21: return arr[21];
|
|
case 22: return arr[22];
|
|
case 23: return arr[23];
|
|
case 24: return arr[24];
|
|
case 25: return arr[25];
|
|
case 26: return arr[26];
|
|
case 27: return arr[27];
|
|
case 28: return arr[28];
|
|
case 29: return arr[29];
|
|
case 30: return arr[30];
|
|
case 31: return arr[31];
|
|
};
|
|
return T();
|
|
}
|
|
|
|
__device__ __forceinline__ static void set(T arr[N], int idx, T val) {
|
|
if (idx < 0 || idx >= N) return;
|
|
switch (idx) {
|
|
case 0: arr[0] = val; break;
|
|
case 1: arr[1] = val; break;
|
|
case 2: arr[2] = val; break;
|
|
case 3: arr[3] = val; break;
|
|
case 4: arr[4] = val; break;
|
|
case 5: arr[5] = val; break;
|
|
case 6: arr[6] = val; break;
|
|
case 7: arr[7] = val; break;
|
|
case 8: arr[8] = val; break;
|
|
case 9: arr[9] = val; break;
|
|
case 10: arr[10] = val; break;
|
|
case 11: arr[11] = val; break;
|
|
case 12: arr[12] = val; break;
|
|
case 13: arr[13] = val; break;
|
|
case 14: arr[14] = val; break;
|
|
case 15: arr[15] = val; break;
|
|
case 16: arr[16] = val; break;
|
|
case 17: arr[17] = val; break;
|
|
case 18: arr[18] = val; break;
|
|
case 19: arr[19] = val; break;
|
|
case 20: arr[20] = val; break;
|
|
case 21: arr[21] = val; break;
|
|
case 22: arr[22] = val; break;
|
|
case 23: arr[23] = val; break;
|
|
case 24: arr[24] = val; break;
|
|
case 25: arr[25] = val; break;
|
|
case 26: arr[26] = val; break;
|
|
case 27: arr[27] = val; break;
|
|
case 28: arr[28] = val; break;
|
|
case 29: arr[29] = val; break;
|
|
case 30: arr[30] = val; break;
|
|
case 31: arr[31] = val; break;
|
|
}
|
|
}
|
|
};
|