mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-03 04:12:48 +08:00
heterogenous KNN
Summary: Interface and working implementation of ragged KNN. Benchmarks (which aren't ragged) haven't slowed. New benchmark shows that ragged is faster than non-ragged of the same shape. Reviewed By: jcjohnson Differential Revision: D20696507 fbshipit-source-id: 21b80f71343a3475c8d3ee0ce2680f92f0fae4de
This commit is contained in:
parent
29b9c44c0a
commit
01b5f7b228
@ -8,10 +8,20 @@
|
|||||||
#include "dispatch.cuh"
|
#include "dispatch.cuh"
|
||||||
#include "mink.cuh"
|
#include "mink.cuh"
|
||||||
|
|
||||||
|
// A chunk of work is blocksize-many points of P1.
|
||||||
|
// The number of potential chunks to do is N*(1+(P1-1)/blocksize)
|
||||||
|
// call (1+(P1-1)/blocksize) chunks_per_cloud
|
||||||
|
// These chunks are divided among the gridSize-many blocks.
|
||||||
|
// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc .
|
||||||
|
// In chunk i, we work on cloud i/chunks_per_cloud on points starting from
|
||||||
|
// blocksize*(i%chunks_per_cloud).
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void KNearestNeighborKernelV0(
|
__global__ void KNearestNeighborKernelV0(
|
||||||
const scalar_t* __restrict__ points1,
|
const scalar_t* __restrict__ points1,
|
||||||
const scalar_t* __restrict__ points2,
|
const scalar_t* __restrict__ points2,
|
||||||
|
const int64_t* __restrict__ lengths1,
|
||||||
|
const int64_t* __restrict__ lengths2,
|
||||||
scalar_t* __restrict__ dists,
|
scalar_t* __restrict__ dists,
|
||||||
int64_t* __restrict__ idxs,
|
int64_t* __restrict__ idxs,
|
||||||
const size_t N,
|
const size_t N,
|
||||||
@ -19,18 +29,19 @@ __global__ void KNearestNeighborKernelV0(
|
|||||||
const size_t P2,
|
const size_t P2,
|
||||||
const size_t D,
|
const size_t D,
|
||||||
const size_t K) {
|
const size_t K) {
|
||||||
// Stupid version: Make each thread handle one query point and loop over
|
|
||||||
// all P2 target points. There are N * P1 input points to handle, so
|
|
||||||
// do a trivial parallelization over threads.
|
|
||||||
// Store both dists and indices for knn in global memory.
|
// Store both dists and indices for knn in global memory.
|
||||||
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
|
||||||
const int num_threads = blockDim.x * gridDim.x;
|
const int64_t chunks_to_do = N * chunks_per_cloud;
|
||||||
for (int np = tid; np < N * P1; np += num_threads) {
|
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
|
||||||
int n = np / P1;
|
const int64_t n = chunk / chunks_per_cloud;
|
||||||
int p1 = np % P1;
|
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
|
||||||
|
int64_t p1 = start_point + threadIdx.x;
|
||||||
|
if (p1 >= lengths1[n])
|
||||||
|
continue;
|
||||||
int offset = n * P1 * K + p1 * K;
|
int offset = n * P1 * K + p1 * K;
|
||||||
|
int64_t length2 = lengths2[n];
|
||||||
MinK<scalar_t, int64_t> mink(dists + offset, idxs + offset, K);
|
MinK<scalar_t, int64_t> mink(dists + offset, idxs + offset, K);
|
||||||
for (int p2 = 0; p2 < P2; ++p2) {
|
for (int p2 = 0; p2 < length2; ++p2) {
|
||||||
// Find the distance between points1[n, p1] and points[n, p2]
|
// Find the distance between points1[n, p1] and points[n, p2]
|
||||||
scalar_t dist = 0;
|
scalar_t dist = 0;
|
||||||
for (int d = 0; d < D; ++d) {
|
for (int d = 0; d < D; ++d) {
|
||||||
@ -48,6 +59,8 @@ template <typename scalar_t, int64_t D>
|
|||||||
__global__ void KNearestNeighborKernelV1(
|
__global__ void KNearestNeighborKernelV1(
|
||||||
const scalar_t* __restrict__ points1,
|
const scalar_t* __restrict__ points1,
|
||||||
const scalar_t* __restrict__ points2,
|
const scalar_t* __restrict__ points2,
|
||||||
|
const int64_t* __restrict__ lengths1,
|
||||||
|
const int64_t* __restrict__ lengths2,
|
||||||
scalar_t* __restrict__ dists,
|
scalar_t* __restrict__ dists,
|
||||||
int64_t* __restrict__ idxs,
|
int64_t* __restrict__ idxs,
|
||||||
const size_t N,
|
const size_t N,
|
||||||
@ -58,18 +71,22 @@ __global__ void KNearestNeighborKernelV1(
|
|||||||
// so we can cache the current point in a thread-local array. We still store
|
// so we can cache the current point in a thread-local array. We still store
|
||||||
// the current best K dists and indices in global memory, so this should work
|
// the current best K dists and indices in global memory, so this should work
|
||||||
// for very large K and fairly large D.
|
// for very large K and fairly large D.
|
||||||
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
|
||||||
const int num_threads = blockDim.x * gridDim.x;
|
|
||||||
scalar_t cur_point[D];
|
scalar_t cur_point[D];
|
||||||
for (int np = tid; np < N * P1; np += num_threads) {
|
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
|
||||||
int n = np / P1;
|
const int64_t chunks_to_do = N * chunks_per_cloud;
|
||||||
int p1 = np % P1;
|
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
|
||||||
|
const int64_t n = chunk / chunks_per_cloud;
|
||||||
|
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
|
||||||
|
int64_t p1 = start_point + threadIdx.x;
|
||||||
|
if (p1 >= lengths1[n])
|
||||||
|
continue;
|
||||||
for (int d = 0; d < D; ++d) {
|
for (int d = 0; d < D; ++d) {
|
||||||
cur_point[d] = points1[n * P1 * D + p1 * D + d];
|
cur_point[d] = points1[n * P1 * D + p1 * D + d];
|
||||||
}
|
}
|
||||||
int offset = n * P1 * K + p1 * K;
|
int offset = n * P1 * K + p1 * K;
|
||||||
|
int64_t length2 = lengths2[n];
|
||||||
MinK<scalar_t, int64_t> mink(dists + offset, idxs + offset, K);
|
MinK<scalar_t, int64_t> mink(dists + offset, idxs + offset, K);
|
||||||
for (int p2 = 0; p2 < P2; ++p2) {
|
for (int p2 = 0; p2 < length2; ++p2) {
|
||||||
// Find the distance between cur_point and points[n, p2]
|
// Find the distance between cur_point and points[n, p2]
|
||||||
scalar_t dist = 0;
|
scalar_t dist = 0;
|
||||||
for (int d = 0; d < D; ++d) {
|
for (int d = 0; d < D; ++d) {
|
||||||
@ -89,14 +106,16 @@ struct KNearestNeighborV1Functor {
|
|||||||
size_t threads,
|
size_t threads,
|
||||||
const scalar_t* __restrict__ points1,
|
const scalar_t* __restrict__ points1,
|
||||||
const scalar_t* __restrict__ points2,
|
const scalar_t* __restrict__ points2,
|
||||||
|
const int64_t* __restrict__ lengths1,
|
||||||
|
const int64_t* __restrict__ lengths2,
|
||||||
scalar_t* __restrict__ dists,
|
scalar_t* __restrict__ dists,
|
||||||
int64_t* __restrict__ idxs,
|
int64_t* __restrict__ idxs,
|
||||||
const size_t N,
|
const size_t N,
|
||||||
const size_t P1,
|
const size_t P1,
|
||||||
const size_t P2,
|
const size_t P2,
|
||||||
const size_t K) {
|
const size_t K) {
|
||||||
KNearestNeighborKernelV1<scalar_t, D>
|
KNearestNeighborKernelV1<scalar_t, D><<<blocks, threads>>>(
|
||||||
<<<blocks, threads>>>(points1, points2, dists, idxs, N, P1, P2, K);
|
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -104,25 +123,31 @@ template <typename scalar_t, int64_t D, int64_t K>
|
|||||||
__global__ void KNearestNeighborKernelV2(
|
__global__ void KNearestNeighborKernelV2(
|
||||||
const scalar_t* __restrict__ points1,
|
const scalar_t* __restrict__ points1,
|
||||||
const scalar_t* __restrict__ points2,
|
const scalar_t* __restrict__ points2,
|
||||||
|
const int64_t* __restrict__ lengths1,
|
||||||
|
const int64_t* __restrict__ lengths2,
|
||||||
scalar_t* __restrict__ dists,
|
scalar_t* __restrict__ dists,
|
||||||
int64_t* __restrict__ idxs,
|
int64_t* __restrict__ idxs,
|
||||||
const int64_t N,
|
const int64_t N,
|
||||||
const int64_t P1,
|
const int64_t P1,
|
||||||
const int64_t P2) {
|
const int64_t P2) {
|
||||||
// Same general implementation as V2, but also hoist K into a template arg.
|
// Same general implementation as V2, but also hoist K into a template arg.
|
||||||
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
|
||||||
const int num_threads = blockDim.x * gridDim.x;
|
|
||||||
scalar_t cur_point[D];
|
scalar_t cur_point[D];
|
||||||
scalar_t min_dists[K];
|
scalar_t min_dists[K];
|
||||||
int min_idxs[K];
|
int min_idxs[K];
|
||||||
for (int np = tid; np < N * P1; np += num_threads) {
|
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
|
||||||
int n = np / P1;
|
const int64_t chunks_to_do = N * chunks_per_cloud;
|
||||||
int p1 = np % P1;
|
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
|
||||||
|
const int64_t n = chunk / chunks_per_cloud;
|
||||||
|
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
|
||||||
|
int64_t p1 = start_point + threadIdx.x;
|
||||||
|
if (p1 >= lengths1[n])
|
||||||
|
continue;
|
||||||
for (int d = 0; d < D; ++d) {
|
for (int d = 0; d < D; ++d) {
|
||||||
cur_point[d] = points1[n * P1 * D + p1 * D + d];
|
cur_point[d] = points1[n * P1 * D + p1 * D + d];
|
||||||
}
|
}
|
||||||
|
int64_t length2 = lengths2[n];
|
||||||
MinK<scalar_t, int> mink(min_dists, min_idxs, K);
|
MinK<scalar_t, int> mink(min_dists, min_idxs, K);
|
||||||
for (int p2 = 0; p2 < P2; ++p2) {
|
for (int p2 = 0; p2 < length2; ++p2) {
|
||||||
scalar_t dist = 0;
|
scalar_t dist = 0;
|
||||||
for (int d = 0; d < D; ++d) {
|
for (int d = 0; d < D; ++d) {
|
||||||
int offset = n * P2 * D + p2 * D + d;
|
int offset = n * P2 * D + p2 * D + d;
|
||||||
@ -146,13 +171,15 @@ struct KNearestNeighborKernelV2Functor {
|
|||||||
size_t threads,
|
size_t threads,
|
||||||
const scalar_t* __restrict__ points1,
|
const scalar_t* __restrict__ points1,
|
||||||
const scalar_t* __restrict__ points2,
|
const scalar_t* __restrict__ points2,
|
||||||
|
const int64_t* __restrict__ lengths1,
|
||||||
|
const int64_t* __restrict__ lengths2,
|
||||||
scalar_t* __restrict__ dists,
|
scalar_t* __restrict__ dists,
|
||||||
int64_t* __restrict__ idxs,
|
int64_t* __restrict__ idxs,
|
||||||
const int64_t N,
|
const int64_t N,
|
||||||
const int64_t P1,
|
const int64_t P1,
|
||||||
const int64_t P2) {
|
const int64_t P2) {
|
||||||
KNearestNeighborKernelV2<scalar_t, D, K>
|
KNearestNeighborKernelV2<scalar_t, D, K><<<blocks, threads>>>(
|
||||||
<<<blocks, threads>>>(points1, points2, dists, idxs, N, P1, P2);
|
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -160,6 +187,8 @@ template <typename scalar_t, int D, int K>
|
|||||||
__global__ void KNearestNeighborKernelV3(
|
__global__ void KNearestNeighborKernelV3(
|
||||||
const scalar_t* __restrict__ points1,
|
const scalar_t* __restrict__ points1,
|
||||||
const scalar_t* __restrict__ points2,
|
const scalar_t* __restrict__ points2,
|
||||||
|
const int64_t* __restrict__ lengths1,
|
||||||
|
const int64_t* __restrict__ lengths2,
|
||||||
scalar_t* __restrict__ dists,
|
scalar_t* __restrict__ dists,
|
||||||
int64_t* __restrict__ idxs,
|
int64_t* __restrict__ idxs,
|
||||||
const size_t N,
|
const size_t N,
|
||||||
@ -169,19 +198,23 @@ __global__ void KNearestNeighborKernelV3(
|
|||||||
// Enabling sorting for this version leads to huge slowdowns; I suspect
|
// Enabling sorting for this version leads to huge slowdowns; I suspect
|
||||||
// that it forces min_dists into local memory rather than registers.
|
// that it forces min_dists into local memory rather than registers.
|
||||||
// As a result this version is always unsorted.
|
// As a result this version is always unsorted.
|
||||||
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
|
||||||
const int num_threads = blockDim.x * gridDim.x;
|
|
||||||
scalar_t cur_point[D];
|
scalar_t cur_point[D];
|
||||||
scalar_t min_dists[K];
|
scalar_t min_dists[K];
|
||||||
int min_idxs[K];
|
int min_idxs[K];
|
||||||
for (int np = tid; np < N * P1; np += num_threads) {
|
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
|
||||||
int n = np / P1;
|
const int64_t chunks_to_do = N * chunks_per_cloud;
|
||||||
int p1 = np % P1;
|
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
|
||||||
|
const int64_t n = chunk / chunks_per_cloud;
|
||||||
|
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
|
||||||
|
int64_t p1 = start_point + threadIdx.x;
|
||||||
|
if (p1 >= lengths1[n])
|
||||||
|
continue;
|
||||||
for (int d = 0; d < D; ++d) {
|
for (int d = 0; d < D; ++d) {
|
||||||
cur_point[d] = points1[n * P1 * D + p1 * D + d];
|
cur_point[d] = points1[n * P1 * D + p1 * D + d];
|
||||||
}
|
}
|
||||||
|
int64_t length2 = lengths2[n];
|
||||||
RegisterMinK<scalar_t, int, K> mink(min_dists, min_idxs);
|
RegisterMinK<scalar_t, int, K> mink(min_dists, min_idxs);
|
||||||
for (int p2 = 0; p2 < P2; ++p2) {
|
for (int p2 = 0; p2 < length2; ++p2) {
|
||||||
scalar_t dist = 0;
|
scalar_t dist = 0;
|
||||||
for (int d = 0; d < D; ++d) {
|
for (int d = 0; d < D; ++d) {
|
||||||
int offset = n * P2 * D + p2 * D + d;
|
int offset = n * P2 * D + p2 * D + d;
|
||||||
@ -205,13 +238,15 @@ struct KNearestNeighborKernelV3Functor {
|
|||||||
size_t threads,
|
size_t threads,
|
||||||
const scalar_t* __restrict__ points1,
|
const scalar_t* __restrict__ points1,
|
||||||
const scalar_t* __restrict__ points2,
|
const scalar_t* __restrict__ points2,
|
||||||
|
const int64_t* __restrict__ lengths1,
|
||||||
|
const int64_t* __restrict__ lengths2,
|
||||||
scalar_t* __restrict__ dists,
|
scalar_t* __restrict__ dists,
|
||||||
int64_t* __restrict__ idxs,
|
int64_t* __restrict__ idxs,
|
||||||
const size_t N,
|
const size_t N,
|
||||||
const size_t P1,
|
const size_t P1,
|
||||||
const size_t P2) {
|
const size_t P2) {
|
||||||
KNearestNeighborKernelV3<scalar_t, D, K>
|
KNearestNeighborKernelV3<scalar_t, D, K><<<blocks, threads>>>(
|
||||||
<<<blocks, threads>>>(points1, points2, dists, idxs, N, P1, P2);
|
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -257,6 +292,8 @@ int ChooseVersion(const int64_t D, const int64_t K) {
|
|||||||
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||||
const at::Tensor& p1,
|
const at::Tensor& p1,
|
||||||
const at::Tensor& p2,
|
const at::Tensor& p2,
|
||||||
|
const at::Tensor& lengths1,
|
||||||
|
const at::Tensor& lengths2,
|
||||||
int K,
|
int K,
|
||||||
int version) {
|
int version) {
|
||||||
const auto N = p1.size(0);
|
const auto N = p1.size(0);
|
||||||
@ -267,8 +304,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
|||||||
|
|
||||||
AT_ASSERTM(p2.size(2) == D, "Point sets must have the same last dimension");
|
AT_ASSERTM(p2.size(2) == D, "Point sets must have the same last dimension");
|
||||||
auto long_dtype = p1.options().dtype(at::kLong);
|
auto long_dtype = p1.options().dtype(at::kLong);
|
||||||
auto idxs = at::full({N, P1, K}, -1, long_dtype);
|
auto idxs = at::zeros({N, P1, K}, long_dtype);
|
||||||
auto dists = at::full({N, P1, K}, -1, p1.options());
|
auto dists = at::zeros({N, P1, K}, p1.options());
|
||||||
|
|
||||||
if (version < 0) {
|
if (version < 0) {
|
||||||
version = ChooseVersion(D, K);
|
version = ChooseVersion(D, K);
|
||||||
@ -294,6 +331,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
|||||||
<<<blocks, threads>>>(
|
<<<blocks, threads>>>(
|
||||||
p1.data_ptr<scalar_t>(),
|
p1.data_ptr<scalar_t>(),
|
||||||
p2.data_ptr<scalar_t>(),
|
p2.data_ptr<scalar_t>(),
|
||||||
|
lengths1.data_ptr<int64_t>(),
|
||||||
|
lengths2.data_ptr<int64_t>(),
|
||||||
dists.data_ptr<scalar_t>(),
|
dists.data_ptr<scalar_t>(),
|
||||||
idxs.data_ptr<int64_t>(),
|
idxs.data_ptr<int64_t>(),
|
||||||
N,
|
N,
|
||||||
@ -314,6 +353,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
|||||||
threads,
|
threads,
|
||||||
p1.data_ptr<scalar_t>(),
|
p1.data_ptr<scalar_t>(),
|
||||||
p2.data_ptr<scalar_t>(),
|
p2.data_ptr<scalar_t>(),
|
||||||
|
lengths1.data_ptr<int64_t>(),
|
||||||
|
lengths2.data_ptr<int64_t>(),
|
||||||
dists.data_ptr<scalar_t>(),
|
dists.data_ptr<scalar_t>(),
|
||||||
idxs.data_ptr<int64_t>(),
|
idxs.data_ptr<int64_t>(),
|
||||||
N,
|
N,
|
||||||
@ -336,6 +377,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
|||||||
threads,
|
threads,
|
||||||
p1.data_ptr<scalar_t>(),
|
p1.data_ptr<scalar_t>(),
|
||||||
p2.data_ptr<scalar_t>(),
|
p2.data_ptr<scalar_t>(),
|
||||||
|
lengths1.data_ptr<int64_t>(),
|
||||||
|
lengths2.data_ptr<int64_t>(),
|
||||||
dists.data_ptr<scalar_t>(),
|
dists.data_ptr<scalar_t>(),
|
||||||
idxs.data_ptr<int64_t>(),
|
idxs.data_ptr<int64_t>(),
|
||||||
N,
|
N,
|
||||||
@ -357,6 +400,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
|||||||
threads,
|
threads,
|
||||||
p1.data_ptr<scalar_t>(),
|
p1.data_ptr<scalar_t>(),
|
||||||
p2.data_ptr<scalar_t>(),
|
p2.data_ptr<scalar_t>(),
|
||||||
|
lengths1.data_ptr<int64_t>(),
|
||||||
|
lengths2.data_ptr<int64_t>(),
|
||||||
dists.data_ptr<scalar_t>(),
|
dists.data_ptr<scalar_t>(),
|
||||||
idxs.data_ptr<int64_t>(),
|
idxs.data_ptr<int64_t>(),
|
||||||
N,
|
N,
|
||||||
|
@ -13,25 +13,38 @@
|
|||||||
// containing P1 points of dimension D.
|
// containing P1 points of dimension D.
|
||||||
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
|
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
|
||||||
// containing P2 points of dimension D.
|
// containing P2 points of dimension D.
|
||||||
|
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
|
||||||
|
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
|
||||||
// K: int giving the number of nearest points to return.
|
// K: int giving the number of nearest points to return.
|
||||||
// sorted: bool telling whether to sort the K returned points by their
|
// sorted: bool telling whether to sort the K returned points by their
|
||||||
// distance.
|
// distance.
|
||||||
// version: Integer telling which implementation to use.
|
// version: Integer telling which implementation to use.
|
||||||
// TODO(jcjohns): Document this more, or maybe remove it before landing.
|
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
|
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
|
||||||
// p1_neighbor_idx[n, i, k] = j means that the kth nearest
|
// p1_neighbor_idx[n, i, k] = j means that the kth nearest
|
||||||
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
|
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
|
||||||
|
// It is padded with zeros so that it can be used easily in a later
|
||||||
|
// gather() operation.
|
||||||
|
//
|
||||||
|
// p1_neighbor_dists: FloatTensor of shape (N, P1, K) containing the squared
|
||||||
|
// distance from each point p1[n, p, :] to its K neighbors
|
||||||
|
// p2[n, p1_neighbor_idx[n, p, k], :].
|
||||||
|
|
||||||
// CPU implementation.
|
// CPU implementation.
|
||||||
std::tuple<at::Tensor, at::Tensor>
|
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
|
||||||
KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K);
|
const at::Tensor& p1,
|
||||||
|
const at::Tensor& p2,
|
||||||
|
const at::Tensor& lengths1,
|
||||||
|
const at::Tensor& lengths2,
|
||||||
|
int K);
|
||||||
|
|
||||||
// CUDA implementation
|
// CUDA implementation
|
||||||
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||||
const at::Tensor& p1,
|
const at::Tensor& p1,
|
||||||
const at::Tensor& p2,
|
const at::Tensor& p2,
|
||||||
|
const at::Tensor& lengths1,
|
||||||
|
const at::Tensor& lengths2,
|
||||||
int K,
|
int K,
|
||||||
int version);
|
int version);
|
||||||
|
|
||||||
@ -39,16 +52,18 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
|||||||
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
|
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
|
||||||
const at::Tensor& p1,
|
const at::Tensor& p1,
|
||||||
const at::Tensor& p2,
|
const at::Tensor& p2,
|
||||||
|
const at::Tensor& lengths1,
|
||||||
|
const at::Tensor& lengths2,
|
||||||
int K,
|
int K,
|
||||||
int version) {
|
int version) {
|
||||||
if (p1.is_cuda() || p2.is_cuda()) {
|
if (p1.is_cuda() || p2.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(p1);
|
CHECK_CONTIGUOUS_CUDA(p1);
|
||||||
CHECK_CONTIGUOUS_CUDA(p2);
|
CHECK_CONTIGUOUS_CUDA(p2);
|
||||||
return KNearestNeighborIdxCuda(p1, p2, K, version);
|
return KNearestNeighborIdxCuda(p1, p2, lengths1, lengths2, K, version);
|
||||||
#else
|
#else
|
||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
return KNearestNeighborIdxCpu(p1, p2, K);
|
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, K);
|
||||||
}
|
}
|
||||||
|
@ -4,27 +4,35 @@
|
|||||||
#include <queue>
|
#include <queue>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
|
||||||
std::tuple<at::Tensor, at::Tensor>
|
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
|
||||||
KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K) {
|
const at::Tensor& p1,
|
||||||
|
const at::Tensor& p2,
|
||||||
|
const at::Tensor& lengths1,
|
||||||
|
const at::Tensor& lengths2,
|
||||||
|
int K) {
|
||||||
const int N = p1.size(0);
|
const int N = p1.size(0);
|
||||||
const int P1 = p1.size(1);
|
const int P1 = p1.size(1);
|
||||||
const int D = p1.size(2);
|
const int D = p1.size(2);
|
||||||
const int P2 = p2.size(1);
|
const int P2 = p2.size(1);
|
||||||
|
|
||||||
auto long_opts = p1.options().dtype(torch::kInt64);
|
auto long_opts = p1.options().dtype(torch::kInt64);
|
||||||
torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts);
|
torch::Tensor idxs = torch::full({N, P1, K}, 0, long_opts);
|
||||||
torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());
|
torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());
|
||||||
|
|
||||||
auto p1_a = p1.accessor<float, 3>();
|
auto p1_a = p1.accessor<float, 3>();
|
||||||
auto p2_a = p2.accessor<float, 3>();
|
auto p2_a = p2.accessor<float, 3>();
|
||||||
|
auto lengths1_a = lengths1.accessor<int64_t, 1>();
|
||||||
|
auto lengths2_a = lengths2.accessor<int64_t, 1>();
|
||||||
auto idxs_a = idxs.accessor<int64_t, 3>();
|
auto idxs_a = idxs.accessor<int64_t, 3>();
|
||||||
auto dists_a = dists.accessor<float, 3>();
|
auto dists_a = dists.accessor<float, 3>();
|
||||||
|
|
||||||
for (int n = 0; n < N; ++n) {
|
for (int n = 0; n < N; ++n) {
|
||||||
for (int i1 = 0; i1 < P1; ++i1) {
|
const int64_t length1 = lengths1_a[n];
|
||||||
|
const int64_t length2 = lengths2_a[n];
|
||||||
|
for (int64_t i1 = 0; i1 < length1; ++i1) {
|
||||||
// Use a priority queue to store (distance, index) tuples.
|
// Use a priority queue to store (distance, index) tuples.
|
||||||
std::priority_queue<std::tuple<float, int>> q;
|
std::priority_queue<std::tuple<float, int>> q;
|
||||||
for (int i2 = 0; i2 < P2; ++i2) {
|
for (int64_t i2 = 0; i2 < length2; ++i2) {
|
||||||
float dist = 0;
|
float dist = 0;
|
||||||
for (int d = 0; d < D; ++d) {
|
for (int d = 0; d < D; ++d) {
|
||||||
float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
|
float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
|
||||||
|
@ -4,37 +4,80 @@ import torch
|
|||||||
from pytorch3d import _C
|
from pytorch3d import _C
|
||||||
|
|
||||||
|
|
||||||
def knn_points_idx(p1, p2, K, sorted=False, version=-1):
|
def knn_points_idx(
|
||||||
|
p1,
|
||||||
|
p2,
|
||||||
|
K: int,
|
||||||
|
lengths1=None,
|
||||||
|
lengths2=None,
|
||||||
|
sorted: bool = False,
|
||||||
|
version: int = -1,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
K-Nearest neighbors on point clouds.
|
K-Nearest neighbors on point clouds.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
p1: Tensor of shape (N, P1, D) giving a batch of point clouds, each
|
p1: Tensor of shape (N, P1, D) giving a batch of point clouds, each
|
||||||
containing P1 points of dimension D.
|
containing up to P1 points of dimension D.
|
||||||
p2: Tensor of shape (N, P2, D) giving a batch of point clouds, each
|
p2: Tensor of shape (N, P2, D) giving a batch of point clouds, each
|
||||||
containing P2 points of dimension D.
|
containing up to P2 points of dimension D.
|
||||||
K: Integer giving the number of nearest neighbors to return
|
K: Integer giving the number of nearest neighbors to return.
|
||||||
|
lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
|
||||||
|
length of each pointcloud in p1. Or None to indicate that every cloud has
|
||||||
|
length P1.
|
||||||
|
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
|
||||||
|
length of each pointcloud in p2. Or None to indicate that every cloud has
|
||||||
|
length P2.
|
||||||
sorted: Whether to sort the resulting points.
|
sorted: Whether to sort the resulting points.
|
||||||
version: Which KNN implementation to use in the backend. If version=-1,
|
version: Which KNN implementation to use in the backend. If version=-1,
|
||||||
the correct implementation is selected based on the shapes of
|
the correct implementation is selected based on the shapes of the inputs.
|
||||||
the inputs.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
idx: LongTensor of shape (N, P1, K) giving the indices of the
|
p1_neighbor_idx: LongTensor of shape (N, P1, K) giving the indices of the
|
||||||
K nearest neighbors from points in p1 to points in p2.
|
K nearest neighbors from points in p1 to points in p2.
|
||||||
Concretely, if idx[n, i, k] = j then p2[n, j] is one of the K
|
Concretely, if idx[n, i, k] = j then p2[n, j] is one of the K nearest
|
||||||
nearest neighbor to p1[n, i] in p2[n]. If sorted=True, then
|
neighbors to p1[n, i] in p2[n]. If sorted=True, then p2[n, j] is the kth
|
||||||
p2[n, j] is the kth nearest neighbor to p1[n, i].
|
nearest neighbor to p1[n, i]. This is padded with zeros both where a cloud
|
||||||
|
in p2 has fewer than K points and where a cloud in p1 has fewer than P1
|
||||||
|
points.
|
||||||
|
If you want an (N, P1, K, D) tensor of the actual points, you can get it
|
||||||
|
using
|
||||||
|
p2[:, :, None].expand(-1, -1, K, -1).gather(1,
|
||||||
|
x_idx[:, :, :, None].expand(-1, -1, -1, D)
|
||||||
|
)
|
||||||
|
If K=1 and you want an (N, P1, D) tensor of the actual points, use
|
||||||
|
p2.gather(1, x_idx.expand(-1, -1, D))
|
||||||
|
|
||||||
|
p1_neighbor_dists: Tensor of shape (N, P1, K) giving the squared distances to
|
||||||
|
the nearest neighbors. This is padded with zeros both where a cloud in p2
|
||||||
|
has fewer than K points and where a cloud in p1 has fewer than P1 points.
|
||||||
|
Warning: this is calculated outside of the autograd framework.
|
||||||
"""
|
"""
|
||||||
idx, dists = _C.knn_points_idx(p1, p2, K, version)
|
P1 = p1.shape[1]
|
||||||
|
P2 = p2.shape[1]
|
||||||
|
if lengths1 is None:
|
||||||
|
lengths1 = torch.full((p1.shape[0],), P1, dtype=torch.int64, device=p1.device)
|
||||||
|
if lengths2 is None:
|
||||||
|
lengths2 = torch.full((p1.shape[0],), P2, dtype=torch.int64, device=p1.device)
|
||||||
|
idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, K, version)
|
||||||
if sorted:
|
if sorted:
|
||||||
dists, sort_idx = dists.sort(dim=2)
|
if lengths2.min() < K:
|
||||||
|
device = dists.device
|
||||||
|
mask1 = lengths2[:, None] <= torch.arange(K, device=device)[None]
|
||||||
|
# mask1 has shape [N, K], true where dists irrelevant
|
||||||
|
mask2 = mask1[:, None].expand(-1, P1, -1)
|
||||||
|
# mask2 has shape [N, P1, K], true where dists irrelevant
|
||||||
|
dists[mask2] = float("inf")
|
||||||
|
dists, sort_idx = dists.sort(dim=2)
|
||||||
|
dists[mask2] = 0
|
||||||
|
else:
|
||||||
|
dists, sort_idx = dists.sort(dim=2)
|
||||||
idx = idx.gather(2, sort_idx)
|
idx = idx.gather(2, sort_idx)
|
||||||
return idx, dists
|
return idx, dists
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _knn_points_idx_naive(p1, p2, K, sorted=False) -> torch.Tensor:
|
def _knn_points_idx_naive(p1, p2, K: int, lengths1, lengths2) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Naive PyTorch implementation of K-Nearest Neighbors.
|
Naive PyTorch implementation of K-Nearest Neighbors.
|
||||||
|
|
||||||
@ -43,25 +86,67 @@ def _knn_points_idx_naive(p1, p2, K, sorted=False) -> torch.Tensor:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
p1: Tensor of shape (N, P1, D) giving a batch of point clouds, each
|
p1: Tensor of shape (N, P1, D) giving a batch of point clouds, each
|
||||||
containing P1 points of dimension D.
|
containing up to P1 points of dimension D.
|
||||||
p2: Tensor of shape (N, P2, D) giving a batch of point clouds, each
|
p2: Tensor of shape (N, P2, D) giving a batch of point clouds, each
|
||||||
containing P2 points of dimension D.
|
containing up to P2 points of dimension D.
|
||||||
K: Integer giving the number of nearest neighbors to return
|
K: Integer giving the number of nearest neighbors to return.
|
||||||
sorted: Whether to sort the resulting points.
|
lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
|
||||||
|
length of each pointcloud in p1. Or None to indicate that every cloud has
|
||||||
|
length P1.
|
||||||
|
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
|
||||||
|
length of each pointcloud in p2. Or None to indicate that every cloud has
|
||||||
|
length P2.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
idx: LongTensor of shape (N, P1, K) giving the indices of the
|
idx: LongTensor of shape (N, P1, K) giving the indices of the
|
||||||
K nearest neighbors from points in p1 to points in p2.
|
K nearest neighbors from points in p1 to points in p2.
|
||||||
Concretely, if idx[n, i, k] = j then p2[n, j] is one of the K
|
Concretely, if idx[n, i, k] = j then p2[n, j] is the kth nearest neighbor
|
||||||
nearest neighbor to p1[n, i] in p2[n]. If sorted=True, then
|
to p1[n, i]. This is padded with zeros both where a cloud in p2 has fewer
|
||||||
p2[n, j] is the kth nearest neighbor to p1[n, i].
|
than K points and where a cloud in p1 has fewer than P1 points.
|
||||||
dists: Tensor of shape (N, P1, K) giving the distances to the nearest
|
dists: Tensor of shape (N, P1, K) giving the squared distances to the nearest
|
||||||
neighbors.
|
neighbors. This is padded with zeros both where a cloud in p2 has fewer than
|
||||||
|
K points and where a cloud in p1 has fewer than P1 points.
|
||||||
"""
|
"""
|
||||||
N, P1, D = p1.shape
|
N, P1, D = p1.shape
|
||||||
_N, P2, _D = p2.shape
|
_N, P2, _D = p2.shape
|
||||||
|
|
||||||
assert N == _N and D == _D
|
assert N == _N and D == _D
|
||||||
diffs = p1.view(N, P1, 1, D) - p2.view(N, 1, P2, D)
|
|
||||||
|
if lengths1 is None:
|
||||||
|
lengths1 = torch.full((N,), P1, dtype=torch.int64, device=p1.device)
|
||||||
|
if lengths2 is None:
|
||||||
|
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device)
|
||||||
|
|
||||||
|
p1_copy = p1.clone()
|
||||||
|
p2_copy = p2.clone()
|
||||||
|
|
||||||
|
# We pad the values with infinities so that the smallest differences are
|
||||||
|
# among actual points.
|
||||||
|
inf = float("inf")
|
||||||
|
p1_mask = torch.arange(P1, device=p1.device)[None] >= lengths1[:, None]
|
||||||
|
p1_copy[p1_mask] = inf
|
||||||
|
p2_copy[torch.arange(P2, device=p1.device)[None] >= lengths2[:, None]] = -inf
|
||||||
|
|
||||||
|
# view is safe here: we are merely adding extra dimensions of length 1
|
||||||
|
diffs = p1_copy.view(N, P1, 1, D) - p2_copy.view(N, 1, P2, D)
|
||||||
dists2 = (diffs * diffs).sum(dim=3)
|
dists2 = (diffs * diffs).sum(dim=3)
|
||||||
out = dists2.topk(K, dim=2, largest=False, sorted=sorted)
|
|
||||||
return out.indices, out.values
|
# We always sort, because this works well with padding.
|
||||||
|
out = dists2.topk(min(K, P2), dim=2, largest=False, sorted=True)
|
||||||
|
|
||||||
|
out_indices = out.indices
|
||||||
|
out_values = out.values
|
||||||
|
|
||||||
|
if P2 < K:
|
||||||
|
# Need to add padding
|
||||||
|
pad_shape = (N, P1, K - P2)
|
||||||
|
out_indices = torch.cat([out_indices, out_indices.new_zeros(pad_shape)], 2)
|
||||||
|
out_values = torch.cat([out_values, out_values.new_zeros(pad_shape)], 2)
|
||||||
|
|
||||||
|
K_mask = torch.arange(K, device=p1.device)[None] >= lengths2[:, None]
|
||||||
|
# Create a combined mask for where the points in p1 are padded
|
||||||
|
# or the corresponding p2 has fewer than K points.
|
||||||
|
p1_K_mask = p1_mask[:, :, None] | K_mask[:, None, :]
|
||||||
|
out_indices[p1_K_mask] = 0
|
||||||
|
out_values[p1_K_mask] = 0
|
||||||
|
return out_indices, out_values
|
||||||
|
@ -13,6 +13,7 @@ def bm_knn() -> None:
|
|||||||
benchmark_knn_cpu()
|
benchmark_knn_cpu()
|
||||||
benchmark_knn_cuda_vs_naive()
|
benchmark_knn_cuda_vs_naive()
|
||||||
benchmark_knn_cuda_versions()
|
benchmark_knn_cuda_versions()
|
||||||
|
benchmark_knn_cuda_versions_ragged()
|
||||||
|
|
||||||
|
|
||||||
def benchmark_knn_cuda_versions() -> None:
|
def benchmark_knn_cuda_versions() -> None:
|
||||||
@ -36,6 +37,25 @@ def benchmark_knn_cuda_versions() -> None:
|
|||||||
benchmark(nn_cuda_with_init, "NN_CUDA", nn_kwargs, warmup_iters=1)
|
benchmark(nn_cuda_with_init, "NN_CUDA", nn_kwargs, warmup_iters=1)
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_knn_cuda_versions_ragged() -> None:
|
||||||
|
# Compare our different KNN implementations,
|
||||||
|
# and also compare against our existing 1-NN
|
||||||
|
Ns = [8]
|
||||||
|
Ps = [4096, 16384]
|
||||||
|
Ds = [3]
|
||||||
|
Ks = [1, 4, 16, 64]
|
||||||
|
versions = [0, 1, 2, 3]
|
||||||
|
knn_kwargs = []
|
||||||
|
for N, P, D, K, version in product(Ns, Ps, Ds, Ks, versions):
|
||||||
|
if version == 2 and K > 32:
|
||||||
|
continue
|
||||||
|
if version == 3 and K > 4:
|
||||||
|
continue
|
||||||
|
knn_kwargs.append({"N": N, "D": D, "P": P, "K": K, "v": version})
|
||||||
|
benchmark(knn_cuda_with_init, "KNN_CUDA_COMPARISON", knn_kwargs, warmup_iters=1)
|
||||||
|
benchmark(knn_cuda_ragged, "KNN_CUDA_RAGGED", knn_kwargs, warmup_iters=1)
|
||||||
|
|
||||||
|
|
||||||
def benchmark_knn_cuda_vs_naive() -> None:
|
def benchmark_knn_cuda_vs_naive() -> None:
|
||||||
# Compare against naive pytorch version of KNN
|
# Compare against naive pytorch version of KNN
|
||||||
Ns = [1, 2, 4]
|
Ns = [1, 2, 4]
|
||||||
@ -72,10 +92,27 @@ def knn_cuda_with_init(N, D, P, K, v=-1):
|
|||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
x = torch.randn(N, P, D, device=device)
|
x = torch.randn(N, P, D, device=device)
|
||||||
y = torch.randn(N, P, D, device=device)
|
y = torch.randn(N, P, D, device=device)
|
||||||
|
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def knn():
|
def knn():
|
||||||
_C.knn_points_idx(x, y, K, v)
|
_C.knn_points_idx(x, y, lengths, lengths, K, v)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
return knn
|
||||||
|
|
||||||
|
|
||||||
|
def knn_cuda_ragged(N, D, P, K, v=-1):
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
x = torch.randn(N, P, D, device=device)
|
||||||
|
y = torch.randn(N, P, D, device=device)
|
||||||
|
lengths1 = torch.randint(P, size=(N,), device=device, dtype=torch.int64)
|
||||||
|
lengths2 = torch.randint(P, size=(N,), device=device, dtype=torch.int64)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
def knn():
|
||||||
|
_C.knn_points_idx(x, y, lengths1, lengths2, K, v)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
return knn
|
return knn
|
||||||
@ -85,9 +122,10 @@ def knn_cpu_with_init(N, D, P, K):
|
|||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
x = torch.randn(N, P, D, device=device)
|
x = torch.randn(N, P, D, device=device)
|
||||||
y = torch.randn(N, P, D, device=device)
|
y = torch.randn(N, P, D, device=device)
|
||||||
|
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
def knn():
|
def knn():
|
||||||
_C.knn_points_idx(x, y, K, 0)
|
_C.knn_points_idx(x, y, lengths, lengths, K, -1)
|
||||||
|
|
||||||
return knn
|
return knn
|
||||||
|
|
||||||
@ -96,10 +134,12 @@ def knn_python_cuda_with_init(N, D, P, K):
|
|||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
x = torch.randn(N, P, D, device=device)
|
x = torch.randn(N, P, D, device=device)
|
||||||
y = torch.randn(N, P, D, device=device)
|
y = torch.randn(N, P, D, device=device)
|
||||||
|
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def knn():
|
def knn():
|
||||||
_knn_points_idx_naive(x, y, K)
|
_knn_points_idx_naive(x, y, K=K, lengths1=lengths, lengths2=lengths)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
return knn
|
return knn
|
||||||
@ -109,9 +149,10 @@ def knn_python_cpu_with_init(N, D, P, K):
|
|||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
x = torch.randn(N, P, D, device=device)
|
x = torch.randn(N, P, D, device=device)
|
||||||
y = torch.randn(N, P, D, device=device)
|
y = torch.randn(N, P, D, device=device)
|
||||||
|
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
def knn():
|
def knn():
|
||||||
_knn_points_idx_naive(x, y, K)
|
_knn_points_idx_naive(x, y, K=K, lengths1=lengths, lengths2=lengths)
|
||||||
|
|
||||||
return knn
|
return knn
|
||||||
|
|
||||||
|
@ -8,6 +8,10 @@ from pytorch3d.ops.knn import _knn_points_idx_naive, knn_points_idx
|
|||||||
|
|
||||||
|
|
||||||
class TestKNN(unittest.TestCase):
|
class TestKNN(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
super().setUp()
|
||||||
|
torch.manual_seed(1)
|
||||||
|
|
||||||
def _check_knn_result(self, out1, out2, sorted):
|
def _check_knn_result(self, out1, out2, sorted):
|
||||||
# When sorted=True, points should be sorted by distance and should
|
# When sorted=True, points should be sorted by distance and should
|
||||||
# match between implementations. When sorted=False we we only want to
|
# match between implementations. When sorted=False we we only want to
|
||||||
@ -26,7 +30,7 @@ class TestKNN(unittest.TestCase):
|
|||||||
self.assertTrue(torch.all(idx1 == idx2))
|
self.assertTrue(torch.all(idx1 == idx2))
|
||||||
self.assertTrue(torch.allclose(dist1, dist2))
|
self.assertTrue(torch.allclose(dist1, dist2))
|
||||||
|
|
||||||
def test_knn_vs_python_cpu(self):
|
def test_knn_vs_python_cpu_square(self):
|
||||||
""" Test CPU output vs PyTorch implementation """
|
""" Test CPU output vs PyTorch implementation """
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
Ns = [1, 4]
|
Ns = [1, 4]
|
||||||
@ -37,13 +41,19 @@ class TestKNN(unittest.TestCase):
|
|||||||
sorts = [True, False]
|
sorts = [True, False]
|
||||||
factors = [Ns, Ds, P1s, P2s, Ks, sorts]
|
factors = [Ns, Ds, P1s, P2s, Ks, sorts]
|
||||||
for N, D, P1, P2, K, sort in product(*factors):
|
for N, D, P1, P2, K, sort in product(*factors):
|
||||||
|
lengths1 = torch.full((N,), P1, dtype=torch.int64, device=device)
|
||||||
|
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=device)
|
||||||
x = torch.randn(N, P1, D, device=device)
|
x = torch.randn(N, P1, D, device=device)
|
||||||
y = torch.randn(N, P2, D, device=device)
|
y = torch.randn(N, P2, D, device=device)
|
||||||
out1 = _knn_points_idx_naive(x, y, K, sort)
|
out1 = _knn_points_idx_naive(
|
||||||
out2 = knn_points_idx(x, y, K, sort)
|
x, y, lengths1=lengths1, lengths2=lengths2, K=K
|
||||||
|
)
|
||||||
|
out2 = knn_points_idx(
|
||||||
|
x, y, K=K, lengths1=lengths1, lengths2=lengths2, sorted=sort
|
||||||
|
)
|
||||||
self._check_knn_result(out1, out2, sort)
|
self._check_knn_result(out1, out2, sort)
|
||||||
|
|
||||||
def test_knn_vs_python_cuda(self):
|
def test_knn_vs_python_cuda_square(self):
|
||||||
""" Test CUDA output vs PyTorch implementation """
|
""" Test CUDA output vs PyTorch implementation """
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
Ns = [1, 4]
|
Ns = [1, 4]
|
||||||
@ -57,9 +67,53 @@ class TestKNN(unittest.TestCase):
|
|||||||
for N, D, P1, P2, K, sort in product(*factors):
|
for N, D, P1, P2, K, sort in product(*factors):
|
||||||
x = torch.randn(N, P1, D, device=device)
|
x = torch.randn(N, P1, D, device=device)
|
||||||
y = torch.randn(N, P2, D, device=device)
|
y = torch.randn(N, P2, D, device=device)
|
||||||
out1 = _knn_points_idx_naive(x, y, K, sorted=sort)
|
out1 = _knn_points_idx_naive(x, y, lengths1=None, lengths2=None, K=K)
|
||||||
for version in versions:
|
for version in versions:
|
||||||
if version == 3 and K > 4:
|
if version == 3 and K > 4:
|
||||||
continue
|
continue
|
||||||
out2 = knn_points_idx(x, y, K, sort, version)
|
out2 = knn_points_idx(x, y, K=K, sorted=sort, version=version)
|
||||||
|
self._check_knn_result(out1, out2, sort)
|
||||||
|
|
||||||
|
def test_knn_vs_python_cpu_ragged(self):
|
||||||
|
device = torch.device("cpu")
|
||||||
|
lengths1 = torch.tensor([10, 100, 10, 100], device=device, dtype=torch.int64)
|
||||||
|
lengths2 = torch.tensor([10, 10, 100, 100], device=device, dtype=torch.int64)
|
||||||
|
N = 4
|
||||||
|
D = 3
|
||||||
|
Ks = [1, 9, 10, 11, 101]
|
||||||
|
sorts = [False, True]
|
||||||
|
factors = [Ks, sorts]
|
||||||
|
for K, sort in product(*factors):
|
||||||
|
x = torch.randn(N, lengths1.max(), D, device=device)
|
||||||
|
y = torch.randn(N, lengths2.max(), D, device=device)
|
||||||
|
out1 = _knn_points_idx_naive(
|
||||||
|
x, y, lengths1=lengths1, lengths2=lengths2, K=K
|
||||||
|
)
|
||||||
|
out2 = knn_points_idx(
|
||||||
|
x, y, lengths1=lengths1, lengths2=lengths2, K=K, sorted=sort
|
||||||
|
)
|
||||||
|
self._check_knn_result(out1, out2, sort)
|
||||||
|
|
||||||
|
def test_knn_vs_python_cuda_ragged(self):
|
||||||
|
device = torch.device("cuda")
|
||||||
|
lengths1 = torch.tensor([10, 100, 10, 100], device=device, dtype=torch.int64)
|
||||||
|
lengths2 = torch.tensor([10, 10, 100, 100], device=device, dtype=torch.int64)
|
||||||
|
N = 4
|
||||||
|
D = 3
|
||||||
|
Ks = [1, 9, 10, 11, 101]
|
||||||
|
sorts = [True, False]
|
||||||
|
versions = [0, 1, 2, 3]
|
||||||
|
factors = [Ks, sorts]
|
||||||
|
for K, sort in product(*factors):
|
||||||
|
x = torch.randn(N, lengths1.max(), D, device=device)
|
||||||
|
y = torch.randn(N, lengths2.max(), D, device=device)
|
||||||
|
out1 = _knn_points_idx_naive(
|
||||||
|
x, y, lengths1=lengths1, lengths2=lengths2, K=K
|
||||||
|
)
|
||||||
|
for version in versions:
|
||||||
|
if version == 3 and K > 4:
|
||||||
|
continue
|
||||||
|
out2 = knn_points_idx(
|
||||||
|
x, y, lengths1=lengths1, lengths2=lengths2, K=K, sorted=sort
|
||||||
|
)
|
||||||
self._check_knn_result(out1, out2, sort)
|
self._check_knn_result(out1, out2, sort)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user