Implement K-Nearest Neighbors

Summary:
Implements K-Nearest Neighbors with C++ and CUDA versions.

KNN in CUDA is highly nontrivial. I've implemented a few different versions of the kernel, and we heuristically dispatch to different kernels based on the problem size. Some of the kernels rely on template specialization on either D or K, so we use template metaprogramming to compile specialized versions for ranges of D and K.

These kernels are up to 3x faster than our existing 1-nearest-neighbor kernels, so we should also consider swapping out `nn_points_idx` to use these kernels in the backend.

I've been working mostly on the CUDA kernels, and haven't converged on the correct Python API.

I still want to benchmark against FAISS to see how far away we are from their performance.

Reviewed By: bottler

Differential Revision: D19729286

fbshipit-source-id: 608ffbb7030c21fe4008f330522f4890f0c3c21a
This commit is contained in:
Justin Johnson 2020-03-26 13:37:32 -07:00 committed by Facebook GitHub Bot
parent 02d4968ee0
commit 870290df34
12 changed files with 1328 additions and 1 deletions

261
pytorch3d/csrc/dispatch.cuh Normal file
View File

@ -0,0 +1,261 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This file provides utilities for dispatching to specialized versions of functions.
// This is especially useful for CUDA kernels, since specializing them to particular
// input sizes can often allow the compiler to unroll loops and place arrays into
// registers, which can give huge performance speedups.
//
// As an example, suppose we have the following function which is specialized
// based on a compile-time int64_t value:
//
// template<typename T, int64_t x>
// struct SquareOffset {
// static void run(T y) {
// T val = x * x + y;
// std::cout << val << std::endl;
// }
// }
//
// This function takes one compile-time argument x, and one run-time argument y.
// We might want to compile specialized versions of this for x=0, x=1, etc and
// then dispatch to the correct one based on the runtime value of x.
// One simple way to achieve this is with a lookup table:
//
// template<typename T>
// void DispatchSquareOffset(const int64_t x, T y) {
// if (x == 0) {
// SquareOffset<T, 0>::run(y);
// } else if (x == 1) {
// SquareOffset<T, 1>::run(y);
// } else if (x == 2) {
// SquareOffset<T, 2>::run(y);
// }
// }
//
// This function takes both x and y as run-time arguments, and dispatches to
// different specialized versions of SquareOffset based on the run-time value
// of x. This works, but it's tedious and error-prone. If we want to change the
// set of x values for which we provide compile-time specializations, then we
// will need to do a lot of tedius editing of the dispatch function. Also, if we
// want to provide compile-time specializations for another function other than
// SquareOffset, we will need to duplicate the entire lookup table.
//
// To solve these problems, we can use the DispatchKernel1D function provided by
// this file instead:
//
// template<typename T>
// void DispatchSquareOffset(const int64_t x, T y) {
// constexpr int64_t xmin = 0;
// constexpr int64_t xmax = 2;
// DispatchKernel1D<SquareOffset, T, xmin, xmax>(x, y);
// }
//
// DispatchKernel1D uses template metaprogramming to compile specialized
// versions of SquareOffset for all values of x with xmin <= x <= xmax, and
// then dispatches to the correct one based on the run-time value of x. If we
// want to change the range of x values for which SquareOffset is specialized
// at compile-time, then all we have to do is change the values of the
// compile-time constants xmin and xmax.
//
// This file also allows us to similarly dispatch functions that depend on two
// compile-time int64_t values, using the DispatchKernel2D function like this:
//
// template<typename T, int64_t x, int64_t y>
// struct Sum {
// static void run(T z, T w) {
// T val = x + y + z + w;
// std::cout << val << std::endl;
// }
// }
//
// template<typename T>
// void DispatchSum(const int64_t x, const int64_t y, int z, int w) {
// constexpr int64_t xmin = 1;
// constexpr int64_t xmax = 3;
// constexpr int64_t ymin = 2;
// constexpr int64_t ymax = 5;
// DispatchKernel2D<Sum, T, xmin, xmax, ymin, ymax>(x, y, z, w);
// }
//
// Like its 1D counterpart, DispatchKernel2D uses template metaprogramming to
// compile specialized versions of sum for all values of (x, y) with
// xmin <= x <= xmax and ymin <= y <= ymax, then dispatches to the correct
// specialized version based on the runtime values of x and y.
// Define some helper structs in an anonymous namespace.
namespace {
// 1D dispatch: general case.
// Kernel is the function we want to dispatch to; it should take a typename and
// an int64_t as template args, and it should define a static void function
// run which takes any number of arguments of any type.
// In order to dispatch, we will take an additional template argument curN,
// and increment it via template recursion until it is equal to the run-time
// argument N.
template<
template<typename, int64_t> class Kernel,
typename T,
int64_t minN,
int64_t maxN,
int64_t curN,
typename... Args
>
struct DispatchKernelHelper1D {
static void run(const int64_t N, Args... args) {
if (curN == N) {
// The compile-time value curN is equal to the run-time value N, so we
// can dispatch to the run method of the Kernel.
Kernel<T, curN>::run(args...);
} else if (curN < N) {
// Increment curN via template recursion
DispatchKernelHelper1D<Kernel, T, minN, maxN, curN + 1, Args...>::run(N, args...);
}
// We shouldn't get here -- throw an error?
}
};
// 1D dispatch: Specialization when curN == maxN
// We need this base case to avoid infinite template recursion.
template<
template<typename, int64_t> class Kernel,
typename T,
int64_t minN,
int64_t maxN,
typename... Args
>
struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> {
static void run(const int64_t N, Args... args) {
if (N == maxN) {
Kernel<T, maxN>::run(args...);
}
// We shouldn't get here -- throw an error?
}
};
// 2D dispatch, general case.
// This is similar to the 1D case: we take additional template args curN and
// curM, and increment them via template recursion until they are equal to
// the run-time values of N and M, at which point we dispatch to the run
// method of the kernel.
template<
template<typename, int64_t, int64_t> class Kernel,
typename T,
int64_t minN, int64_t maxN, int64_t curN,
int64_t minM, int64_t maxM, int64_t curM,
typename... Args
>
struct DispatchKernelHelper2D {
static void run(const int64_t N, const int64_t M, Args... args) {
if (curN == N && curM == M) {
Kernel<T, curN, curM>::run(args...);
} else if (curN < N && curM < M) {
// Increment both curN and curM. This isn't strictly necessary; we could
// just increment one or the other at each step. But this helps to cut
// on the number of recursive calls we make.
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, curM + 1, Args...>::run(N, M, args...);
} else if (curN < N) {
// Increment curN only
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, curM, Args...>::run(N, M, args...);
} else if (curM < M) {
// Increment curM only
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN, minM, maxM, curM + 1, Args...>::run(N, M, args...);
}
}
};
// 2D dispatch, specialization for curN == maxN
template<
template<typename, int64_t, int64_t> class Kernel,
typename T,
int64_t minN, int64_t maxN,
int64_t minM, int64_t maxM, int64_t curM,
typename... Args
>
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, curM, Args...> {
static void run(const int64_t N, const int64_t M, Args... args) {
if (maxN == N && curM == M) {
Kernel<T, maxN, curM>::run(args...);
} else if (curM < maxM) {
DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, curM + 1, Args...>::run(N, M, args...);
}
// We should not get here -- throw an error?
}
};
// 2D dispatch, specialization for curM == maxM
template<
template<typename, int64_t, int64_t> class Kernel,
typename T,
int64_t minN, int64_t maxN, int64_t curN,
int64_t minM, int64_t maxM,
typename... Args
>
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, curN, minM, maxM, maxM, Args...> {
static void run(const int64_t N, const int64_t M, Args... args) {
if (curN == N && maxM == M) {
Kernel<T, curN, maxM>::run(args...);
} else if (curN < maxN) {
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, maxM, Args...>::run(N, M, args...);
}
// We should not get here -- throw an error?
}
};
// 2D dispatch, specialization for curN == maxN, curM == maxM
template<
template<typename, int64_t, int64_t> class Kernel,
typename T,
int64_t minN, int64_t maxN,
int64_t minM, int64_t maxM,
typename... Args
>
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, maxM, Args...> {
static void run(const int64_t N, const int64_t M, Args... args) {
if (maxN == N && maxM == M) {
Kernel<T, maxN, maxM>::run(args...);
}
// We should not get here -- throw an error?
}
};
} // namespace
// This is the function we expect users to call to dispatch to 1D functions
template<
template<typename, int64_t> class Kernel,
typename T,
int64_t minN,
int64_t maxN,
typename... Args
>
void DispatchKernel1D(const int64_t N, Args... args) {
if (minN <= N && N <= maxN) {
// Kick off the template recursion by calling the Helper with curN = minN
DispatchKernelHelper1D<Kernel, T, minN, maxN, minN, Args...>::run(N, args...);
}
// Maybe throw an error if we tried to dispatch outside the allowed range?
}
// This is the function we expect users to call to dispatch to 2D functions
template<
template<typename, int64_t, int64_t> class Kernel,
typename T,
int64_t minN, int64_t maxN,
int64_t minM, int64_t maxM,
typename... Args
>
void DispatchKernel2D(const int64_t N, const int64_t M, Args... args) {
if (minN <= N && N <= maxN && minM <= M && M <= maxM) {
// Kick off the template recursion by calling the Helper with curN = minN
// and curM = minM
DispatchKernelHelper2D<Kernel, T, minN, maxN, minN, minM, maxM, minM, Args...>::run(N, M, args...);
}
// Maybe throw an error if we tried to dispatch outside the specified range?
}

