mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Implement K-Nearest Neighbors
Summary: Implements K-Nearest Neighbors with C++ and CUDA versions. KNN in CUDA is highly nontrivial. I've implemented a few different versions of the kernel, and we heuristically dispatch to different kernels based on the problem size. Some of the kernels rely on template specialization on either D or K, so we use template metaprogramming to compile specialized versions for ranges of D and K. These kernels are up to 3x faster than our existing 1-nearest-neighbor kernels, so we should also consider swapping out `nn_points_idx` to use these kernels in the backend. I've been working mostly on the CUDA kernels, and haven't converged on the correct Python API. I still want to benchmark against FAISS to see how far away we are from their performance. Reviewed By: bottler Differential Revision: D19729286 fbshipit-source-id: 608ffbb7030c21fe4008f330522f4890f0c3c21a
This commit is contained in:
parent
02d4968ee0
commit
870290df34
261
pytorch3d/csrc/dispatch.cuh
Normal file
261
pytorch3d/csrc/dispatch.cuh
Normal file
@ -0,0 +1,261 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
//
|
||||
// This file provides utilities for dispatching to specialized versions of functions.
|
||||
// This is especially useful for CUDA kernels, since specializing them to particular
|
||||
// input sizes can often allow the compiler to unroll loops and place arrays into
|
||||
// registers, which can give huge performance speedups.
|
||||
//
|
||||
// As an example, suppose we have the following function which is specialized
|
||||
// based on a compile-time int64_t value:
|
||||
//
|
||||
// template<typename T, int64_t x>
|
||||
// struct SquareOffset {
|
||||
// static void run(T y) {
|
||||
// T val = x * x + y;
|
||||
// std::cout << val << std::endl;
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// This function takes one compile-time argument x, and one run-time argument y.
|
||||
// We might want to compile specialized versions of this for x=0, x=1, etc and
|
||||
// then dispatch to the correct one based on the runtime value of x.
|
||||
// One simple way to achieve this is with a lookup table:
|
||||
//
|
||||
// template<typename T>
|
||||
// void DispatchSquareOffset(const int64_t x, T y) {
|
||||
// if (x == 0) {
|
||||
// SquareOffset<T, 0>::run(y);
|
||||
// } else if (x == 1) {
|
||||
// SquareOffset<T, 1>::run(y);
|
||||
// } else if (x == 2) {
|
||||
// SquareOffset<T, 2>::run(y);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// This function takes both x and y as run-time arguments, and dispatches to
|
||||
// different specialized versions of SquareOffset based on the run-time value
|
||||
// of x. This works, but it's tedious and error-prone. If we want to change the
|
||||
// set of x values for which we provide compile-time specializations, then we
|
||||
// will need to do a lot of tedius editing of the dispatch function. Also, if we
|
||||
// want to provide compile-time specializations for another function other than
|
||||
// SquareOffset, we will need to duplicate the entire lookup table.
|
||||
//
|
||||
// To solve these problems, we can use the DispatchKernel1D function provided by
|
||||
// this file instead:
|
||||
//
|
||||
// template<typename T>
|
||||
// void DispatchSquareOffset(const int64_t x, T y) {
|
||||
// constexpr int64_t xmin = 0;
|
||||
// constexpr int64_t xmax = 2;
|
||||
// DispatchKernel1D<SquareOffset, T, xmin, xmax>(x, y);
|
||||
// }
|
||||
//
|
||||
// DispatchKernel1D uses template metaprogramming to compile specialized
|
||||
// versions of SquareOffset for all values of x with xmin <= x <= xmax, and
|
||||
// then dispatches to the correct one based on the run-time value of x. If we
|
||||
// want to change the range of x values for which SquareOffset is specialized
|
||||
// at compile-time, then all we have to do is change the values of the
|
||||
// compile-time constants xmin and xmax.
|
||||
//
|
||||
// This file also allows us to similarly dispatch functions that depend on two
|
||||
// compile-time int64_t values, using the DispatchKernel2D function like this:
|
||||
//
|
||||
// template<typename T, int64_t x, int64_t y>
|
||||
// struct Sum {
|
||||
// static void run(T z, T w) {
|
||||
// T val = x + y + z + w;
|
||||
// std::cout << val << std::endl;
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// template<typename T>
|
||||
// void DispatchSum(const int64_t x, const int64_t y, int z, int w) {
|
||||
// constexpr int64_t xmin = 1;
|
||||
// constexpr int64_t xmax = 3;
|
||||
// constexpr int64_t ymin = 2;
|
||||
// constexpr int64_t ymax = 5;
|
||||
// DispatchKernel2D<Sum, T, xmin, xmax, ymin, ymax>(x, y, z, w);
|
||||
// }
|
||||
//
|
||||
// Like its 1D counterpart, DispatchKernel2D uses template metaprogramming to
|
||||
// compile specialized versions of sum for all values of (x, y) with
|
||||
// xmin <= x <= xmax and ymin <= y <= ymax, then dispatches to the correct
|
||||
// specialized version based on the runtime values of x and y.
|
||||
|
||||
// Define some helper structs in an anonymous namespace.
|
||||
namespace {
|
||||
|
||||
// 1D dispatch: general case.
|
||||
// Kernel is the function we want to dispatch to; it should take a typename and
|
||||
// an int64_t as template args, and it should define a static void function
|
||||
// run which takes any number of arguments of any type.
|
||||
// In order to dispatch, we will take an additional template argument curN,
|
||||
// and increment it via template recursion until it is equal to the run-time
|
||||
// argument N.
|
||||
template<
|
||||
template<typename, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
int64_t curN,
|
||||
typename... Args
|
||||
>
|
||||
struct DispatchKernelHelper1D {
|
||||
static void run(const int64_t N, Args... args) {
|
||||
if (curN == N) {
|
||||
// The compile-time value curN is equal to the run-time value N, so we
|
||||
// can dispatch to the run method of the Kernel.
|
||||
Kernel<T, curN>::run(args...);
|
||||
} else if (curN < N) {
|
||||
// Increment curN via template recursion
|
||||
DispatchKernelHelper1D<Kernel, T, minN, maxN, curN + 1, Args...>::run(N, args...);
|
||||
}
|
||||
// We shouldn't get here -- throw an error?
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// 1D dispatch: Specialization when curN == maxN
|
||||
// We need this base case to avoid infinite template recursion.
|
||||
template<
|
||||
template<typename, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
typename... Args
|
||||
>
|
||||
struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> {
|
||||
static void run(const int64_t N, Args... args) {
|
||||
if (N == maxN) {
|
||||
Kernel<T, maxN>::run(args...);
|
||||
}
|
||||
// We shouldn't get here -- throw an error?
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// 2D dispatch, general case.
|
||||
// This is similar to the 1D case: we take additional template args curN and
|
||||
// curM, and increment them via template recursion until they are equal to
|
||||
// the run-time values of N and M, at which point we dispatch to the run
|
||||
// method of the kernel.
|
||||
template<
|
||||
template<typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN, int64_t maxN, int64_t curN,
|
||||
int64_t minM, int64_t maxM, int64_t curM,
|
||||
typename... Args
|
||||
>
|
||||
struct DispatchKernelHelper2D {
|
||||
static void run(const int64_t N, const int64_t M, Args... args) {
|
||||
if (curN == N && curM == M) {
|
||||
Kernel<T, curN, curM>::run(args...);
|
||||
} else if (curN < N && curM < M) {
|
||||
// Increment both curN and curM. This isn't strictly necessary; we could
|
||||
// just increment one or the other at each step. But this helps to cut
|
||||
// on the number of recursive calls we make.
|
||||
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, curM + 1, Args...>::run(N, M, args...);
|
||||
} else if (curN < N) {
|
||||
// Increment curN only
|
||||
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, curM, Args...>::run(N, M, args...);
|
||||
} else if (curM < M) {
|
||||
// Increment curM only
|
||||
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN, minM, maxM, curM + 1, Args...>::run(N, M, args...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// 2D dispatch, specialization for curN == maxN
|
||||
template<
|
||||
template<typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN, int64_t maxN,
|
||||
int64_t minM, int64_t maxM, int64_t curM,
|
||||
typename... Args
|
||||
>
|
||||
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, curM, Args...> {
|
||||
static void run(const int64_t N, const int64_t M, Args... args) {
|
||||
if (maxN == N && curM == M) {
|
||||
Kernel<T, maxN, curM>::run(args...);
|
||||
} else if (curM < maxM) {
|
||||
DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, curM + 1, Args...>::run(N, M, args...);
|
||||
}
|
||||
// We should not get here -- throw an error?
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// 2D dispatch, specialization for curM == maxM
|
||||
template<
|
||||
template<typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN, int64_t maxN, int64_t curN,
|
||||
int64_t minM, int64_t maxM,
|
||||
typename... Args
|
||||
>
|
||||
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, curN, minM, maxM, maxM, Args...> {
|
||||
static void run(const int64_t N, const int64_t M, Args... args) {
|
||||
if (curN == N && maxM == M) {
|
||||
Kernel<T, curN, maxM>::run(args...);
|
||||
} else if (curN < maxN) {
|
||||
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, maxM, Args...>::run(N, M, args...);
|
||||
}
|
||||
// We should not get here -- throw an error?
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// 2D dispatch, specialization for curN == maxN, curM == maxM
|
||||
template<
|
||||
template<typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN, int64_t maxN,
|
||||
int64_t minM, int64_t maxM,
|
||||
typename... Args
|
||||
>
|
||||
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, maxM, Args...> {
|
||||
static void run(const int64_t N, const int64_t M, Args... args) {
|
||||
if (maxN == N && maxM == M) {
|
||||
Kernel<T, maxN, maxM>::run(args...);
|
||||
}
|
||||
// We should not get here -- throw an error?
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
|
||||
// This is the function we expect users to call to dispatch to 1D functions
|
||||
template<
|
||||
template<typename, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
typename... Args
|
||||
>
|
||||
void DispatchKernel1D(const int64_t N, Args... args) {
|
||||
if (minN <= N && N <= maxN) {
|
||||
// Kick off the template recursion by calling the Helper with curN = minN
|
||||
DispatchKernelHelper1D<Kernel, T, minN, maxN, minN, Args...>::run(N, args...);
|
||||
}
|
||||
// Maybe throw an error if we tried to dispatch outside the allowed range?
|
||||
}
|
||||
|
||||
|
||||
// This is the function we expect users to call to dispatch to 2D functions
|
||||
template<
|
||||
template<typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN, int64_t maxN,
|
||||
int64_t minM, int64_t maxM,
|
||||
typename... Args
|
||||
>
|
||||
void DispatchKernel2D(const int64_t N, const int64_t M, Args... args) {
|
||||
if (minN <= N && N <= maxN && minM <= M && M <= maxM) {
|
||||
// Kick off the template recursion by calling the Helper with curN = minN
|
||||
// and curM = minM
|
||||
DispatchKernelHelper2D<Kernel, T, minN, maxN, minN, minM, maxM, minM, Args...>::run(N, M, args...);
|
||||
}
|
||||
// Maybe throw an error if we tried to dispatch outside the specified range?
|
||||
}
|
@ -6,6 +6,7 @@
|
||||
#include "compositing/weighted_sum.h"
|
||||
#include "face_areas_normals/face_areas_normals.h"
|
||||
#include "gather_scatter/gather_scatter.h"
|
||||
#include "knn/knn.h"
|
||||
#include "nearest_neighbor_points/nearest_neighbor_points.h"
|
||||
#include "packed_to_padded_tensor/packed_to_padded_tensor.h"
|
||||
#include "rasterize_meshes/rasterize_meshes.h"
|
||||
@ -16,6 +17,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("face_areas_normals_backward", &FaceAreasNormalsBackward);
|
||||
m.def("packed_to_padded", &PackedToPadded);
|
||||
m.def("padded_to_packed", &PaddedToPacked);
|
||||
m.def("knn_points_idx", &KNearestNeighborIdx);
|
||||
m.def("nn_points_idx", &NearestNeighborIdx);
|
||||
m.def("gather_scatter", &gather_scatter);
|
||||
m.def("rasterize_points", &RasterizePoints);
|
||||
|
120
pytorch3d/csrc/index_utils.cuh
Normal file
120
pytorch3d/csrc/index_utils.cuh
Normal file
@ -0,0 +1,120 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
// This converts dynamic array lookups into static array lookups, for small
|
||||
// arrays up to size 32.
|
||||
//
|
||||
// Suppose we have a small thread-local array:
|
||||
//
|
||||
// float vals[10];
|
||||
//
|
||||
// Ideally we should only index this array using static indices:
|
||||
//
|
||||
// for (int i = 0; i < 10; ++i) vals[i] = i * i;
|
||||
//
|
||||
// If we do so, then the CUDA compiler may be able to place the array into
|
||||
// registers, which can have a big performance improvement. However if we
|
||||
// access the array dynamically, the the compiler may force the array into
|
||||
// local memory, which has the same latency as global memory.
|
||||
//
|
||||
// These functions convert dynamic array access into static array access
|
||||
// using a brute-force lookup table. It can be used like this:
|
||||
//
|
||||
// float vals[10];
|
||||
// int idx = 3;
|
||||
// float val = 3.14f;
|
||||
// RegisterIndexUtils<float, 10>::set(vals, idx, val);
|
||||
// float val2 = RegisterIndexUtils<float, 10>::get(vals, idx);
|
||||
//
|
||||
// The implementation is based on fbcuda/RegisterUtils.cuh:
|
||||
// https://github.com/facebook/fbcuda/blob/master/RegisterUtils.cuh
|
||||
// To avoid depending on the entire library, we just reimplement these two
|
||||
// functions. The fbcuda implementation is a bit more sophisticated, and uses
|
||||
// the preprocessor to generate switch statements that go up to N for each
|
||||
// value of N. We are lazy and just have a giant explicit switch statement.
|
||||
//
|
||||
// We might be able to use a template metaprogramming approach similar to
|
||||
// DispatchKernel1D for this. However DispatchKernel1D is intended to be used
|
||||
// for dispatching to the correct CUDA kernel on the host, while this is
|
||||
// is intended to run on the device. I was concerned that a metaprogramming
|
||||
// approach for this might lead to extra function calls at runtime if the
|
||||
// compiler fails to optimize them away, which could be very slow on device.
|
||||
// However I didn't actually benchmark or test this.
|
||||
template<typename T, int N>
|
||||
struct RegisterIndexUtils {
|
||||
__device__ __forceinline__ static T get(const T arr[N], int idx) {
|
||||
if (idx < 0 || idx >= N) return T();
|
||||
switch (idx) {
|
||||
case 0: return arr[0];
|
||||
case 1: return arr[1];
|
||||
case 2: return arr[2];
|
||||
case 3: return arr[3];
|
||||
case 4: return arr[4];
|
||||
case 5: return arr[5];
|
||||
case 6: return arr[6];
|
||||
case 7: return arr[7];
|
||||
case 8: return arr[8];
|
||||
case 9: return arr[9];
|
||||
case 10: return arr[10];
|
||||
case 11: return arr[11];
|
||||
case 12: return arr[12];
|
||||
case 13: return arr[13];
|
||||
case 14: return arr[14];
|
||||
case 15: return arr[15];
|
||||
case 16: return arr[16];
|
||||
case 17: return arr[17];
|
||||
case 18: return arr[18];
|
||||
case 19: return arr[19];
|
||||
case 20: return arr[20];
|
||||
case 21: return arr[21];
|
||||
case 22: return arr[22];
|
||||
case 23: return arr[23];
|
||||
case 24: return arr[24];
|
||||
case 25: return arr[25];
|
||||
case 26: return arr[26];
|
||||
case 27: return arr[27];
|
||||
case 28: return arr[28];
|
||||
case 29: return arr[29];
|
||||
case 30: return arr[30];
|
||||
case 31: return arr[31];
|
||||
};
|
||||
return T();
|
||||
}
|
||||
|
||||
__device__ __forceinline__ static void set(T arr[N], int idx, T val) {
|
||||
if (idx < 0 || idx >= N) return;
|
||||
switch (idx) {
|
||||
case 0: arr[0] = val; break;
|
||||
case 1: arr[1] = val; break;
|
||||
case 2: arr[2] = val; break;
|
||||
case 3: arr[3] = val; break;
|
||||
case 4: arr[4] = val; break;
|
||||
case 5: arr[5] = val; break;
|
||||
case 6: arr[6] = val; break;
|
||||
case 7: arr[7] = val; break;
|
||||
case 8: arr[8] = val; break;
|
||||
case 9: arr[9] = val; break;
|
||||
case 10: arr[10] = val; break;
|
||||
case 11: arr[11] = val; break;
|
||||
case 12: arr[12] = val; break;
|
||||
case 13: arr[13] = val; break;
|
||||
case 14: arr[14] = val; break;
|
||||
case 15: arr[15] = val; break;
|
||||
case 16: arr[16] = val; break;
|
||||
case 17: arr[17] = val; break;
|
||||
case 18: arr[18] = val; break;
|
||||
case 19: arr[19] = val; break;
|
||||
case 20: arr[20] = val; break;
|
||||
case 21: arr[21] = val; break;
|
||||
case 22: arr[22] = val; break;
|
||||
case 23: arr[23] = val; break;
|
||||
case 24: arr[24] = val; break;
|
||||
case 25: arr[25] = val; break;
|
||||
case 26: arr[26] = val; break;
|
||||
case 27: arr[27] = val; break;
|
||||
case 28: arr[28] = val; break;
|
||||
case 29: arr[29] = val; break;
|
||||
case 30: arr[30] = val; break;
|
||||
case 31: arr[31] = val; break;
|
||||
}
|
||||
}
|
||||
};
|
369
pytorch3d/csrc/knn/knn.cu
Normal file
369
pytorch3d/csrc/knn/knn.cu
Normal file
@ -0,0 +1,369 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <float.h>
|
||||
#include <iostream>
|
||||
#include <tuple>
|
||||
|
||||
#include "dispatch.cuh"
|
||||
#include "mink.cuh"
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void KNearestNeighborKernelV0(
|
||||
const scalar_t* __restrict__ points1,
|
||||
const scalar_t* __restrict__ points2,
|
||||
scalar_t* __restrict__ dists,
|
||||
int64_t* __restrict__ idxs,
|
||||
const size_t N,
|
||||
const size_t P1,
|
||||
const size_t P2,
|
||||
const size_t D,
|
||||
const size_t K) {
|
||||
// Stupid version: Make each thread handle one query point and loop over
|
||||
// all P2 target points. There are N * P1 input points to handle, so
|
||||
// do a trivial parallelization over threads.
|
||||
// Store both dists and indices for knn in global memory.
|
||||
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int num_threads = blockDim.x * gridDim.x;
|
||||
for (int np = tid; np < N * P1; np += num_threads) {
|
||||
int n = np / P1;
|
||||
int p1 = np % P1;
|
||||
int offset = n * P1 * K + p1 * K;
|
||||
MinK<scalar_t, int64_t> mink(dists + offset, idxs + offset, K);
|
||||
for (int p2 = 0; p2 < P2; ++p2) {
|
||||
// Find the distance between points1[n, p1] and points[n, p2]
|
||||
scalar_t dist = 0;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
scalar_t coord1 = points1[n * P1 * D + p1 * D + d];
|
||||
scalar_t coord2 = points2[n * P2 * D + p2 * D + d];
|
||||
scalar_t diff = coord1 - coord2;
|
||||
dist += diff * diff;
|
||||
}
|
||||
mink.add(dist, p2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int64_t D>
|
||||
__global__ void KNearestNeighborKernelV1(
|
||||
const scalar_t* __restrict__ points1,
|
||||
const scalar_t* __restrict__ points2,
|
||||
scalar_t* __restrict__ dists,
|
||||
int64_t* __restrict__ idxs,
|
||||
const size_t N,
|
||||
const size_t P1,
|
||||
const size_t P2,
|
||||
const size_t K) {
|
||||
// Same idea as the previous version, but hoist D into a template argument
|
||||
// so we can cache the current point in a thread-local array. We still store
|
||||
// the current best K dists and indices in global memory, so this should work
|
||||
// for very large K and fairly large D.
|
||||
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int num_threads = blockDim.x * gridDim.x;
|
||||
scalar_t cur_point[D];
|
||||
for (int np = tid; np < N * P1; np += num_threads) {
|
||||
int n = np / P1;
|
||||
int p1 = np % P1;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
cur_point[d] = points1[n * P1 * D + p1 * D + d];
|
||||
}
|
||||
int offset = n * P1 * K + p1 * K;
|
||||
MinK<scalar_t, int64_t> mink(dists + offset, idxs + offset, K);
|
||||
for (int p2 = 0; p2 < P2; ++p2) {
|
||||
// Find the distance between cur_point and points[n, p2]
|
||||
scalar_t dist = 0;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
scalar_t diff = cur_point[d] - points2[n * P2 * D + p2 * D + d];
|
||||
dist += diff * diff;
|
||||
}
|
||||
mink.add(dist, p2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This is a shim functor to allow us to dispatch using DispatchKernel1D
|
||||
template <typename scalar_t, int64_t D>
|
||||
struct KNearestNeighborV1Functor {
|
||||
static void run(
|
||||
size_t blocks,
|
||||
size_t threads,
|
||||
const scalar_t* __restrict__ points1,
|
||||
const scalar_t* __restrict__ points2,
|
||||
scalar_t* __restrict__ dists,
|
||||
int64_t* __restrict__ idxs,
|
||||
const size_t N,
|
||||
const size_t P1,
|
||||
const size_t P2,
|
||||
const size_t K) {
|
||||
KNearestNeighborKernelV1<scalar_t, D>
|
||||
<<<blocks, threads>>>(points1, points2, dists, idxs, N, P1, P2, K);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, int64_t D, int64_t K>
|
||||
__global__ void KNearestNeighborKernelV2(
|
||||
const scalar_t* __restrict__ points1,
|
||||
const scalar_t* __restrict__ points2,
|
||||
scalar_t* __restrict__ dists,
|
||||
int64_t* __restrict__ idxs,
|
||||
const int64_t N,
|
||||
const int64_t P1,
|
||||
const int64_t P2) {
|
||||
// Same general implementation as V2, but also hoist K into a template arg.
|
||||
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int num_threads = blockDim.x * gridDim.x;
|
||||
scalar_t cur_point[D];
|
||||
scalar_t min_dists[K];
|
||||
int min_idxs[K];
|
||||
for (int np = tid; np < N * P1; np += num_threads) {
|
||||
int n = np / P1;
|
||||
int p1 = np % P1;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
cur_point[d] = points1[n * P1 * D + p1 * D + d];
|
||||
}
|
||||
MinK<scalar_t, int> mink(min_dists, min_idxs, K);
|
||||
for (int p2 = 0; p2 < P2; ++p2) {
|
||||
scalar_t dist = 0;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
int offset = n * P2 * D + p2 * D + d;
|
||||
scalar_t diff = cur_point[d] - points2[offset];
|
||||
dist += diff * diff;
|
||||
}
|
||||
mink.add(dist, p2);
|
||||
}
|
||||
for (int k = 0; k < mink.size(); ++k) {
|
||||
idxs[n * P1 * K + p1 * K + k] = min_idxs[k];
|
||||
dists[n * P1 * K + p1 * K + k] = min_dists[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This is a shim so we can dispatch using DispatchKernel2D
|
||||
template <typename scalar_t, int64_t D, int64_t K>
|
||||
struct KNearestNeighborKernelV2Functor {
|
||||
static void run(
|
||||
size_t blocks,
|
||||
size_t threads,
|
||||
const scalar_t* __restrict__ points1,
|
||||
const scalar_t* __restrict__ points2,
|
||||
scalar_t* __restrict__ dists,
|
||||
int64_t* __restrict__ idxs,
|
||||
const int64_t N,
|
||||
const int64_t P1,
|
||||
const int64_t P2) {
|
||||
KNearestNeighborKernelV2<scalar_t, D, K>
|
||||
<<<blocks, threads>>>(points1, points2, dists, idxs, N, P1, P2);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, int D, int K>
|
||||
__global__ void KNearestNeighborKernelV3(
|
||||
const scalar_t* __restrict__ points1,
|
||||
const scalar_t* __restrict__ points2,
|
||||
scalar_t* __restrict__ dists,
|
||||
int64_t* __restrict__ idxs,
|
||||
const size_t N,
|
||||
const size_t P1,
|
||||
const size_t P2) {
|
||||
// Same idea as V2, but use register indexing for thread-local arrays.
|
||||
// Enabling sorting for this version leads to huge slowdowns; I suspect
|
||||
// that it forces min_dists into local memory rather than registers.
|
||||
// As a result this version is always unsorted.
|
||||
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int num_threads = blockDim.x * gridDim.x;
|
||||
scalar_t cur_point[D];
|
||||
scalar_t min_dists[K];
|
||||
int min_idxs[K];
|
||||
for (int np = tid; np < N * P1; np += num_threads) {
|
||||
int n = np / P1;
|
||||
int p1 = np % P1;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
cur_point[d] = points1[n * P1 * D + p1 * D + d];
|
||||
}
|
||||
RegisterMinK<scalar_t, int, K> mink(min_dists, min_idxs);
|
||||
for (int p2 = 0; p2 < P2; ++p2) {
|
||||
scalar_t dist = 0;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
int offset = n * P2 * D + p2 * D + d;
|
||||
scalar_t diff = cur_point[d] - points2[offset];
|
||||
dist += diff * diff;
|
||||
}
|
||||
mink.add(dist, p2);
|
||||
}
|
||||
for (int k = 0; k < mink.size(); ++k) {
|
||||
idxs[n * P1 * K + p1 * K + k] = min_idxs[k];
|
||||
dists[n * P1 * K + p1 * K + k] = min_dists[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This is a shim so we can dispatch using DispatchKernel2D
|
||||
template <typename scalar_t, int64_t D, int64_t K>
|
||||
struct KNearestNeighborKernelV3Functor {
|
||||
static void run(
|
||||
size_t blocks,
|
||||
size_t threads,
|
||||
const scalar_t* __restrict__ points1,
|
||||
const scalar_t* __restrict__ points2,
|
||||
scalar_t* __restrict__ dists,
|
||||
int64_t* __restrict__ idxs,
|
||||
const size_t N,
|
||||
const size_t P1,
|
||||
const size_t P2) {
|
||||
KNearestNeighborKernelV3<scalar_t, D, K>
|
||||
<<<blocks, threads>>>(points1, points2, dists, idxs, N, P1, P2);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr int V1_MIN_D = 1;
|
||||
constexpr int V1_MAX_D = 32;
|
||||
|
||||
constexpr int V2_MIN_D = 1;
|
||||
constexpr int V2_MAX_D = 8;
|
||||
constexpr int V2_MIN_K = 1;
|
||||
constexpr int V2_MAX_K = 32;
|
||||
|
||||
constexpr int V3_MIN_D = 1;
|
||||
constexpr int V3_MAX_D = 8;
|
||||
constexpr int V3_MIN_K = 1;
|
||||
constexpr int V3_MAX_K = 4;
|
||||
|
||||
bool InBounds(const int64_t min, const int64_t x, const int64_t max) {
|
||||
return min <= x && x <= max;
|
||||
}
|
||||
|
||||
bool CheckVersion(int version, const int64_t D, const int64_t K) {
|
||||
if (version == 0) {
|
||||
return true;
|
||||
} else if (version == 1) {
|
||||
return InBounds(V1_MIN_D, D, V1_MAX_D);
|
||||
} else if (version == 2) {
|
||||
return InBounds(V2_MIN_D, D, V2_MAX_D) && InBounds(V2_MIN_K, K, V2_MAX_K);
|
||||
} else if (version == 3) {
|
||||
return InBounds(V3_MIN_D, D, V3_MAX_D) && InBounds(V3_MIN_K, K, V3_MAX_K);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
int ChooseVersion(const int64_t D, const int64_t K) {
|
||||
for (int version = 3; version >= 1; version--) {
|
||||
if (CheckVersion(version, D, K)) {
|
||||
return version;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
const at::Tensor& p1,
|
||||
const at::Tensor& p2,
|
||||
int K,
|
||||
int version) {
|
||||
const auto N = p1.size(0);
|
||||
const auto P1 = p1.size(1);
|
||||
const auto P2 = p2.size(1);
|
||||
const auto D = p2.size(2);
|
||||
const int64_t K_64 = K;
|
||||
|
||||
AT_ASSERTM(p2.size(2) == D, "Point sets must have the same last dimension");
|
||||
auto long_dtype = p1.options().dtype(at::kLong);
|
||||
auto idxs = at::full({N, P1, K}, -1, long_dtype);
|
||||
auto dists = at::full({N, P1, K}, -1, p1.options());
|
||||
|
||||
if (version < 0) {
|
||||
version = ChooseVersion(D, K);
|
||||
} else if (!CheckVersion(version, D, K)) {
|
||||
int new_version = ChooseVersion(D, K);
|
||||
std::cout << "WARNING: Requested KNN version " << version
|
||||
<< " is not compatible with D = " << D << "; K = " << K
|
||||
<< ". Falling back to version = " << new_version << std::endl;
|
||||
version = new_version;
|
||||
}
|
||||
|
||||
// At this point we should have a valid version no matter what data the user
|
||||
// gave us. But we can check once more to be sure; however this time
|
||||
// assert fail since failing at this point means we have a bug in our version
|
||||
// selection or checking code.
|
||||
AT_ASSERTM(CheckVersion(version, D, K), "Invalid version");
|
||||
|
||||
const size_t threads = 256;
|
||||
const size_t blocks = 256;
|
||||
if (version == 0) {
|
||||
AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] {
|
||||
KNearestNeighborKernelV0<scalar_t>
|
||||
<<<blocks, threads>>>(
|
||||
p1.data_ptr<scalar_t>(),
|
||||
p2.data_ptr<scalar_t>(),
|
||||
dists.data_ptr<scalar_t>(),
|
||||
idxs.data_ptr<int64_t>(),
|
||||
N,
|
||||
P1,
|
||||
P2,
|
||||
D,
|
||||
K);
|
||||
}));
|
||||
} else if (version == 1) {
|
||||
AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] {
|
||||
DispatchKernel1D<
|
||||
KNearestNeighborV1Functor,
|
||||
scalar_t,
|
||||
V1_MIN_D,
|
||||
V1_MAX_D>(
|
||||
D,
|
||||
blocks,
|
||||
threads,
|
||||
p1.data_ptr<scalar_t>(),
|
||||
p2.data_ptr<scalar_t>(),
|
||||
dists.data_ptr<scalar_t>(),
|
||||
idxs.data_ptr<int64_t>(),
|
||||
N,
|
||||
P1,
|
||||
P2,
|
||||
K);
|
||||
}));
|
||||
} else if (version == 2) {
|
||||
AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] {
|
||||
DispatchKernel2D<
|
||||
KNearestNeighborKernelV2Functor,
|
||||
scalar_t,
|
||||
V2_MIN_D,
|
||||
V2_MAX_D,
|
||||
V2_MIN_K,
|
||||
V2_MAX_K>(
|
||||
D,
|
||||
K_64,
|
||||
blocks,
|
||||
threads,
|
||||
p1.data_ptr<scalar_t>(),
|
||||
p2.data_ptr<scalar_t>(),
|
||||
dists.data_ptr<scalar_t>(),
|
||||
idxs.data_ptr<int64_t>(),
|
||||
N,
|
||||
P1,
|
||||
P2);
|
||||
}));
|
||||
} else if (version == 3) {
|
||||
AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] {
|
||||
DispatchKernel2D<
|
||||
KNearestNeighborKernelV3Functor,
|
||||
scalar_t,
|
||||
V3_MIN_D,
|
||||
V3_MAX_D,
|
||||
V3_MIN_K,
|
||||
V3_MAX_K>(
|
||||
D,
|
||||
K_64,
|
||||
blocks,
|
||||
threads,
|
||||
p1.data_ptr<scalar_t>(),
|
||||
p2.data_ptr<scalar_t>(),
|
||||
dists.data_ptr<scalar_t>(),
|
||||
idxs.data_ptr<int64_t>(),
|
||||
N,
|
||||
P1,
|
||||
P2);
|
||||
}));
|
||||
}
|
||||
|
||||
return std::make_tuple(idxs, dists);
|
||||
}
|
54
pytorch3d/csrc/knn/knn.h
Normal file
54
pytorch3d/csrc/knn/knn.h
Normal file
@ -0,0 +1,54 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
#include <tuple>
|
||||
#include "pytorch3d_cutils.h"
|
||||
|
||||
// Compute indices of K nearest neighbors in pointcloud p2 to points
|
||||
// in pointcloud p1.
|
||||
//
|
||||
// Args:
|
||||
// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
|
||||
// containing P1 points of dimension D.
|
||||
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
|
||||
// containing P2 points of dimension D.
|
||||
// K: int giving the number of nearest points to return.
|
||||
// sorted: bool telling whether to sort the K returned points by their
|
||||
// distance version: Integer telling which implementation to use.
|
||||
// TODO(jcjohns): Document this more, or maybe remove it before
|
||||
// landing.
|
||||
//
|
||||
// Returns:
|
||||
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
|
||||
// p1_neighbor_idx[n, i, k] = j means that the kth nearest
|
||||
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
|
||||
|
||||
// CPU implementation.
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K);
|
||||
|
||||
// CUDA implementation
|
||||
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
const at::Tensor& p1,
|
||||
const at::Tensor& p2,
|
||||
int K,
|
||||
int version);
|
||||
|
||||
// Implementation which is exposed.
|
||||
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
|
||||
const at::Tensor& p1,
|
||||
const at::Tensor& p2,
|
||||
int K,
|
||||
int version) {
|
||||
if (p1.type().is_cuda() || p2.type().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(p1);
|
||||
CHECK_CONTIGUOUS_CUDA(p2);
|
||||
return KNearestNeighborIdxCuda(p1, p2, K, version);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return KNearestNeighborIdxCpu(p1, p2, K);
|
||||
}
|
52
pytorch3d/csrc/knn/knn_cpu.cpp
Normal file
52
pytorch3d/csrc/knn/knn_cpu.cpp
Normal file
@ -0,0 +1,52 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <queue>
|
||||
#include <tuple>
|
||||
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K) {
|
||||
const int N = p1.size(0);
|
||||
const int P1 = p1.size(1);
|
||||
const int D = p1.size(2);
|
||||
const int P2 = p2.size(1);
|
||||
|
||||
auto long_opts = p1.options().dtype(torch::kInt64);
|
||||
torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts);
|
||||
torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());
|
||||
|
||||
auto p1_a = p1.accessor<float, 3>();
|
||||
auto p2_a = p2.accessor<float, 3>();
|
||||
auto idxs_a = idxs.accessor<int64_t, 3>();
|
||||
auto dists_a = dists.accessor<float, 3>();
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int i1 = 0; i1 < P1; ++i1) {
|
||||
// Use a priority queue to store (distance, index) tuples.
|
||||
std::priority_queue<std::tuple<float, int>> q;
|
||||
for (int i2 = 0; i2 < P2; ++i2) {
|
||||
float dist = 0;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
|
||||
dist += diff * diff;
|
||||
}
|
||||
int size = static_cast<int>(q.size());
|
||||
if (size < K || dist < std::get<0>(q.top())) {
|
||||
q.emplace(dist, i2);
|
||||
if (size >= K) {
|
||||
q.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
while (!q.empty()) {
|
||||
auto t = q.top();
|
||||
q.pop();
|
||||
const int k = q.size();
|
||||
dists_a[n][i1][k] = std::get<0>(t);
|
||||
idxs_a[n][i1][k] = std::get<1>(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_tuple(idxs, dists);
|
||||
}
|
162
pytorch3d/csrc/mink.cuh
Normal file
162
pytorch3d/csrc/mink.cuh
Normal file
@ -0,0 +1,162 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#define MINK_H
|
||||
|
||||
#include "index_utils.cuh"
|
||||
|
||||
|
||||
// A data structure to keep track of the smallest K keys seen so far as well
|
||||
// as their associated values, intended to be used in device code.
|
||||
// This data structure doesn't allocate any memory; keys and values are stored
|
||||
// in arrays passed to the constructor.
|
||||
//
|
||||
// The implementation is generic; it can be used for any key type that supports
|
||||
// the < operator, and can be used with any value type.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// float keys[K];
|
||||
// int values[K];
|
||||
// MinK<float, int> mink(keys, values, K);
|
||||
// for (...) {
|
||||
// // Produce some key and value from somewhere
|
||||
// mink.add(key, value);
|
||||
// }
|
||||
// mink.sort();
|
||||
//
|
||||
// Now keys and values store the smallest K keys seen so far and the values
|
||||
// associated to these keys:
|
||||
//
|
||||
// for (int k = 0; k < K; ++k) {
|
||||
// float key_k = keys[k];
|
||||
// int value_k = values[k];
|
||||
// }
|
||||
template<typename key_t, typename value_t>
|
||||
class MinK {
|
||||
public:
|
||||
|
||||
// Constructor.
|
||||
//
|
||||
// Arguments:
|
||||
// keys: Array in which to store keys
|
||||
// values: Array in which to store values
|
||||
// K: How many values to keep track of
|
||||
__device__ MinK(key_t *keys, value_t *vals, int K) :
|
||||
keys(keys), vals(vals), K(K), _size(0) { }
|
||||
|
||||
// Try to add a new key and associated value to the data structure. If the key
|
||||
// is one of the smallest K seen so far then it will be kept; otherwise it
|
||||
// it will not be kept.
|
||||
//
|
||||
// This takes O(1) operations if the new key is not kept, or if the structure
|
||||
// currently contains fewer than K elements. Otherwise this takes O(K) time.
|
||||
//
|
||||
// Arguments:
|
||||
// key: The key to add
|
||||
// val: The value associated to the key
|
||||
__device__ __forceinline__ void add(const key_t &key, const value_t &val) {
|
||||
if (_size < K) {
|
||||
keys[_size] = key;
|
||||
vals[_size] = val;
|
||||
if (_size == 0 || key > max_key) {
|
||||
max_key = key;
|
||||
max_idx = _size;
|
||||
}
|
||||
_size++;
|
||||
} else if (key < max_key) {
|
||||
keys[max_idx] = key;
|
||||
vals[max_idx] = val;
|
||||
max_key = key;
|
||||
for (int k = 0; k < K; ++k) {
|
||||
key_t cur_key = keys[k];
|
||||
if (cur_key > max_key) {
|
||||
max_key = cur_key;
|
||||
max_idx = k;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get the number of items currently stored in the structure.
|
||||
// This takes O(1) time.
|
||||
__device__ __forceinline__ int size() {
|
||||
return _size;
|
||||
}
|
||||
|
||||
// Sort the items stored in the structure using bubble sort.
|
||||
// This takes O(K^2) time.
|
||||
__device__ __forceinline__ void sort() {
|
||||
for (int i = 0; i < _size - 1; ++i) {
|
||||
for (int j = 0; j < _size - i - 1; ++j) {
|
||||
if (keys[j + 1] < keys[j]) {
|
||||
key_t key = keys[j];
|
||||
value_t val = vals[j];
|
||||
keys[j] = keys[j + 1];
|
||||
vals[j] = vals[j + 1];
|
||||
keys[j + 1] = key;
|
||||
vals[j + 1] = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
key_t *keys;
|
||||
value_t *vals;
|
||||
int K;
|
||||
int _size;
|
||||
key_t max_key;
|
||||
int max_idx;
|
||||
};
|
||||
|
||||
|
||||
// This is a version of MinK that only touches the arrays using static indexing
|
||||
// via RegisterIndexUtils. If the keys and values are stored in thread-local
|
||||
// arrays, then this may allow the compiler to place them in registers for
|
||||
// fast access.
|
||||
//
|
||||
// This has the same API as RegisterMinK, but doesn't support sorting.
|
||||
// We found that sorting via RegisterIndexUtils gave very poor performance,
|
||||
// and suspect it may have prevented the compiler from placing the arrays
|
||||
// into registers.
|
||||
template<typename key_t, typename value_t, int K>
|
||||
class RegisterMinK {
|
||||
public:
|
||||
__device__ RegisterMinK(key_t *keys, value_t *vals) :
|
||||
keys(keys), vals(vals), _size(0) {}
|
||||
|
||||
__device__ __forceinline__ void add(const key_t &key, const value_t &val) {
|
||||
if (_size < K) {
|
||||
RegisterIndexUtils<key_t, K>::set(keys, _size, key);
|
||||
RegisterIndexUtils<value_t, K>::set(vals, _size, val);
|
||||
if (_size == 0 || key > max_key) {
|
||||
max_key = key;
|
||||
max_idx = _size;
|
||||
}
|
||||
_size++;
|
||||
} else if (key < max_key) {
|
||||
RegisterIndexUtils<key_t, K>::set(keys, max_idx, key);
|
||||
RegisterIndexUtils<value_t, K>::set(vals, max_idx, val);
|
||||
max_key = key;
|
||||
for (int k = 0; k < K; ++k) {
|
||||
key_t cur_key = RegisterIndexUtils<key_t, K>::get(keys, k);
|
||||
if (cur_key > max_key) {
|
||||
max_key = cur_key;
|
||||
max_idx = k;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int size() {
|
||||
return _size;
|
||||
}
|
||||
|
||||
private:
|
||||
key_t *keys;
|
||||
value_t *vals;
|
||||
int _size;
|
||||
key_t max_key;
|
||||
int max_idx;
|
||||
};
|
@ -39,4 +39,4 @@ at::Tensor NearestNeighborIdx(at::Tensor p1, at::Tensor p2) {
|
||||
#endif
|
||||
}
|
||||
return NearestNeighborIdxCpu(p1, p2);
|
||||
};
|
||||
}
|
||||
|
67
pytorch3d/ops/knn.py
Normal file
67
pytorch3d/ops/knn.py
Normal file
@ -0,0 +1,67 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import torch
|
||||
from pytorch3d import _C
|
||||
|
||||
|
||||
def knn_points_idx(p1, p2, K, sorted=False, version=-1):
|
||||
"""
|
||||
K-Nearest neighbors on point clouds.
|
||||
|
||||
Args:
|
||||
p1: Tensor of shape (N, P1, D) giving a batch of point clouds, each
|
||||
containing P1 points of dimension D.
|
||||
p2: Tensor of shape (N, P2, D) giving a batch of point clouds, each
|
||||
containing P2 points of dimension D.
|
||||
K: Integer giving the number of nearest neighbors to return
|
||||
sorted: Whether to sort the resulting points.
|
||||
version: Which KNN implementation to use in the backend. If version=-1,
|
||||
the correct implementation is selected based on the shapes of
|
||||
the inputs.
|
||||
|
||||
Returns:
|
||||
idx: LongTensor of shape (N, P1, K) giving the indices of the
|
||||
K nearest neighbors from points in p1 to points in p2.
|
||||
Concretely, if idx[n, i, k] = j then p2[n, j] is one of the K
|
||||
nearest neighbor to p1[n, i] in p2[n]. If sorted=True, then
|
||||
p2[n, j] is the kth nearest neighbor to p1[n, i].
|
||||
"""
|
||||
idx, dists = _C.knn_points_idx(p1, p2, K, version)
|
||||
if sorted:
|
||||
dists, sort_idx = dists.sort(dim=2)
|
||||
idx = idx.gather(2, sort_idx)
|
||||
return idx, dists
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _knn_points_idx_naive(p1, p2, K, sorted=False) -> torch.Tensor:
|
||||
"""
|
||||
Naive PyTorch implementation of K-Nearest Neighbors.
|
||||
|
||||
This is much less efficient than _C.knn_points_idx, but we include this
|
||||
naive implementation for testing and benchmarking.
|
||||
|
||||
Args:
|
||||
p1: Tensor of shape (N, P1, D) giving a batch of point clouds, each
|
||||
containing P1 points of dimension D.
|
||||
p2: Tensor of shape (N, P2, D) giving a batch of point clouds, each
|
||||
containing P2 points of dimension D.
|
||||
K: Integer giving the number of nearest neighbors to return
|
||||
sorted: Whether to sort the resulting points.
|
||||
|
||||
Returns:
|
||||
idx: LongTensor of shape (N, P1, K) giving the indices of the
|
||||
K nearest neighbors from points in p1 to points in p2.
|
||||
Concretely, if idx[n, i, k] = j then p2[n, j] is one of the K
|
||||
nearest neighbor to p1[n, i] in p2[n]. If sorted=True, then
|
||||
p2[n, j] is the kth nearest neighbor to p1[n, i].
|
||||
dists: Tensor of shape (N, P1, K) giving the distances to the nearest
|
||||
neighbors.
|
||||
"""
|
||||
N, P1, D = p1.shape
|
||||
_N, P2, _D = p2.shape
|
||||
assert N == _N and D == _D
|
||||
diffs = p1.view(N, P1, 1, D) - p2.view(N, 1, P2, D)
|
||||
dists2 = (diffs * diffs).sum(dim=3)
|
||||
out = dists2.topk(K, dim=2, largest=False, sorted=sorted)
|
||||
return out.indices, out.values
|
174
tests/bm_knn.py
Normal file
174
tests/bm_knn.py
Normal file
@ -0,0 +1,174 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
from fvcore.common.benchmark import benchmark
|
||||
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.ops.knn import _knn_points_idx_naive
|
||||
|
||||
|
||||
def bm_knn() -> None:
|
||||
""" Entry point for the benchmark """
|
||||
benchmark_knn_cpu()
|
||||
benchmark_knn_cuda_vs_naive()
|
||||
benchmark_knn_cuda_versions()
|
||||
|
||||
|
||||
def benchmark_knn_cuda_versions() -> None:
|
||||
# Compare our different KNN implementations,
|
||||
# and also compare against our existing 1-NN
|
||||
Ns = [1, 2]
|
||||
Ps = [4096, 16384]
|
||||
Ds = [3]
|
||||
Ks = [1, 4, 16, 64]
|
||||
versions = [0, 1, 2, 3]
|
||||
knn_kwargs, nn_kwargs = [], []
|
||||
for N, P, D, K, version in product(Ns, Ps, Ds, Ks, versions):
|
||||
if version == 2 and K > 32:
|
||||
continue
|
||||
if version == 3 and K > 4:
|
||||
continue
|
||||
knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K, 'v': version})
|
||||
for N, P, D in product(Ns, Ps, Ds):
|
||||
nn_kwargs.append({'N': N, 'D': D, 'P': P})
|
||||
benchmark(
|
||||
knn_cuda_with_init,
|
||||
'KNN_CUDA_VERSIONS',
|
||||
knn_kwargs,
|
||||
warmup_iters=1,
|
||||
)
|
||||
benchmark(
|
||||
nn_cuda_with_init,
|
||||
'NN_CUDA',
|
||||
nn_kwargs,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
|
||||
def benchmark_knn_cuda_vs_naive() -> None:
|
||||
# Compare against naive pytorch version of KNN
|
||||
Ns = [1, 2, 4]
|
||||
Ps = [1024, 4096, 16384, 65536]
|
||||
Ds = [3]
|
||||
Ks = [1, 2, 4, 8, 16]
|
||||
knn_kwargs, naive_kwargs = [], []
|
||||
for N, P, D, K in product(Ns, Ps, Ds, Ks):
|
||||
knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K})
|
||||
if P <= 4096:
|
||||
naive_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K})
|
||||
benchmark(
|
||||
knn_python_cuda_with_init,
|
||||
'KNN_CUDA_PYTHON',
|
||||
naive_kwargs,
|
||||
warmup_iters=1,
|
||||
)
|
||||
benchmark(
|
||||
knn_cuda_with_init,
|
||||
'KNN_CUDA',
|
||||
knn_kwargs,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
|
||||
def benchmark_knn_cpu() -> None:
|
||||
Ns = [1, 2]
|
||||
Ps = [256, 512]
|
||||
Ds = [3]
|
||||
Ks = [1, 2, 4]
|
||||
knn_kwargs, nn_kwargs = [], []
|
||||
for N, P, D, K in product(Ns, Ps, Ds, Ks):
|
||||
knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K})
|
||||
for N, P, D in product(Ns, Ps, Ds):
|
||||
nn_kwargs.append({'N': N, 'D': D, 'P': P})
|
||||
benchmark(
|
||||
knn_python_cpu_with_init,
|
||||
'KNN_CPU_PYTHON',
|
||||
knn_kwargs,
|
||||
warmup_iters=1,
|
||||
)
|
||||
benchmark(
|
||||
knn_cpu_with_init,
|
||||
'KNN_CPU_CPP',
|
||||
knn_kwargs,
|
||||
warmup_iters=1,
|
||||
)
|
||||
benchmark(
|
||||
nn_cpu_with_init,
|
||||
'NN_CPU_CPP',
|
||||
nn_kwargs,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
|
||||
def knn_cuda_with_init(N, D, P, K, v=-1):
|
||||
device = torch.device('cuda:0')
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def knn():
|
||||
_C.knn_points_idx(x, y, K, v)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return knn
|
||||
|
||||
|
||||
def knn_cpu_with_init(N, D, P, K):
|
||||
device = torch.device('cpu')
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
|
||||
def knn():
|
||||
_C.knn_points_idx(x, y, K, 0)
|
||||
|
||||
return knn
|
||||
|
||||
|
||||
def knn_python_cuda_with_init(N, D, P, K):
|
||||
device = torch.device('cuda')
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def knn():
|
||||
_knn_points_idx_naive(x, y, K)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return knn
|
||||
|
||||
|
||||
def knn_python_cpu_with_init(N, D, P, K):
|
||||
device = torch.device('cpu')
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
|
||||
def knn():
|
||||
_knn_points_idx_naive(x, y, K)
|
||||
|
||||
return knn
|
||||
|
||||
|
||||
def nn_cuda_with_init(N, D, P):
|
||||
device = torch.device('cuda')
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def knn():
|
||||
_C.nn_points_idx(x, y)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return knn
|
||||
|
||||
|
||||
def nn_cpu_with_init(N, D, P):
|
||||
device = torch.device('cpu')
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
|
||||
def knn():
|
||||
_C.nn_points_idx(x, y)
|
||||
|
||||
return knn
|
65
tests/test_knn.py
Normal file
65
tests/test_knn.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import unittest
|
||||
from itertools import product
|
||||
import torch
|
||||
|
||||
from pytorch3d.ops.knn import _knn_points_idx_naive, knn_points_idx
|
||||
|
||||
|
||||
class TestKNN(unittest.TestCase):
|
||||
def _check_knn_result(self, out1, out2, sorted):
|
||||
# When sorted=True, points should be sorted by distance and should
|
||||
# match between implementations. When sorted=False we we only want to
|
||||
# check that we got the same set of indices, so we sort the indices by
|
||||
# index value.
|
||||
idx1, dist1 = out1
|
||||
idx2, dist2 = out2
|
||||
if not sorted:
|
||||
idx1 = idx1.sort(dim=2).values
|
||||
idx2 = idx2.sort(dim=2).values
|
||||
dist1 = dist1.sort(dim=2).values
|
||||
dist2 = dist2.sort(dim=2).values
|
||||
if not torch.all(idx1 == idx2):
|
||||
print(idx1)
|
||||
print(idx2)
|
||||
self.assertTrue(torch.all(idx1 == idx2))
|
||||
self.assertTrue(torch.allclose(dist1, dist2))
|
||||
|
||||
def test_knn_vs_python_cpu(self):
|
||||
""" Test CPU output vs PyTorch implementation """
|
||||
device = torch.device('cpu')
|
||||
Ns = [1, 4]
|
||||
Ds = [2, 3]
|
||||
P1s = [1, 10, 101]
|
||||
P2s = [10, 101]
|
||||
Ks = [1, 3, 10]
|
||||
sorts = [True, False]
|
||||
factors = [Ns, Ds, P1s, P2s, Ks, sorts]
|
||||
for N, D, P1, P2, K, sort in product(*factors):
|
||||
x = torch.randn(N, P1, D, device=device)
|
||||
y = torch.randn(N, P2, D, device=device)
|
||||
out1 = _knn_points_idx_naive(x, y, K, sort)
|
||||
out2 = knn_points_idx(x, y, K, sort)
|
||||
self._check_knn_result(out1, out2, sort)
|
||||
|
||||
def test_knn_vs_python_cuda(self):
|
||||
""" Test CUDA output vs PyTorch implementation """
|
||||
device = torch.device('cuda')
|
||||
Ns = [1, 4]
|
||||
Ds = [2, 3, 8]
|
||||
P1s = [1, 8, 64, 128, 1001]
|
||||
P2s = [32, 128, 513]
|
||||
Ks = [1, 3, 10]
|
||||
sorts = [True, False]
|
||||
versions = [0, 1, 2, 3]
|
||||
factors = [Ns, Ds, P1s, P2s, Ks, sorts]
|
||||
for N, D, P1, P2, K, sort in product(*factors):
|
||||
x = torch.randn(N, P1, D, device=device)
|
||||
y = torch.randn(N, P2, D, device=device)
|
||||
out1 = _knn_points_idx_naive(x, y, K, sorted=sort)
|
||||
for version in versions:
|
||||
if version == 3 and K > 4:
|
||||
continue
|
||||
out2 = knn_points_idx(x, y, K, sort, version)
|
||||
self._check_knn_result(out1, out2, sort)
|
@ -1,6 +1,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import unittest
|
||||
from itertools import product
|
||||
import torch
|
||||
|
||||
from pytorch3d import _C
|
||||
|
Loading…
x
Reference in New Issue
Block a user