mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 14:50:36 +08:00
heterogenous KNN
Summary: Interface and working implementation of ragged KNN. Benchmarks (which aren't ragged) haven't slowed. New benchmark shows that ragged is faster than non-ragged of the same shape. Reviewed By: jcjohnson Differential Revision: D20696507 fbshipit-source-id: 21b80f71343a3475c8d3ee0ce2680f92f0fae4de
This commit is contained in:
committed by
Facebook GitHub Bot
parent
29b9c44c0a
commit
01b5f7b228
@@ -8,10 +8,20 @@
|
||||
#include "dispatch.cuh"
|
||||
#include "mink.cuh"
|
||||
|
||||
// A chunk of work is blocksize-many points of P1.
|
||||
// The number of potential chunks to do is N*(1+(P1-1)/blocksize)
|
||||
// call (1+(P1-1)/blocksize) chunks_per_cloud
|
||||
// These chunks are divided among the gridSize-many blocks.
|
||||
// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc .
|
||||
// In chunk i, we work on cloud i/chunks_per_cloud on points starting from
|
||||
// blocksize*(i%chunks_per_cloud).
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void KNearestNeighborKernelV0(
|
||||
const scalar_t* __restrict__ points1,
|
||||
const scalar_t* __restrict__ points2,
|
||||
const int64_t* __restrict__ lengths1,
|
||||
const int64_t* __restrict__ lengths2,
|
||||
scalar_t* __restrict__ dists,
|
||||
int64_t* __restrict__ idxs,
|
||||
const size_t N,
|
||||
@@ -19,18 +29,19 @@ __global__ void KNearestNeighborKernelV0(
|
||||
const size_t P2,
|
||||
const size_t D,
|
||||
const size_t K) {
|
||||
// Stupid version: Make each thread handle one query point and loop over
|
||||
// all P2 target points. There are N * P1 input points to handle, so
|
||||
// do a trivial parallelization over threads.
|
||||
// Store both dists and indices for knn in global memory.
|
||||
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int num_threads = blockDim.x * gridDim.x;
|
||||
for (int np = tid; np < N * P1; np += num_threads) {
|
||||
int n = np / P1;
|
||||
int p1 = np % P1;
|
||||
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
|
||||
const int64_t chunks_to_do = N * chunks_per_cloud;
|
||||
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
|
||||
const int64_t n = chunk / chunks_per_cloud;
|
||||
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
|
||||
int64_t p1 = start_point + threadIdx.x;
|
||||
if (p1 >= lengths1[n])
|
||||
continue;
|
||||
int offset = n * P1 * K + p1 * K;
|
||||
int64_t length2 = lengths2[n];
|
||||
MinK<scalar_t, int64_t> mink(dists + offset, idxs + offset, K);
|
||||
for (int p2 = 0; p2 < P2; ++p2) {
|
||||
for (int p2 = 0; p2 < length2; ++p2) {
|
||||
// Find the distance between points1[n, p1] and points[n, p2]
|
||||
scalar_t dist = 0;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
@@ -48,6 +59,8 @@ template <typename scalar_t, int64_t D>
|
||||
__global__ void KNearestNeighborKernelV1(
|
||||
const scalar_t* __restrict__ points1,
|
||||
const scalar_t* __restrict__ points2,
|
||||
const int64_t* __restrict__ lengths1,
|
||||
const int64_t* __restrict__ lengths2,
|
||||
scalar_t* __restrict__ dists,
|
||||
int64_t* __restrict__ idxs,
|
||||
const size_t N,
|
||||
@@ -58,18 +71,22 @@ __global__ void KNearestNeighborKernelV1(
|
||||
// so we can cache the current point in a thread-local array. We still store
|
||||
// the current best K dists and indices in global memory, so this should work
|
||||
// for very large K and fairly large D.
|
||||
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int num_threads = blockDim.x * gridDim.x;
|
||||
scalar_t cur_point[D];
|
||||
for (int np = tid; np < N * P1; np += num_threads) {
|
||||
int n = np / P1;
|
||||
int p1 = np % P1;
|
||||
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
|
||||
const int64_t chunks_to_do = N * chunks_per_cloud;
|
||||
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
|
||||
const int64_t n = chunk / chunks_per_cloud;
|
||||
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
|
||||
int64_t p1 = start_point + threadIdx.x;
|
||||
if (p1 >= lengths1[n])
|
||||
continue;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
cur_point[d] = points1[n * P1 * D + p1 * D + d];
|
||||
}
|
||||
int offset = n * P1 * K + p1 * K;
|
||||
int64_t length2 = lengths2[n];
|
||||
MinK<scalar_t, int64_t> mink(dists + offset, idxs + offset, K);
|
||||
for (int p2 = 0; p2 < P2; ++p2) {
|
||||
for (int p2 = 0; p2 < length2; ++p2) {
|
||||
// Find the distance between cur_point and points[n, p2]
|
||||
scalar_t dist = 0;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
@@ -89,14 +106,16 @@ struct KNearestNeighborV1Functor {
|
||||
size_t threads,
|
||||
const scalar_t* __restrict__ points1,
|
||||
const scalar_t* __restrict__ points2,
|
||||
const int64_t* __restrict__ lengths1,
|
||||
const int64_t* __restrict__ lengths2,
|
||||
scalar_t* __restrict__ dists,
|
||||
int64_t* __restrict__ idxs,
|
||||
const size_t N,
|
||||
const size_t P1,
|
||||
const size_t P2,
|
||||
const size_t K) {
|
||||
KNearestNeighborKernelV1<scalar_t, D>
|
||||
<<<blocks, threads>>>(points1, points2, dists, idxs, N, P1, P2, K);
|
||||
KNearestNeighborKernelV1<scalar_t, D><<<blocks, threads>>>(
|
||||
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -104,25 +123,31 @@ template <typename scalar_t, int64_t D, int64_t K>
|
||||
__global__ void KNearestNeighborKernelV2(
|
||||
const scalar_t* __restrict__ points1,
|
||||
const scalar_t* __restrict__ points2,
|
||||
const int64_t* __restrict__ lengths1,
|
||||
const int64_t* __restrict__ lengths2,
|
||||
scalar_t* __restrict__ dists,
|
||||
int64_t* __restrict__ idxs,
|
||||
const int64_t N,
|
||||
const int64_t P1,
|
||||
const int64_t P2) {
|
||||
// Same general implementation as V2, but also hoist K into a template arg.
|
||||
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int num_threads = blockDim.x * gridDim.x;
|
||||
scalar_t cur_point[D];
|
||||
scalar_t min_dists[K];
|
||||
int min_idxs[K];
|
||||
for (int np = tid; np < N * P1; np += num_threads) {
|
||||
int n = np / P1;
|
||||
int p1 = np % P1;
|
||||
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
|
||||
const int64_t chunks_to_do = N * chunks_per_cloud;
|
||||
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
|
||||
const int64_t n = chunk / chunks_per_cloud;
|
||||
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
|
||||
int64_t p1 = start_point + threadIdx.x;
|
||||
if (p1 >= lengths1[n])
|
||||
continue;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
cur_point[d] = points1[n * P1 * D + p1 * D + d];
|
||||
}
|
||||
int64_t length2 = lengths2[n];
|
||||
MinK<scalar_t, int> mink(min_dists, min_idxs, K);
|
||||
for (int p2 = 0; p2 < P2; ++p2) {
|
||||
for (int p2 = 0; p2 < length2; ++p2) {
|
||||
scalar_t dist = 0;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
int offset = n * P2 * D + p2 * D + d;
|
||||
@@ -146,13 +171,15 @@ struct KNearestNeighborKernelV2Functor {
|
||||
size_t threads,
|
||||
const scalar_t* __restrict__ points1,
|
||||
const scalar_t* __restrict__ points2,
|
||||
const int64_t* __restrict__ lengths1,
|
||||
const int64_t* __restrict__ lengths2,
|
||||
scalar_t* __restrict__ dists,
|
||||
int64_t* __restrict__ idxs,
|
||||
const int64_t N,
|
||||
const int64_t P1,
|
||||
const int64_t P2) {
|
||||
KNearestNeighborKernelV2<scalar_t, D, K>
|
||||
<<<blocks, threads>>>(points1, points2, dists, idxs, N, P1, P2);
|
||||
KNearestNeighborKernelV2<scalar_t, D, K><<<blocks, threads>>>(
|
||||
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -160,6 +187,8 @@ template <typename scalar_t, int D, int K>
|
||||
__global__ void KNearestNeighborKernelV3(
|
||||
const scalar_t* __restrict__ points1,
|
||||
const scalar_t* __restrict__ points2,
|
||||
const int64_t* __restrict__ lengths1,
|
||||
const int64_t* __restrict__ lengths2,
|
||||
scalar_t* __restrict__ dists,
|
||||
int64_t* __restrict__ idxs,
|
||||
const size_t N,
|
||||
@@ -169,19 +198,23 @@ __global__ void KNearestNeighborKernelV3(
|
||||
// Enabling sorting for this version leads to huge slowdowns; I suspect
|
||||
// that it forces min_dists into local memory rather than registers.
|
||||
// As a result this version is always unsorted.
|
||||
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int num_threads = blockDim.x * gridDim.x;
|
||||
scalar_t cur_point[D];
|
||||
scalar_t min_dists[K];
|
||||
int min_idxs[K];
|
||||
for (int np = tid; np < N * P1; np += num_threads) {
|
||||
int n = np / P1;
|
||||
int p1 = np % P1;
|
||||
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
|
||||
const int64_t chunks_to_do = N * chunks_per_cloud;
|
||||
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
|
||||
const int64_t n = chunk / chunks_per_cloud;
|
||||
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
|
||||
int64_t p1 = start_point + threadIdx.x;
|
||||
if (p1 >= lengths1[n])
|
||||
continue;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
cur_point[d] = points1[n * P1 * D + p1 * D + d];
|
||||
}
|
||||
int64_t length2 = lengths2[n];
|
||||
RegisterMinK<scalar_t, int, K> mink(min_dists, min_idxs);
|
||||
for (int p2 = 0; p2 < P2; ++p2) {
|
||||
for (int p2 = 0; p2 < length2; ++p2) {
|
||||
scalar_t dist = 0;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
int offset = n * P2 * D + p2 * D + d;
|
||||
@@ -205,13 +238,15 @@ struct KNearestNeighborKernelV3Functor {
|
||||
size_t threads,
|
||||
const scalar_t* __restrict__ points1,
|
||||
const scalar_t* __restrict__ points2,
|
||||
const int64_t* __restrict__ lengths1,
|
||||
const int64_t* __restrict__ lengths2,
|
||||
scalar_t* __restrict__ dists,
|
||||
int64_t* __restrict__ idxs,
|
||||
const size_t N,
|
||||
const size_t P1,
|
||||
const size_t P2) {
|
||||
KNearestNeighborKernelV3<scalar_t, D, K>
|
||||
<<<blocks, threads>>>(points1, points2, dists, idxs, N, P1, P2);
|
||||
KNearestNeighborKernelV3<scalar_t, D, K><<<blocks, threads>>>(
|
||||
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -257,6 +292,8 @@ int ChooseVersion(const int64_t D, const int64_t K) {
|
||||
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
const at::Tensor& p1,
|
||||
const at::Tensor& p2,
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
int K,
|
||||
int version) {
|
||||
const auto N = p1.size(0);
|
||||
@@ -267,8 +304,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
|
||||
AT_ASSERTM(p2.size(2) == D, "Point sets must have the same last dimension");
|
||||
auto long_dtype = p1.options().dtype(at::kLong);
|
||||
auto idxs = at::full({N, P1, K}, -1, long_dtype);
|
||||
auto dists = at::full({N, P1, K}, -1, p1.options());
|
||||
auto idxs = at::zeros({N, P1, K}, long_dtype);
|
||||
auto dists = at::zeros({N, P1, K}, p1.options());
|
||||
|
||||
if (version < 0) {
|
||||
version = ChooseVersion(D, K);
|
||||
@@ -294,6 +331,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
<<<blocks, threads>>>(
|
||||
p1.data_ptr<scalar_t>(),
|
||||
p2.data_ptr<scalar_t>(),
|
||||
lengths1.data_ptr<int64_t>(),
|
||||
lengths2.data_ptr<int64_t>(),
|
||||
dists.data_ptr<scalar_t>(),
|
||||
idxs.data_ptr<int64_t>(),
|
||||
N,
|
||||
@@ -314,6 +353,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
threads,
|
||||
p1.data_ptr<scalar_t>(),
|
||||
p2.data_ptr<scalar_t>(),
|
||||
lengths1.data_ptr<int64_t>(),
|
||||
lengths2.data_ptr<int64_t>(),
|
||||
dists.data_ptr<scalar_t>(),
|
||||
idxs.data_ptr<int64_t>(),
|
||||
N,
|
||||
@@ -336,6 +377,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
threads,
|
||||
p1.data_ptr<scalar_t>(),
|
||||
p2.data_ptr<scalar_t>(),
|
||||
lengths1.data_ptr<int64_t>(),
|
||||
lengths2.data_ptr<int64_t>(),
|
||||
dists.data_ptr<scalar_t>(),
|
||||
idxs.data_ptr<int64_t>(),
|
||||
N,
|
||||
@@ -357,6 +400,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
threads,
|
||||
p1.data_ptr<scalar_t>(),
|
||||
p2.data_ptr<scalar_t>(),
|
||||
lengths1.data_ptr<int64_t>(),
|
||||
lengths2.data_ptr<int64_t>(),
|
||||
dists.data_ptr<scalar_t>(),
|
||||
idxs.data_ptr<int64_t>(),
|
||||
N,
|
||||
|
||||
@@ -13,25 +13,38 @@
|
||||
// containing P1 points of dimension D.
|
||||
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
|
||||
// containing P2 points of dimension D.
|
||||
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
|
||||
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
|
||||
// K: int giving the number of nearest points to return.
|
||||
// sorted: bool telling whether to sort the K returned points by their
|
||||
// distance.
|
||||
// version: Integer telling which implementation to use.
|
||||
// TODO(jcjohns): Document this more, or maybe remove it before landing.
|
||||
//
|
||||
// Returns:
|
||||
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
|
||||
// p1_neighbor_idx[n, i, k] = j means that the kth nearest
|
||||
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
|
||||
// p1_neighbor_idx[n, i, k] = j means that the kth nearest
|
||||
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
|
||||
// It is padded with zeros so that it can be used easily in a later
|
||||
// gather() operation.
|
||||
//
|
||||
// p1_neighbor_dists: FloatTensor of shape (N, P1, K) containing the squared
|
||||
// distance from each point p1[n, p, :] to its K neighbors
|
||||
// p2[n, p1_neighbor_idx[n, p, k], :].
|
||||
|
||||
// CPU implementation.
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K);
|
||||
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
|
||||
const at::Tensor& p1,
|
||||
const at::Tensor& p2,
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
int K);
|
||||
|
||||
// CUDA implementation
|
||||
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
const at::Tensor& p1,
|
||||
const at::Tensor& p2,
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
int K,
|
||||
int version);
|
||||
|
||||
@@ -39,16 +52,18 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
|
||||
const at::Tensor& p1,
|
||||
const at::Tensor& p2,
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
int K,
|
||||
int version) {
|
||||
if (p1.is_cuda() || p2.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(p1);
|
||||
CHECK_CONTIGUOUS_CUDA(p2);
|
||||
return KNearestNeighborIdxCuda(p1, p2, K, version);
|
||||
return KNearestNeighborIdxCuda(p1, p2, lengths1, lengths2, K, version);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return KNearestNeighborIdxCpu(p1, p2, K);
|
||||
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, K);
|
||||
}
|
||||
|
||||
@@ -4,27 +4,35 @@
|
||||
#include <queue>
|
||||
#include <tuple>
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K) {
|
||||
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
|
||||
const at::Tensor& p1,
|
||||
const at::Tensor& p2,
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
int K) {
|
||||
const int N = p1.size(0);
|
||||
const int P1 = p1.size(1);
|
||||
const int D = p1.size(2);
|
||||
const int P2 = p2.size(1);
|
||||
|
||||
auto long_opts = p1.options().dtype(torch::kInt64);
|
||||
torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts);
|
||||
torch::Tensor idxs = torch::full({N, P1, K}, 0, long_opts);
|
||||
torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());
|
||||
|
||||
auto p1_a = p1.accessor<float, 3>();
|
||||
auto p2_a = p2.accessor<float, 3>();
|
||||
auto lengths1_a = lengths1.accessor<int64_t, 1>();
|
||||
auto lengths2_a = lengths2.accessor<int64_t, 1>();
|
||||
auto idxs_a = idxs.accessor<int64_t, 3>();
|
||||
auto dists_a = dists.accessor<float, 3>();
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int i1 = 0; i1 < P1; ++i1) {
|
||||
const int64_t length1 = lengths1_a[n];
|
||||
const int64_t length2 = lengths2_a[n];
|
||||
for (int64_t i1 = 0; i1 < length1; ++i1) {
|
||||
// Use a priority queue to store (distance, index) tuples.
|
||||
std::priority_queue<std::tuple<float, int>> q;
|
||||
for (int i2 = 0; i2 < P2; ++i2) {
|
||||
for (int64_t i2 = 0; i2 < length2; ++i2) {
|
||||
float dist = 0;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
|
||||
|
||||
Reference in New Issue
Block a user