View File

@ -6,6 +6,7 @@
#include "compositing/weighted_sum.h"
#include "face_areas_normals/face_areas_normals.h"
#include "gather_scatter/gather_scatter.h"
#include "knn/knn.h"
#include "nearest_neighbor_points/nearest_neighbor_points.h"
#include "packed_to_padded_tensor/packed_to_padded_tensor.h"
#include "rasterize_meshes/rasterize_meshes.h"
@ -16,6 +17,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("face_areas_normals_backward", &FaceAreasNormalsBackward);
m.def("packed_to_padded", &PackedToPadded);
m.def("padded_to_packed", &PaddedToPacked);
m.def("knn_points_idx", &KNearestNeighborIdx);
m.def("nn_points_idx", &NearestNeighborIdx);
m.def("gather_scatter", &gather_scatter);
m.def("rasterize_points", &RasterizePoints);

View File

@ -0,0 +1,120 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
// This converts dynamic array lookups into static array lookups, for small
// arrays up to size 32.
//
// Suppose we have a small thread-local array:
//
// float vals[10];
//
// Ideally we should only index this array using static indices:
//
// for (int i = 0; i < 10; ++i) vals[i] = i * i;
//
// If we do so, then the CUDA compiler may be able to place the array into
// registers, which can have a big performance improvement. However if we
// access the array dynamically, the the compiler may force the array into
// local memory, which has the same latency as global memory.
//
// These functions convert dynamic array access into static array access
// using a brute-force lookup table. It can be used like this:
//
// float vals[10];
// int idx = 3;
// float val = 3.14f;
// RegisterIndexUtils<float, 10>::set(vals, idx, val);
// float val2 = RegisterIndexUtils<float, 10>::get(vals, idx);
//
// The implementation is based on fbcuda/RegisterUtils.cuh:
// https://github.com/facebook/fbcuda/blob/master/RegisterUtils.cuh
// To avoid depending on the entire library, we just reimplement these two
// functions. The fbcuda implementation is a bit more sophisticated, and uses
// the preprocessor to generate switch statements that go up to N for each
// value of N. We are lazy and just have a giant explicit switch statement.
//
// We might be able to use a template metaprogramming approach similar to
// DispatchKernel1D for this. However DispatchKernel1D is intended to be used
// for dispatching to the correct CUDA kernel on the host, while this is
// is intended to run on the device. I was concerned that a metaprogramming
// approach for this might lead to extra function calls at runtime if the
// compiler fails to optimize them away, which could be very slow on device.
// However I didn't actually benchmark or test this.
template<typename T, int N>
struct RegisterIndexUtils {
__device__ __forceinline__ static T get(const T arr[N], int idx) {
if (idx < 0 || idx >= N) return T();
switch (idx) {
case 0: return arr[0];
case 1: return arr[1];
case 2: return arr[2];
case 3: return arr[3];
case 4: return arr[4];
case 5: return arr[5];
case 6: return arr[6];
case 7: return arr[7];
case 8: return arr[8];
case 9: return arr[9];
case 10: return arr[10];
case 11: return arr[11];
case 12: return arr[12];
case 13: return arr[13];
case 14: return arr[14];
case 15: return arr[15];
case 16: return arr[16];
case 17: return arr[17];
case 18: return arr[18];
case 19: return arr[19];
case 20: return arr[20];
case 21: return arr[21];
case 22: return arr[22];
case 23: return arr[23];
case 24: return arr[24];
case 25: return arr[25];
case 26: return arr[26];
case 27: return arr[27];
case 28: return arr[28];
case 29: return arr[29];
case 30: return arr[30];
case 31: return arr[31];
};
return T();
}
__device__ __forceinline__ static void set(T arr[N], int idx, T val) {
if (idx < 0 || idx >= N) return;
switch (idx) {
case 0: arr[0] = val; break;
case 1: arr[1] = val; break;
case 2: arr[2] = val; break;
case 3: arr[3] = val; break;
case 4: arr[4] = val; break;
case 5: arr[5] = val; break;
case 6: arr[6] = val; break;
case 7: arr[7] = val; break;
case 8: arr[8] = val; break;
case 9: arr[9] = val; break;
case 10: arr[10] = val; break;
case 11: arr[11] = val; break;
case 12: arr[12] = val; break;
case 13: arr[13] = val; break;
case 14: arr[14] = val; break;
case 15: arr[15] = val; break;
case 16: arr[16] = val; break;
case 17: arr[17] = val; break;
case 18: arr[18] = val; break;
case 19: arr[19] = val; break;
case 20: arr[20] = val; break;
case 21: arr[21] = val; break;
case 22: arr[22] = val; break;
case 23: arr[23] = val; break;
case 24: arr[24] = val; break;
case 25: arr[25] = val; break;
case 26: arr[26] = val; break;
case 27: arr[27] = val; break;
case 28: arr[28] = val; break;
case 29: arr[29] = val; break;
case 30: arr[30] = val; break;
case 31: arr[31] = val; break;
}
}
};

369
pytorch3d/csrc/knn/knn.cu Normal file
View File

@ -0,0 +1,369 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h>
#include <float.h>
#include <iostream>
#include <tuple>
#include "dispatch.cuh"
#include "mink.cuh"
template <typename scalar_t>
__global__ void KNearestNeighborKernelV0(
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const size_t N,
const size_t P1,
const size_t P2,
const size_t D,
const size_t K) {
// Stupid version: Make each thread handle one query point and loop over
// all P2 target points. There are N * P1 input points to handle, so
// do a trivial parallelization over threads.
// Store both dists and indices for knn in global memory.
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int num_threads = blockDim.x * gridDim.x;
for (int np = tid; np < N * P1; np += num_threads) {
int n = np / P1;
int p1 = np % P1;
int offset = n * P1 * K + p1 * K;
MinK<scalar_t, int64_t> mink(dists + offset, idxs + offset, K);
for (int p2 = 0; p2 < P2; ++p2) {
// Find the distance between points1[n, p1] and points[n, p2]
scalar_t dist = 0;
for (int d = 0; d < D; ++d) {
scalar_t coord1 = points1[n * P1 * D + p1 * D + d];
scalar_t coord2 = points2[n * P2 * D + p2 * D + d];
scalar_t diff = coord1 - coord2;
dist += diff * diff;
}
mink.add(dist, p2);
}
}
}
template <typename scalar_t, int64_t D>
__global__ void KNearestNeighborKernelV1(
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const size_t N,
const size_t P1,
const size_t P2,
const size_t K) {
// Same idea as the previous version, but hoist D into a template argument
// so we can cache the current point in a thread-local array. We still store
// the current best K dists and indices in global memory, so this should work
// for very large K and fairly large D.
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int num_threads = blockDim.x * gridDim.x;
scalar_t cur_point[D];
for (int np = tid; np < N * P1; np += num_threads) {
int n = np / P1;
int p1 = np % P1;
for (int d = 0; d < D; ++d) {
cur_point[d] = points1[n * P1 * D + p1 * D + d];
}
int offset = n * P1 * K + p1 * K;
MinK<scalar_t, int64_t> mink(dists + offset, idxs + offset, K);
for (int p2 = 0; p2 < P2; ++p2) {
// Find the distance between cur_point and points[n, p2]
scalar_t dist = 0;
for (int d = 0; d < D; ++d) {
scalar_t diff = cur_point[d] - points2[n * P2 * D + p2 * D + d];
dist += diff * diff;
}
mink.add(dist, p2);
}
}
}
// This is a shim functor to allow us to dispatch using DispatchKernel1D
template <typename scalar_t, int64_t D>
struct KNearestNeighborV1Functor {
static void run(
size_t blocks,
size_t threads,
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const size_t N,
const size_t P1,
const size_t P2,
const size_t K) {
KNearestNeighborKernelV1<scalar_t, D>
<<<blocks, threads>>>(points1, points2, dists, idxs, N, P1, P2, K);
}
};
template <typename scalar_t, int64_t D, int64_t K>
__global__ void KNearestNeighborKernelV2(
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const int64_t N,
const int64_t P1,
const int64_t P2) {
// Same general implementation as V2, but also hoist K into a template arg.
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int num_threads = blockDim.x * gridDim.x;
scalar_t cur_point[D];
scalar_t min_dists[K];
int min_idxs[K];
for (int np = tid; np < N * P1; np += num_threads) {
int n = np / P1;
int p1 = np % P1;
for (int d = 0; d < D; ++d) {
cur_point[d] = points1[n * P1 * D + p1 * D + d];
}
MinK<scalar_t, int> mink(min_dists, min_idxs, K);
for (int p2 = 0; p2 < P2; ++p2) {
scalar_t dist = 0;
for (int d = 0; d < D; ++d) {
int offset = n * P2 * D + p2 * D + d;
scalar_t diff = cur_point[d] - points2[offset];
dist += diff * diff;
}
mink.add(dist, p2);
}
for (int k = 0; k < mink.size(); ++k) {
idxs[n * P1 * K + p1 * K + k] = min_idxs[k];
dists[n * P1 * K + p1 * K + k] = min_dists[k];
}
}
}
// This is a shim so we can dispatch using DispatchKernel2D
template <typename scalar_t, int64_t D, int64_t K>
struct KNearestNeighborKernelV2Functor {
static void run(
size_t blocks,
size_t threads,
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const int64_t N,
const int64_t P1,
const int64_t P2) {
KNearestNeighborKernelV2<scalar_t, D, K>
<<<blocks, threads>>>(points1, points2, dists, idxs, N, P1, P2);
}
};
template <typename scalar_t, int D, int K>
__global__ void KNearestNeighborKernelV3(
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const size_t N,
const size_t P1,
const size_t P2) {
// Same idea as V2, but use register indexing for thread-local arrays.
// Enabling sorting for this version leads to huge slowdowns; I suspect
// that it forces min_dists into local memory rather than registers.
// As a result this version is always unsorted.
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int num_threads = blockDim.x * gridDim.x;
scalar_t cur_point[D];
scalar_t min_dists[K];
int min_idxs[K];
for (int np = tid; np < N * P1; np += num_threads) {
int n = np / P1;
int p1 = np % P1;
for (int d = 0; d < D; ++d) {
cur_point[d] = points1[n * P1 * D + p1 * D + d];
}
RegisterMinK<scalar_t, int, K> mink(min_dists, min_idxs);
for (int p2 = 0; p2 < P2; ++p2) {
scalar_t dist = 0;
for (int d = 0; d < D; ++d) {
int offset = n * P2 * D + p2 * D + d;
scalar_t diff = cur_point[d] - points2[offset];
dist += diff * diff;
}
mink.add(dist, p2);
}
for (int k = 0; k < mink.size(); ++k) {
idxs[n * P1 * K + p1 * K + k] = min_idxs[k];
dists[n * P1 * K + p1 * K + k] = min_dists[k];
}
}
}
// This is a shim so we can dispatch using DispatchKernel2D
template <typename scalar_t, int64_t D, int64_t K>
struct KNearestNeighborKernelV3Functor {
static void run(
size_t blocks,
size_t threads,
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const size_t N,
const size_t P1,
const size_t P2) {
KNearestNeighborKernelV3<scalar_t, D, K>
<<<blocks, threads>>>(points1, points2, dists, idxs, N, P1, P2);
}
};
constexpr int V1_MIN_D = 1;
constexpr int V1_MAX_D = 32;
constexpr int V2_MIN_D = 1;
constexpr int V2_MAX_D = 8;
constexpr int V2_MIN_K = 1;
constexpr int V2_MAX_K = 32;
constexpr int V3_MIN_D = 1;
constexpr int V3_MAX_D = 8;
constexpr int V3_MIN_K = 1;
constexpr int V3_MAX_K = 4;
bool InBounds(const int64_t min, const int64_t x, const int64_t max) {
return min <= x && x <= max;
}
bool CheckVersion(int version, const int64_t D, const int64_t K) {
if (version == 0) {
return true;
} else if (version == 1) {
return InBounds(V1_MIN_D, D, V1_MAX_D);
} else if (version == 2) {
return InBounds(V2_MIN_D, D, V2_MAX_D) && InBounds(V2_MIN_K, K, V2_MAX_K);
} else if (version == 3) {
return InBounds(V3_MIN_D, D, V3_MAX_D) && InBounds(V3_MIN_K, K, V3_MAX_K);
}
return false;
}
int ChooseVersion(const int64_t D, const int64_t K) {
for (int version = 3; version >= 1; version--) {
if (CheckVersion(version, D, K)) {
return version;
}
}
return 0;
}
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
const at::Tensor& p1,
const at::Tensor& p2,
int K,
int version) {
const auto N = p1.size(0);
const auto P1 = p1.size(1);
const auto P2 = p2.size(1);
const auto D = p2.size(2);
const int64_t K_64 = K;
AT_ASSERTM(p2.size(2) == D, "Point sets must have the same last dimension");
auto long_dtype = p1.options().dtype(at::kLong);
auto idxs = at::full({N, P1, K}, -1, long_dtype);
auto dists = at::full({N, P1, K}, -1, p1.options());
if (version < 0) {
version = ChooseVersion(D, K);
} else if (!CheckVersion(version, D, K)) {
int new_version = ChooseVersion(D, K);
std::cout << "WARNING: Requested KNN version " << version
<< " is not compatible with D = " << D << "; K = " << K
<< ". Falling back to version = " << new_version << std::endl;
version = new_version;
}
// At this point we should have a valid version no matter what data the user
// gave us. But we can check once more to be sure; however this time
// assert fail since failing at this point means we have a bug in our version
// selection or checking code.
AT_ASSERTM(CheckVersion(version, D, K), "Invalid version");
const size_t threads = 256;
const size_t blocks = 256;
if (version == 0) {
AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] {
KNearestNeighborKernelV0<scalar_t>
<<<blocks, threads>>>(
p1.data_ptr<scalar_t>(),
p2.data_ptr<scalar_t>(),
dists.data_ptr<scalar_t>(),
idxs.data_ptr<int64_t>(),
N,
P1,
P2,
D,
K);
}));
} else if (version == 1) {
AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] {
DispatchKernel1D<
KNearestNeighborV1Functor,
scalar_t,
V1_MIN_D,
V1_MAX_D>(
D,
blocks,
threads,
p1.data_ptr<scalar_t>(),
p2.data_ptr<scalar_t>(),
dists.data_ptr<scalar_t>(),
idxs.data_ptr<int64_t>(),
N,
P1,
P2,
K);
}));
} else if (version == 2) {
AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] {
DispatchKernel2D<
KNearestNeighborKernelV2Functor,
scalar_t,
V2_MIN_D,
V2_MAX_D,
V2_MIN_K,
V2_MAX_K>(
D,
K_64,
blocks,
threads,
p1.data_ptr<scalar_t>(),
p2.data_ptr<scalar_t>(),
dists.data_ptr<scalar_t>(),
idxs.data_ptr<int64_t>(),
N,
P1,
P2);
}));
} else if (version == 3) {
AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] {
DispatchKernel2D<
KNearestNeighborKernelV3Functor,
scalar_t,
V3_MIN_D,
V3_MAX_D,
V3_MIN_K,
V3_MAX_K>(
D,
K_64,
blocks,
threads,
p1.data_ptr<scalar_t>(),
p2.data_ptr<scalar_t>(),
dists.data_ptr<scalar_t>(),
idxs.data_ptr<int64_t>(),
N,
P1,
P2);
}));
}
return std::make_tuple(idxs, dists);
}

54
pytorch3d/csrc/knn/knn.h Normal file
View File

@ -0,0 +1,54 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#include <torch/extension.h>
#include <tuple>
#include "pytorch3d_cutils.h"
// Compute indices of K nearest neighbors in pointcloud p2 to points
// in pointcloud p1.
//
// Args:
// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
// containing P1 points of dimension D.
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
// containing P2 points of dimension D.
// K: int giving the number of nearest points to return.
// sorted: bool telling whether to sort the K returned points by their
// distance version: Integer telling which implementation to use.
// TODO(jcjohns): Document this more, or maybe remove it before
// landing.
//
// Returns:
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
// p1_neighbor_idx[n, i, k] = j means that the kth nearest
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
// CPU implementation.
std::tuple<at::Tensor, at::Tensor>
KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K);
// CUDA implementation
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
const at::Tensor& p1,
const at::Tensor& p2,
int K,
int version);
// Implementation which is exposed.
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
const at::Tensor& p1,
const at::Tensor& p2,
int K,
int version) {
if (p1.type().is_cuda() || p2.type().is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(p1);
CHECK_CONTIGUOUS_CUDA(p2);
return KNearestNeighborIdxCuda(p1, p2, K, version);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
return KNearestNeighborIdxCpu(p1, p2, K);
}

View File

@ -0,0 +1,52 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
#include <queue>
#include <tuple>
std::tuple<at::Tensor, at::Tensor>
KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K) {
const int N = p1.size(0);
const int P1 = p1.size(1);
const int D = p1.size(2);
const int P2 = p2.size(1);
auto long_opts = p1.options().dtype(torch::kInt64);
torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts);
torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());
auto p1_a = p1.accessor<float, 3>();
auto p2_a = p2.accessor<float, 3>();
auto idxs_a = idxs.accessor<int64_t, 3>();
auto dists_a = dists.accessor<float, 3>();
for (int n = 0; n < N; ++n) {
for (int i1 = 0; i1 < P1; ++i1) {
// Use a priority queue to store (distance, index) tuples.
std::priority_queue<std::tuple<float, int>> q;
for (int i2 = 0; i2 < P2; ++i2) {
float dist = 0;
for (int d = 0; d < D; ++d) {
float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
dist += diff * diff;
}
int size = static_cast<int>(q.size());
if (size < K || dist < std::get<0>(q.top())) {
q.emplace(dist, i2);
if (size >= K) {
q.pop();
}
}
}
while (!q.empty()) {
auto t = q.top();
q.pop();
const int k = q.size();
dists_a[n][i1][k] = std::get<0>(t);
idxs_a[n][i1][k] = std::get<1>(t);
}
}
}
return std::make_tuple(idxs, dists);
}

162
pytorch3d/csrc/mink.cuh Normal file
View File

@ -0,0 +1,162 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#define MINK_H
#include "index_utils.cuh"
// A data structure to keep track of the smallest K keys seen so far as well
// as their associated values, intended to be used in device code.
// This data structure doesn't allocate any memory; keys and values are stored
// in arrays passed to the constructor.
//
// The implementation is generic; it can be used for any key type that supports
// the < operator, and can be used with any value type.
//
// Example usage:
//
// float keys[K];
// int values[K];
// MinK<float, int> mink(keys, values, K);
// for (...) {
// // Produce some key and value from somewhere
// mink.add(key, value);
// }
// mink.sort();
//
// Now keys and values store the smallest K keys seen so far and the values
// associated to these keys:
//
// for (int k = 0; k < K; ++k) {
// float key_k = keys[k];
// int value_k = values[k];
// }
template<typename key_t, typename value_t>
class MinK {
public:
// Constructor.
//
// Arguments:
// keys: Array in which to store keys
// values: Array in which to store values
// K: How many values to keep track of
__device__ MinK(key_t *keys, value_t *vals, int K) :
keys(keys), vals(vals), K(K), _size(0) { }
// Try to add a new key and associated value to the data structure. If the key
// is one of the smallest K seen so far then it will be kept; otherwise it
// it will not be kept.
//
// This takes O(1) operations if the new key is not kept, or if the structure
// currently contains fewer than K elements. Otherwise this takes O(K) time.
//
// Arguments:
// key: The key to add
// val: The value associated to the key
__device__ __forceinline__ void add(const key_t &key, const value_t &val) {
if (_size < K) {
keys[_size] = key;
vals[_size] = val;
if (_size == 0 || key > max_key) {
max_key = key;
max_idx = _size;
}
_size++;
} else if (key < max_key) {
keys[max_idx] = key;
vals[max_idx] = val;
max_key = key;
for (int k = 0; k < K; ++k) {
key_t cur_key = keys[k];
if (cur_key > max_key) {
max_key = cur_key;
max_idx = k;
}
}
}
}
// Get the number of items currently stored in the structure.
// This takes O(1) time.
__device__ __forceinline__ int size() {
return _size;
}
// Sort the items stored in the structure using bubble sort.
// This takes O(K^2) time.
__device__ __forceinline__ void sort() {
for (int i = 0; i < _size - 1; ++i) {
for (int j = 0; j < _size - i - 1; ++j) {
if (keys[j + 1] < keys[j]) {
key_t key = keys[j];
value_t val = vals[j];
keys[j] = keys[j + 1];
vals[j] = vals[j + 1];
keys[j + 1] = key;
vals[j + 1] = val;
}
}
}
}
private:
key_t *keys;
value_t *vals;
int K;
int _size;
key_t max_key;
int max_idx;
};
// This is a version of MinK that only touches the arrays using static indexing
// via RegisterIndexUtils. If the keys and values are stored in thread-local
// arrays, then this may allow the compiler to place them in registers for
// fast access.
//
// This has the same API as RegisterMinK, but doesn't support sorting.
// We found that sorting via RegisterIndexUtils gave very poor performance,
// and suspect it may have prevented the compiler from placing the arrays
// into registers.
template<typename key_t, typename value_t, int K>
class RegisterMinK {
public:
__device__ RegisterMinK(key_t *keys, value_t *vals) :
keys(keys), vals(vals), _size(0) {}
__device__ __forceinline__ void add(const key_t &key, const value_t &val) {
if (_size < K) {
RegisterIndexUtils<key_t, K>::set(keys, _size, key);
RegisterIndexUtils<value_t, K>::set(vals, _size, val);
if (_size == 0 || key > max_key) {
max_key = key;
max_idx = _size;
}
_size++;
} else if (key < max_key) {
RegisterIndexUtils<key_t, K>::set(keys, max_idx, key);
RegisterIndexUtils<value_t, K>::set(vals, max_idx, val);
max_key = key;
for (int k = 0; k < K; ++k) {
key_t cur_key = RegisterIndexUtils<key_t, K>::get(keys, k);
if (cur_key > max_key) {
max_key = cur_key;
max_idx = k;
}
}
}
}
__device__ __forceinline__ int size() {
return _size;
}
private:
key_t *keys;
value_t *vals;
int _size;
key_t max_key;
int max_idx;
};

View File

@ -39,4 +39,4 @@ at::Tensor NearestNeighborIdx(at::Tensor p1, at::Tensor p2) {
#endif
}
return NearestNeighborIdxCpu(p1, p2);
};
}

67
pytorch3d/ops/knn.py Normal file
View File

@ -0,0 +1,67 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
from pytorch3d import _C
def knn_points_idx(p1, p2, K, sorted=False, version=-1):
"""
K-Nearest neighbors on point clouds.
Args:
p1: Tensor of shape (N, P1, D) giving a batch of point clouds, each
containing P1 points of dimension D.
p2: Tensor of shape (N, P2, D) giving a batch of point clouds, each
containing P2 points of dimension D.
K: Integer giving the number of nearest neighbors to return
sorted: Whether to sort the resulting points.
version: Which KNN implementation to use in the backend. If version=-1,
the correct implementation is selected based on the shapes of
the inputs.
Returns:
idx: LongTensor of shape (N, P1, K) giving the indices of the
K nearest neighbors from points in p1 to points in p2.
Concretely, if idx[n, i, k] = j then p2[n, j] is one of the K
nearest neighbor to p1[n, i] in p2[n]. If sorted=True, then
p2[n, j] is the kth nearest neighbor to p1[n, i].
"""
idx, dists = _C.knn_points_idx(p1, p2, K, version)
if sorted:
dists, sort_idx = dists.sort(dim=2)
idx = idx.gather(2, sort_idx)
return idx, dists
@torch.no_grad()
def _knn_points_idx_naive(p1, p2, K, sorted=False) -> torch.Tensor:
"""
Naive PyTorch implementation of K-Nearest Neighbors.
This is much less efficient than _C.knn_points_idx, but we include this
naive implementation for testing and benchmarking.
Args:
p1: Tensor of shape (N, P1, D) giving a batch of point clouds, each
containing P1 points of dimension D.
p2: Tensor of shape (N, P2, D) giving a batch of point clouds, each
containing P2 points of dimension D.
K: Integer giving the number of nearest neighbors to return
sorted: Whether to sort the resulting points.
Returns:
idx: LongTensor of shape (N, P1, K) giving the indices of the
K nearest neighbors from points in p1 to points in p2.
Concretely, if idx[n, i, k] = j then p2[n, j] is one of the K
nearest neighbor to p1[n, i] in p2[n]. If sorted=True, then
p2[n, j] is the kth nearest neighbor to p1[n, i].
dists: Tensor of shape (N, P1, K) giving the distances to the nearest
neighbors.
"""
N, P1, D = p1.shape
_N, P2, _D = p2.shape
assert N == _N and D == _D
diffs = p1.view(N, P1, 1, D) - p2.view(N, 1, P2, D)
dists2 = (diffs * diffs).sum(dim=3)
out = dists2.topk(K, dim=2, largest=False, sorted=sorted)
return out.indices, out.values

174
tests/bm_knn.py Normal file
View File

@ -0,0 +1,174 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from itertools import product
import torch
from fvcore.common.benchmark import benchmark
from pytorch3d import _C
from pytorch3d.ops.knn import _knn_points_idx_naive
def bm_knn() -> None:
""" Entry point for the benchmark """
benchmark_knn_cpu()
benchmark_knn_cuda_vs_naive()
benchmark_knn_cuda_versions()
def benchmark_knn_cuda_versions() -> None:
# Compare our different KNN implementations,
# and also compare against our existing 1-NN
Ns = [1, 2]
Ps = [4096, 16384]
Ds = [3]
Ks = [1, 4, 16, 64]
versions = [0, 1, 2, 3]
knn_kwargs, nn_kwargs = [], []
for N, P, D, K, version in product(Ns, Ps, Ds, Ks, versions):
if version == 2 and K > 32:
continue
if version == 3 and K > 4:
continue
knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K, 'v': version})
for N, P, D in product(Ns, Ps, Ds):
nn_kwargs.append({'N': N, 'D': D, 'P': P})
benchmark(
knn_cuda_with_init,
'KNN_CUDA_VERSIONS',
knn_kwargs,
warmup_iters=1,
)
benchmark(
nn_cuda_with_init,
'NN_CUDA',
nn_kwargs,
warmup_iters=1,
)
def benchmark_knn_cuda_vs_naive() -> None:
# Compare against naive pytorch version of KNN
Ns = [1, 2, 4]
Ps = [1024, 4096, 16384, 65536]
Ds = [3]
Ks = [1, 2, 4, 8, 16]
knn_kwargs, naive_kwargs = [], []
for N, P, D, K in product(Ns, Ps, Ds, Ks):
knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K})
if P <= 4096:
naive_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K})
benchmark(
knn_python_cuda_with_init,
'KNN_CUDA_PYTHON',
naive_kwargs,
warmup_iters=1,
)
benchmark(
knn_cuda_with_init,
'KNN_CUDA',
knn_kwargs,
warmup_iters=1,
)
def benchmark_knn_cpu() -> None:
Ns = [1, 2]
Ps = [256, 512]
Ds = [3]
Ks = [1, 2, 4]
knn_kwargs, nn_kwargs = [], []
for N, P, D, K in product(Ns, Ps, Ds, Ks):
knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K})
for N, P, D in product(Ns, Ps, Ds):
nn_kwargs.append({'N': N, 'D': D, 'P': P})
benchmark(
knn_python_cpu_with_init,
'KNN_CPU_PYTHON',
knn_kwargs,
warmup_iters=1,
)
benchmark(
knn_cpu_with_init,
'KNN_CPU_CPP',
knn_kwargs,
warmup_iters=1,
)
benchmark(
nn_cpu_with_init,
'NN_CPU_CPP',
nn_kwargs,
warmup_iters=1,
)
def knn_cuda_with_init(N, D, P, K, v=-1):
device = torch.device('cuda:0')
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
torch.cuda.synchronize()
def knn():
_C.knn_points_idx(x, y, K, v)
torch.cuda.synchronize()
return knn
def knn_cpu_with_init(N, D, P, K):
device = torch.device('cpu')
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
def knn():
_C.knn_points_idx(x, y, K, 0)
return knn
def knn_python_cuda_with_init(N, D, P, K):
device = torch.device('cuda')
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
torch.cuda.synchronize()
def knn():
_knn_points_idx_naive(x, y, K)
torch.cuda.synchronize()
return knn
def knn_python_cpu_with_init(N, D, P, K):
device = torch.device('cpu')
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
def knn():
_knn_points_idx_naive(x, y, K)
return knn
def nn_cuda_with_init(N, D, P):
device = torch.device('cuda')
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
torch.cuda.synchronize()
def knn():
_C.nn_points_idx(x, y)
torch.cuda.synchronize()
return knn
def nn_cpu_with_init(N, D, P):
device = torch.device('cpu')
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
def knn():
_C.nn_points_idx(x, y)
return knn

65
tests/test_knn.py Normal file
View File

@ -0,0 +1,65 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
from itertools import product
import torch
from pytorch3d.ops.knn import _knn_points_idx_naive, knn_points_idx
class TestKNN(unittest.TestCase):
def _check_knn_result(self, out1, out2, sorted):
# When sorted=True, points should be sorted by distance and should
# match between implementations. When sorted=False we we only want to
# check that we got the same set of indices, so we sort the indices by
# index value.
idx1, dist1 = out1
idx2, dist2 = out2
if not sorted:
idx1 = idx1.sort(dim=2).values
idx2 = idx2.sort(dim=2).values
dist1 = dist1.sort(dim=2).values
dist2 = dist2.sort(dim=2).values
if not torch.all(idx1 == idx2):
print(idx1)
print(idx2)
self.assertTrue(torch.all(idx1 == idx2))
self.assertTrue(torch.allclose(dist1, dist2))
def test_knn_vs_python_cpu(self):
""" Test CPU output vs PyTorch implementation """
device = torch.device('cpu')
Ns = [1, 4]
Ds = [2, 3]
P1s = [1, 10, 101]
P2s = [10, 101]
Ks = [1, 3, 10]
sorts = [True, False]
factors = [Ns, Ds, P1s, P2s, Ks, sorts]
for N, D, P1, P2, K, sort in product(*factors):
x = torch.randn(N, P1, D, device=device)
y = torch.randn(N, P2, D, device=device)
out1 = _knn_points_idx_naive(x, y, K, sort)
out2 = knn_points_idx(x, y, K, sort)
self._check_knn_result(out1, out2, sort)
def test_knn_vs_python_cuda(self):
""" Test CUDA output vs PyTorch implementation """
device = torch.device('cuda')
Ns = [1, 4]
Ds = [2, 3, 8]
P1s = [1, 8, 64, 128, 1001]
P2s = [32, 128, 513]
Ks = [1, 3, 10]
sorts = [True, False]
versions = [0, 1, 2, 3]
factors = [Ns, Ds, P1s, P2s, Ks, sorts]
for N, D, P1, P2, K, sort in product(*factors):
x = torch.randn(N, P1, D, device=device)
y = torch.randn(N, P2, D, device=device)
out1 = _knn_points_idx_naive(x, y, K, sorted=sort)
for version in versions:
if version == 3 and K > 4:
continue
out2 = knn_points_idx(x, y, K, sort, version)
self._check_knn_result(out1, out2, sort)

View File

@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
from itertools import product
import torch
from pytorch3d import _C