mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
heterogenous KNN
Summary: Interface and working implementation of ragged KNN. Benchmarks (which aren't ragged) haven't slowed. New benchmark shows that ragged is faster than non-ragged of the same shape. Reviewed By: jcjohnson Differential Revision: D20696507 fbshipit-source-id: 21b80f71343a3475c8d3ee0ce2680f92f0fae4de
This commit is contained in:
parent
29b9c44c0a
commit
01b5f7b228
@ -8,10 +8,20 @@
|
||||
#include "dispatch.cuh"
|
||||
#include "mink.cuh"
|
||||
|
||||
// A chunk of work is blocksize-many points of P1.
|
||||
// The number of potential chunks to do is N*(1+(P1-1)/blocksize)
|
||||
// call (1+(P1-1)/blocksize) chunks_per_cloud
|
||||
// These chunks are divided among the gridSize-many blocks.
|
||||
// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc .
|
||||
// In chunk i, we work on cloud i/chunks_per_cloud on points starting from
|
||||
// blocksize*(i%chunks_per_cloud).
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void KNearestNeighborKernelV0(
|
||||
const scalar_t* __restrict__ points1,
|
||||
const scalar_t* __restrict__ points2,
|
||||
const int64_t* __restrict__ lengths1,
|
||||
const int64_t* __restrict__ lengths2,
|
||||
scalar_t* __restrict__ dists,
|
||||
int64_t* __restrict__ idxs,
|
||||
const size_t N,
|
||||
@ -19,18 +29,19 @@ __global__ void KNearestNeighborKernelV0(
|
||||
const size_t P2,
|
||||
const size_t D,
|
||||
const size_t K) {
|
||||
// Stupid version: Make each thread handle one query point and loop over
|
||||
// all P2 target points. There are N * P1 input points to handle, so
|
||||
// do a trivial parallelization over threads.
|
||||
// Store both dists and indices for knn in global memory.
|
||||
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int num_threads = blockDim.x * gridDim.x;
|
||||
for (int np = tid; np < N * P1; np += num_threads) {
|
||||
int n = np / P1;
|
||||
int p1 = np % P1;
|
||||
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
|
||||
const int64_t chunks_to_do = N * chunks_per_cloud;
|
||||
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
|
||||
const int64_t n = chunk / chunks_per_cloud;
|
||||
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
|
||||
int64_t p1 = start_point + threadIdx.x;
|
||||
if (p1 >= lengths1[n])
|
||||
continue;
|
||||
int offset = n * P1 * K + p1 * K;
|
||||
int64_t length2 = lengths2[n];
|
||||
MinK<scalar_t, int64_t> mink(dists + offset, idxs + offset, K);
|
||||
for (int p2 = 0; p2 < P2; ++p2) {
|
||||
for (int p2 = 0; p2 < length2; ++p2) {
|
||||
// Find the distance between points1[n, p1] and points[n, p2]
|
||||
scalar_t dist = 0;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
@ -48,6 +59,8 @@ template <typename scalar_t, int64_t D>
|
||||
__global__ void KNearestNeighborKernelV1(
|
||||
const scalar_t* __restrict__ points1,
|
||||
const scalar_t* __restrict__ points2,
|
||||
const int64_t* __restrict__ lengths1,
|
||||
const int64_t* __restrict__ lengths2,
|
||||
scalar_t* __restrict__ dists,
|
||||
int64_t* __restrict__ idxs,
|
||||
const size_t N,
|
||||
@ -58,18 +71,22 @@ __global__ void KNearestNeighborKernelV1(
|
||||
// so we can cache the current point in a thread-local array. We still store
|
||||
// the current best K dists and indices in global memory, so this should work
|
||||
// for very large K and fairly large D.
|
||||
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int num_threads = blockDim.x * gridDim.x;
|
||||
scalar_t cur_point[D];
|
||||
for (int np = tid; np < N * P1; np += num_threads) {
|
||||
int n = np / P1;
|
||||
int p1 = np % P1;
|
||||
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
|
||||
const int64_t chunks_to_do = N * chunks_per_cloud;
|
||||
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
|
||||
const int64_t n = chunk / chunks_per_cloud;
|
||||
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
|
||||
int64_t p1 = start_point + threadIdx.x;
|
||||
if (p1 >= lengths1[n])
|
||||
continue;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
cur_point[d] = points1[n * P1 * D + p1 * D + d];
|
||||
}
|
||||
int offset = n * P1 * K + p1 * K;
|
||||
int64_t length2 = lengths2[n];
|
||||
MinK<scalar_t, int64_t> mink(dists + offset, idxs + offset, K);
|
||||
for (int p2 = 0; p2 < P2; ++p2) {
|
||||
for (int p2 = 0; p2 < length2; ++p2) {
|
||||
// Find the distance between cur_point and points[n, p2]
|
||||
scalar_t dist = 0;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
@ -89,14 +106,16 @@ struct KNearestNeighborV1Functor {
|
||||
size_t threads,
|
||||
const scalar_t* __restrict__ points1,
|
||||
const scalar_t* __restrict__ points2,
|
||||
const int64_t* __restrict__ lengths1,
|
||||
const int64_t* __restrict__ lengths2,
|
||||
scalar_t* __restrict__ dists,
|
||||
int64_t* __restrict__ idxs,
|
||||
const size_t N,
|
||||
const size_t P1,
|
||||
const size_t P2,
|
||||
const size_t K) {
|
||||
KNearestNeighborKernelV1<scalar_t, D>
|
||||
<<<blocks, threads>>>(points1, points2, dists, idxs, N, P1, P2, K);
|
||||
KNearestNeighborKernelV1<scalar_t, D><<<blocks, threads>>>(
|
||||
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K);
|
||||
}
|
||||
};
|
||||
|
||||
@ -104,25 +123,31 @@ template <typename scalar_t, int64_t D, int64_t K>
|
||||
__global__ void KNearestNeighborKernelV2(
|
||||
const scalar_t* __restrict__ points1,
|
||||
const scalar_t* __restrict__ points2,
|
||||
const int64_t* __restrict__ lengths1,
|
||||
const int64_t* __restrict__ lengths2,
|
||||
scalar_t* __restrict__ dists,
|
||||
int64_t* __restrict__ idxs,
|
||||
const int64_t N,
|
||||
const int64_t P1,
|
||||
const int64_t P2) {
|
||||
// Same general implementation as V2, but also hoist K into a template arg.
|
||||
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int num_threads = blockDim.x * gridDim.x;
|
||||
scalar_t cur_point[D];
|
||||
scalar_t min_dists[K];
|
||||
int min_idxs[K];
|
||||
for (int np = tid; np < N * P1; np += num_threads) {
|
||||
int n = np / P1;
|
||||
int p1 = np % P1;
|
||||
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
|
||||
const int64_t chunks_to_do = N * chunks_per_cloud;
|
||||
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
|
||||
const int64_t n = chunk / chunks_per_cloud;
|
||||
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
|
||||
int64_t p1 = start_point + threadIdx.x;
|
||||
if (p1 >= lengths1[n])
|
||||
continue;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
cur_point[d] = points1[n * P1 * D + p1 * D + d];
|
||||
}
|
||||
int64_t length2 = lengths2[n];
|
||||
MinK<scalar_t, int> mink(min_dists, min_idxs, K);
|
||||
for (int p2 = 0; p2 < P2; ++p2) {
|
||||
for (int p2 = 0; p2 < length2; ++p2) {
|
||||
scalar_t dist = 0;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
int offset = n * P2 * D + p2 * D + d;
|
||||
@ -146,13 +171,15 @@ struct KNearestNeighborKernelV2Functor {
|
||||
size_t threads,
|
||||
const scalar_t* __restrict__ points1,
|
||||
const scalar_t* __restrict__ points2,
|
||||
const int64_t* __restrict__ lengths1,
|
||||
const int64_t* __restrict__ lengths2,
|
||||
scalar_t* __restrict__ dists,
|
||||
int64_t* __restrict__ idxs,
|
||||
const int64_t N,
|
||||
const int64_t P1,
|
||||
const int64_t P2) {
|
||||
KNearestNeighborKernelV2<scalar_t, D, K>
|
||||
<<<blocks, threads>>>(points1, points2, dists, idxs, N, P1, P2);
|
||||
KNearestNeighborKernelV2<scalar_t, D, K><<<blocks, threads>>>(
|
||||
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
|
||||
}
|
||||
};
|
||||
|
||||
@ -160,6 +187,8 @@ template <typename scalar_t, int D, int K>
|
||||
__global__ void KNearestNeighborKernelV3(
|
||||
const scalar_t* __restrict__ points1,
|
||||
const scalar_t* __restrict__ points2,
|
||||
const int64_t* __restrict__ lengths1,
|
||||
const int64_t* __restrict__ lengths2,
|
||||
scalar_t* __restrict__ dists,
|
||||
int64_t* __restrict__ idxs,
|
||||
const size_t N,
|
||||
@ -169,19 +198,23 @@ __global__ void KNearestNeighborKernelV3(
|
||||
// Enabling sorting for this version leads to huge slowdowns; I suspect
|
||||
// that it forces min_dists into local memory rather than registers.
|
||||
// As a result this version is always unsorted.
|
||||
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int num_threads = blockDim.x * gridDim.x;
|
||||
scalar_t cur_point[D];
|
||||
scalar_t min_dists[K];
|
||||
int min_idxs[K];
|
||||
for (int np = tid; np < N * P1; np += num_threads) {
|
||||
int n = np / P1;
|
||||
int p1 = np % P1;
|
||||
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
|
||||
const int64_t chunks_to_do = N * chunks_per_cloud;
|
||||
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
|
||||
const int64_t n = chunk / chunks_per_cloud;
|
||||
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
|
||||
int64_t p1 = start_point + threadIdx.x;
|
||||
if (p1 >= lengths1[n])
|
||||
continue;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
cur_point[d] = points1[n * P1 * D + p1 * D + d];
|
||||
}
|
||||
int64_t length2 = lengths2[n];
|
||||
RegisterMinK<scalar_t, int, K> mink(min_dists, min_idxs);
|
||||
for (int p2 = 0; p2 < P2; ++p2) {
|
||||
for (int p2 = 0; p2 < length2; ++p2) {
|
||||
scalar_t dist = 0;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
int offset = n * P2 * D + p2 * D + d;
|
||||
@ -205,13 +238,15 @@ struct KNearestNeighborKernelV3Functor {
|
||||
size_t threads,
|
||||
const scalar_t* __restrict__ points1,
|
||||
const scalar_t* __restrict__ points2,
|
||||
const int64_t* __restrict__ lengths1,
|
||||
const int64_t* __restrict__ lengths2,
|
||||
scalar_t* __restrict__ dists,
|
||||
int64_t* __restrict__ idxs,
|
||||
const size_t N,
|
||||
const size_t P1,
|
||||
const size_t P2) {
|
||||
KNearestNeighborKernelV3<scalar_t, D, K>
|
||||
<<<blocks, threads>>>(points1, points2, dists, idxs, N, P1, P2);
|
||||
KNearestNeighborKernelV3<scalar_t, D, K><<<blocks, threads>>>(
|
||||
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
|
||||
}
|
||||
};
|
||||
|
||||
@ -257,6 +292,8 @@ int ChooseVersion(const int64_t D, const int64_t K) {
|
||||
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
const at::Tensor& p1,
|
||||
const at::Tensor& p2,
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
int K,
|
||||
int version) {
|
||||
const auto N = p1.size(0);
|
||||
@ -267,8 +304,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
|
||||
AT_ASSERTM(p2.size(2) == D, "Point sets must have the same last dimension");
|
||||
auto long_dtype = p1.options().dtype(at::kLong);
|
||||
auto idxs = at::full({N, P1, K}, -1, long_dtype);
|
||||
auto dists = at::full({N, P1, K}, -1, p1.options());
|
||||
auto idxs = at::zeros({N, P1, K}, long_dtype);
|
||||
auto dists = at::zeros({N, P1, K}, p1.options());
|
||||
|
||||
if (version < 0) {
|
||||
version = ChooseVersion(D, K);
|
||||
@ -294,6 +331,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
<<<blocks, threads>>>(
|
||||
p1.data_ptr<scalar_t>(),
|
||||
p2.data_ptr<scalar_t>(),
|
||||
lengths1.data_ptr<int64_t>(),
|
||||
lengths2.data_ptr<int64_t>(),
|
||||
dists.data_ptr<scalar_t>(),
|
||||
idxs.data_ptr<int64_t>(),
|
||||
N,
|
||||
@ -314,6 +353,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
threads,
|
||||
p1.data_ptr<scalar_t>(),
|
||||
p2.data_ptr<scalar_t>(),
|
||||
lengths1.data_ptr<int64_t>(),
|
||||
lengths2.data_ptr<int64_t>(),
|
||||
dists.data_ptr<scalar_t>(),
|
||||
idxs.data_ptr<int64_t>(),
|
||||
N,
|
||||
@ -336,6 +377,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
threads,
|
||||
p1.data_ptr<scalar_t>(),
|
||||
p2.data_ptr<scalar_t>(),
|
||||
lengths1.data_ptr<int64_t>(),
|
||||
lengths2.data_ptr<int64_t>(),
|
||||
dists.data_ptr<scalar_t>(),
|
||||
idxs.data_ptr<int64_t>(),
|
||||
N,
|
||||
@ -357,6 +400,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
threads,
|
||||
p1.data_ptr<scalar_t>(),
|
||||
p2.data_ptr<scalar_t>(),
|
||||
lengths1.data_ptr<int64_t>(),
|
||||
lengths2.data_ptr<int64_t>(),
|
||||
dists.data_ptr<scalar_t>(),
|
||||
idxs.data_ptr<int64_t>(),
|
||||
N,
|
||||
|
@ -13,25 +13,38 @@
|
||||
// containing P1 points of dimension D.
|
||||
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
|
||||
// containing P2 points of dimension D.
|
||||
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
|
||||
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
|
||||
// K: int giving the number of nearest points to return.
|
||||
// sorted: bool telling whether to sort the K returned points by their
|
||||
// distance.
|
||||
// version: Integer telling which implementation to use.
|
||||
// TODO(jcjohns): Document this more, or maybe remove it before landing.
|
||||
//
|
||||
// Returns:
|
||||
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
|
||||
// p1_neighbor_idx[n, i, k] = j means that the kth nearest
|
||||
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
|
||||
// p1_neighbor_idx[n, i, k] = j means that the kth nearest
|
||||
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
|
||||
// It is padded with zeros so that it can be used easily in a later
|
||||
// gather() operation.
|
||||
//
|
||||
// p1_neighbor_dists: FloatTensor of shape (N, P1, K) containing the squared
|
||||
// distance from each point p1[n, p, :] to its K neighbors
|
||||
// p2[n, p1_neighbor_idx[n, p, k], :].
|
||||
|
||||
// CPU implementation.
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K);
|
||||
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
|
||||
const at::Tensor& p1,
|
||||
const at::Tensor& p2,
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
int K);
|
||||
|
||||
// CUDA implementation
|
||||
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
const at::Tensor& p1,
|
||||
const at::Tensor& p2,
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
int K,
|
||||
int version);
|
||||
|
||||
@ -39,16 +52,18 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
|
||||
const at::Tensor& p1,
|
||||
const at::Tensor& p2,
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
int K,
|
||||
int version) {
|
||||
if (p1.is_cuda() || p2.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(p1);
|
||||
CHECK_CONTIGUOUS_CUDA(p2);
|
||||
return KNearestNeighborIdxCuda(p1, p2, K, version);
|
||||
return KNearestNeighborIdxCuda(p1, p2, lengths1, lengths2, K, version);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return KNearestNeighborIdxCpu(p1, p2, K);
|
||||
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, K);
|
||||
}
|
||||
|
@ -4,27 +4,35 @@
|
||||
#include <queue>
|
||||
#include <tuple>
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K) {
|
||||
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
|
||||
const at::Tensor& p1,
|
||||
const at::Tensor& p2,
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
int K) {
|
||||
const int N = p1.size(0);
|
||||
const int P1 = p1.size(1);
|
||||
const int D = p1.size(2);
|
||||
const int P2 = p2.size(1);
|
||||
|
||||
auto long_opts = p1.options().dtype(torch::kInt64);
|
||||
torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts);
|
||||
torch::Tensor idxs = torch::full({N, P1, K}, 0, long_opts);
|
||||
torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());
|
||||
|
||||
auto p1_a = p1.accessor<float, 3>();
|
||||
auto p2_a = p2.accessor<float, 3>();
|
||||
auto lengths1_a = lengths1.accessor<int64_t, 1>();
|
||||
auto lengths2_a = lengths2.accessor<int64_t, 1>();
|
||||
auto idxs_a = idxs.accessor<int64_t, 3>();
|
||||
auto dists_a = dists.accessor<float, 3>();
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int i1 = 0; i1 < P1; ++i1) {
|
||||
const int64_t length1 = lengths1_a[n];
|
||||
const int64_t length2 = lengths2_a[n];
|
||||
for (int64_t i1 = 0; i1 < length1; ++i1) {
|
||||
// Use a priority queue to store (distance, index) tuples.
|
||||
std::priority_queue<std::tuple<float, int>> q;
|
||||
for (int i2 = 0; i2 < P2; ++i2) {
|
||||
for (int64_t i2 = 0; i2 < length2; ++i2) {
|
||||
float dist = 0;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
|
||||
|
@ -4,37 +4,80 @@ import torch
|
||||
from pytorch3d import _C
|
||||
|
||||
|
||||
def knn_points_idx(p1, p2, K, sorted=False, version=-1):
|
||||
def knn_points_idx(
|
||||
p1,
|
||||
p2,
|
||||
K: int,
|
||||
lengths1=None,
|
||||
lengths2=None,
|
||||
sorted: bool = False,
|
||||
version: int = -1,
|
||||
):
|
||||
"""
|
||||
K-Nearest neighbors on point clouds.
|
||||
|
||||
Args:
|
||||
p1: Tensor of shape (N, P1, D) giving a batch of point clouds, each
|
||||
containing P1 points of dimension D.
|
||||
containing up to P1 points of dimension D.
|
||||
p2: Tensor of shape (N, P2, D) giving a batch of point clouds, each
|
||||
containing P2 points of dimension D.
|
||||
K: Integer giving the number of nearest neighbors to return
|
||||
containing up to P2 points of dimension D.
|
||||
K: Integer giving the number of nearest neighbors to return.
|
||||
lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
|
||||
length of each pointcloud in p1. Or None to indicate that every cloud has
|
||||
length P1.
|
||||
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
|
||||
length of each pointcloud in p2. Or None to indicate that every cloud has
|
||||
length P2.
|
||||
sorted: Whether to sort the resulting points.
|
||||
version: Which KNN implementation to use in the backend. If version=-1,
|
||||
the correct implementation is selected based on the shapes of
|
||||
the inputs.
|
||||
the correct implementation is selected based on the shapes of the inputs.
|
||||
|
||||
Returns:
|
||||
idx: LongTensor of shape (N, P1, K) giving the indices of the
|
||||
K nearest neighbors from points in p1 to points in p2.
|
||||
Concretely, if idx[n, i, k] = j then p2[n, j] is one of the K
|
||||
nearest neighbor to p1[n, i] in p2[n]. If sorted=True, then
|
||||
p2[n, j] is the kth nearest neighbor to p1[n, i].
|
||||
p1_neighbor_idx: LongTensor of shape (N, P1, K) giving the indices of the
|
||||
K nearest neighbors from points in p1 to points in p2.
|
||||
Concretely, if idx[n, i, k] = j then p2[n, j] is one of the K nearest
|
||||
neighbors to p1[n, i] in p2[n]. If sorted=True, then p2[n, j] is the kth
|
||||
nearest neighbor to p1[n, i]. This is padded with zeros both where a cloud
|
||||
in p2 has fewer than K points and where a cloud in p1 has fewer than P1
|
||||
points.
|
||||
If you want an (N, P1, K, D) tensor of the actual points, you can get it
|
||||
using
|
||||
p2[:, :, None].expand(-1, -1, K, -1).gather(1,
|
||||
x_idx[:, :, :, None].expand(-1, -1, -1, D)
|
||||
)
|
||||
If K=1 and you want an (N, P1, D) tensor of the actual points, use
|
||||
p2.gather(1, x_idx.expand(-1, -1, D))
|
||||
|
||||
p1_neighbor_dists: Tensor of shape (N, P1, K) giving the squared distances to
|
||||
the nearest neighbors. This is padded with zeros both where a cloud in p2
|
||||
has fewer than K points and where a cloud in p1 has fewer than P1 points.
|
||||
Warning: this is calculated outside of the autograd framework.
|
||||
"""
|
||||
idx, dists = _C.knn_points_idx(p1, p2, K, version)
|
||||
P1 = p1.shape[1]
|
||||
P2 = p2.shape[1]
|
||||
if lengths1 is None:
|
||||
lengths1 = torch.full((p1.shape[0],), P1, dtype=torch.int64, device=p1.device)
|
||||
if lengths2 is None:
|
||||
lengths2 = torch.full((p1.shape[0],), P2, dtype=torch.int64, device=p1.device)
|
||||
idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, K, version)
|
||||
if sorted:
|
||||
dists, sort_idx = dists.sort(dim=2)
|
||||
if lengths2.min() < K:
|
||||
device = dists.device
|
||||
mask1 = lengths2[:, None] <= torch.arange(K, device=device)[None]
|
||||
# mask1 has shape [N, K], true where dists irrelevant
|
||||
mask2 = mask1[:, None].expand(-1, P1, -1)
|
||||
# mask2 has shape [N, P1, K], true where dists irrelevant
|
||||
dists[mask2] = float("inf")
|
||||
dists, sort_idx = dists.sort(dim=2)
|
||||
dists[mask2] = 0
|
||||
else:
|
||||
dists, sort_idx = dists.sort(dim=2)
|
||||
idx = idx.gather(2, sort_idx)
|
||||
return idx, dists
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _knn_points_idx_naive(p1, p2, K, sorted=False) -> torch.Tensor:
|
||||
def _knn_points_idx_naive(p1, p2, K: int, lengths1, lengths2) -> torch.Tensor:
|
||||
"""
|
||||
Naive PyTorch implementation of K-Nearest Neighbors.
|
||||
|
||||
@ -43,25 +86,67 @@ def _knn_points_idx_naive(p1, p2, K, sorted=False) -> torch.Tensor:
|
||||
|
||||
Args:
|
||||
p1: Tensor of shape (N, P1, D) giving a batch of point clouds, each
|
||||
containing P1 points of dimension D.
|
||||
containing up to P1 points of dimension D.
|
||||
p2: Tensor of shape (N, P2, D) giving a batch of point clouds, each
|
||||
containing P2 points of dimension D.
|
||||
K: Integer giving the number of nearest neighbors to return
|
||||
sorted: Whether to sort the resulting points.
|
||||
containing up to P2 points of dimension D.
|
||||
K: Integer giving the number of nearest neighbors to return.
|
||||
lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
|
||||
length of each pointcloud in p1. Or None to indicate that every cloud has
|
||||
length P1.
|
||||
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
|
||||
length of each pointcloud in p2. Or None to indicate that every cloud has
|
||||
length P2.
|
||||
|
||||
Returns:
|
||||
idx: LongTensor of shape (N, P1, K) giving the indices of the
|
||||
K nearest neighbors from points in p1 to points in p2.
|
||||
Concretely, if idx[n, i, k] = j then p2[n, j] is one of the K
|
||||
nearest neighbor to p1[n, i] in p2[n]. If sorted=True, then
|
||||
p2[n, j] is the kth nearest neighbor to p1[n, i].
|
||||
dists: Tensor of shape (N, P1, K) giving the distances to the nearest
|
||||
neighbors.
|
||||
K nearest neighbors from points in p1 to points in p2.
|
||||
Concretely, if idx[n, i, k] = j then p2[n, j] is the kth nearest neighbor
|
||||
to p1[n, i]. This is padded with zeros both where a cloud in p2 has fewer
|
||||
than K points and where a cloud in p1 has fewer than P1 points.
|
||||
dists: Tensor of shape (N, P1, K) giving the squared distances to the nearest
|
||||
neighbors. This is padded with zeros both where a cloud in p2 has fewer than
|
||||
K points and where a cloud in p1 has fewer than P1 points.
|
||||
"""
|
||||
N, P1, D = p1.shape
|
||||
_N, P2, _D = p2.shape
|
||||
|
||||
assert N == _N and D == _D
|
||||
diffs = p1.view(N, P1, 1, D) - p2.view(N, 1, P2, D)
|
||||
|
||||
if lengths1 is None:
|
||||
lengths1 = torch.full((N,), P1, dtype=torch.int64, device=p1.device)
|
||||
if lengths2 is None:
|
||||
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device)
|
||||
|
||||
p1_copy = p1.clone()
|
||||
p2_copy = p2.clone()
|
||||
|
||||
# We pad the values with infinities so that the smallest differences are
|
||||
# among actual points.
|
||||
inf = float("inf")
|
||||
p1_mask = torch.arange(P1, device=p1.device)[None] >= lengths1[:, None]
|
||||
p1_copy[p1_mask] = inf
|
||||
p2_copy[torch.arange(P2, device=p1.device)[None] >= lengths2[:, None]] = -inf
|
||||
|
||||
# view is safe here: we are merely adding extra dimensions of length 1
|
||||
diffs = p1_copy.view(N, P1, 1, D) - p2_copy.view(N, 1, P2, D)
|
||||
dists2 = (diffs * diffs).sum(dim=3)
|
||||
out = dists2.topk(K, dim=2, largest=False, sorted=sorted)
|
||||
return out.indices, out.values
|
||||
|
||||
# We always sort, because this works well with padding.
|
||||
out = dists2.topk(min(K, P2), dim=2, largest=False, sorted=True)
|
||||
|
||||
out_indices = out.indices
|
||||
out_values = out.values
|
||||
|
||||
if P2 < K:
|
||||
# Need to add padding
|
||||
pad_shape = (N, P1, K - P2)
|
||||
out_indices = torch.cat([out_indices, out_indices.new_zeros(pad_shape)], 2)
|
||||
out_values = torch.cat([out_values, out_values.new_zeros(pad_shape)], 2)
|
||||
|
||||
K_mask = torch.arange(K, device=p1.device)[None] >= lengths2[:, None]
|
||||
# Create a combined mask for where the points in p1 are padded
|
||||
# or the corresponding p2 has fewer than K points.
|
||||
p1_K_mask = p1_mask[:, :, None] | K_mask[:, None, :]
|
||||
out_indices[p1_K_mask] = 0
|
||||
out_values[p1_K_mask] = 0
|
||||
return out_indices, out_values
|
||||
|
@ -13,6 +13,7 @@ def bm_knn() -> None:
|
||||
benchmark_knn_cpu()
|
||||
benchmark_knn_cuda_vs_naive()
|
||||
benchmark_knn_cuda_versions()
|
||||
benchmark_knn_cuda_versions_ragged()
|
||||
|
||||
|
||||
def benchmark_knn_cuda_versions() -> None:
|
||||
@ -36,6 +37,25 @@ def benchmark_knn_cuda_versions() -> None:
|
||||
benchmark(nn_cuda_with_init, "NN_CUDA", nn_kwargs, warmup_iters=1)
|
||||
|
||||
|
||||
def benchmark_knn_cuda_versions_ragged() -> None:
|
||||
# Compare our different KNN implementations,
|
||||
# and also compare against our existing 1-NN
|
||||
Ns = [8]
|
||||
Ps = [4096, 16384]
|
||||
Ds = [3]
|
||||
Ks = [1, 4, 16, 64]
|
||||
versions = [0, 1, 2, 3]
|
||||
knn_kwargs = []
|
||||
for N, P, D, K, version in product(Ns, Ps, Ds, Ks, versions):
|
||||
if version == 2 and K > 32:
|
||||
continue
|
||||
if version == 3 and K > 4:
|
||||
continue
|
||||
knn_kwargs.append({"N": N, "D": D, "P": P, "K": K, "v": version})
|
||||
benchmark(knn_cuda_with_init, "KNN_CUDA_COMPARISON", knn_kwargs, warmup_iters=1)
|
||||
benchmark(knn_cuda_ragged, "KNN_CUDA_RAGGED", knn_kwargs, warmup_iters=1)
|
||||
|
||||
|
||||
def benchmark_knn_cuda_vs_naive() -> None:
|
||||
# Compare against naive pytorch version of KNN
|
||||
Ns = [1, 2, 4]
|
||||
@ -72,10 +92,27 @@ def knn_cuda_with_init(N, D, P, K, v=-1):
|
||||
device = torch.device("cuda:0")
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def knn():
|
||||
_C.knn_points_idx(x, y, K, v)
|
||||
_C.knn_points_idx(x, y, lengths, lengths, K, v)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return knn
|
||||
|
||||
|
||||
def knn_cuda_ragged(N, D, P, K, v=-1):
|
||||
device = torch.device("cuda:0")
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
lengths1 = torch.randint(P, size=(N,), device=device, dtype=torch.int64)
|
||||
lengths2 = torch.randint(P, size=(N,), device=device, dtype=torch.int64)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def knn():
|
||||
_C.knn_points_idx(x, y, lengths1, lengths2, K, v)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return knn
|
||||
@ -85,9 +122,10 @@ def knn_cpu_with_init(N, D, P, K):
|
||||
device = torch.device("cpu")
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
||||
|
||||
def knn():
|
||||
_C.knn_points_idx(x, y, K, 0)
|
||||
_C.knn_points_idx(x, y, lengths, lengths, K, -1)
|
||||
|
||||
return knn
|
||||
|
||||
@ -96,10 +134,12 @@ def knn_python_cuda_with_init(N, D, P, K):
|
||||
device = torch.device("cuda")
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def knn():
|
||||
_knn_points_idx_naive(x, y, K)
|
||||
_knn_points_idx_naive(x, y, K=K, lengths1=lengths, lengths2=lengths)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return knn
|
||||
@ -109,9 +149,10 @@ def knn_python_cpu_with_init(N, D, P, K):
|
||||
device = torch.device("cpu")
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
||||
|
||||
def knn():
|
||||
_knn_points_idx_naive(x, y, K)
|
||||
_knn_points_idx_naive(x, y, K=K, lengths1=lengths, lengths2=lengths)
|
||||
|
||||
return knn
|
||||
|
||||
|
@ -8,6 +8,10 @@ from pytorch3d.ops.knn import _knn_points_idx_naive, knn_points_idx
|
||||
|
||||
|
||||
class TestKNN(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(1)
|
||||
|
||||
def _check_knn_result(self, out1, out2, sorted):
|
||||
# When sorted=True, points should be sorted by distance and should
|
||||
# match between implementations. When sorted=False we we only want to
|
||||
@ -26,7 +30,7 @@ class TestKNN(unittest.TestCase):
|
||||
self.assertTrue(torch.all(idx1 == idx2))
|
||||
self.assertTrue(torch.allclose(dist1, dist2))
|
||||
|
||||
def test_knn_vs_python_cpu(self):
|
||||
def test_knn_vs_python_cpu_square(self):
|
||||
""" Test CPU output vs PyTorch implementation """
|
||||
device = torch.device("cpu")
|
||||
Ns = [1, 4]
|
||||
@ -37,13 +41,19 @@ class TestKNN(unittest.TestCase):
|
||||
sorts = [True, False]
|
||||
factors = [Ns, Ds, P1s, P2s, Ks, sorts]
|
||||
for N, D, P1, P2, K, sort in product(*factors):
|
||||
lengths1 = torch.full((N,), P1, dtype=torch.int64, device=device)
|
||||
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=device)
|
||||
x = torch.randn(N, P1, D, device=device)
|
||||
y = torch.randn(N, P2, D, device=device)
|
||||
out1 = _knn_points_idx_naive(x, y, K, sort)
|
||||
out2 = knn_points_idx(x, y, K, sort)
|
||||
out1 = _knn_points_idx_naive(
|
||||
x, y, lengths1=lengths1, lengths2=lengths2, K=K
|
||||
)
|
||||
out2 = knn_points_idx(
|
||||
x, y, K=K, lengths1=lengths1, lengths2=lengths2, sorted=sort
|
||||
)
|
||||
self._check_knn_result(out1, out2, sort)
|
||||
|
||||
def test_knn_vs_python_cuda(self):
|
||||
def test_knn_vs_python_cuda_square(self):
|
||||
""" Test CUDA output vs PyTorch implementation """
|
||||
device = torch.device("cuda")
|
||||
Ns = [1, 4]
|
||||
@ -57,9 +67,53 @@ class TestKNN(unittest.TestCase):
|
||||
for N, D, P1, P2, K, sort in product(*factors):
|
||||
x = torch.randn(N, P1, D, device=device)
|
||||
y = torch.randn(N, P2, D, device=device)
|
||||
out1 = _knn_points_idx_naive(x, y, K, sorted=sort)
|
||||
out1 = _knn_points_idx_naive(x, y, lengths1=None, lengths2=None, K=K)
|
||||
for version in versions:
|
||||
if version == 3 and K > 4:
|
||||
continue
|
||||
out2 = knn_points_idx(x, y, K, sort, version)
|
||||
out2 = knn_points_idx(x, y, K=K, sorted=sort, version=version)
|
||||
self._check_knn_result(out1, out2, sort)
|
||||
|
||||
def test_knn_vs_python_cpu_ragged(self):
|
||||
device = torch.device("cpu")
|
||||
lengths1 = torch.tensor([10, 100, 10, 100], device=device, dtype=torch.int64)
|
||||
lengths2 = torch.tensor([10, 10, 100, 100], device=device, dtype=torch.int64)
|
||||
N = 4
|
||||
D = 3
|
||||
Ks = [1, 9, 10, 11, 101]
|
||||
sorts = [False, True]
|
||||
factors = [Ks, sorts]
|
||||
for K, sort in product(*factors):
|
||||
x = torch.randn(N, lengths1.max(), D, device=device)
|
||||
y = torch.randn(N, lengths2.max(), D, device=device)
|
||||
out1 = _knn_points_idx_naive(
|
||||
x, y, lengths1=lengths1, lengths2=lengths2, K=K
|
||||
)
|
||||
out2 = knn_points_idx(
|
||||
x, y, lengths1=lengths1, lengths2=lengths2, K=K, sorted=sort
|
||||
)
|
||||
self._check_knn_result(out1, out2, sort)
|
||||
|
||||
def test_knn_vs_python_cuda_ragged(self):
|
||||
device = torch.device("cuda")
|
||||
lengths1 = torch.tensor([10, 100, 10, 100], device=device, dtype=torch.int64)
|
||||
lengths2 = torch.tensor([10, 10, 100, 100], device=device, dtype=torch.int64)
|
||||
N = 4
|
||||
D = 3
|
||||
Ks = [1, 9, 10, 11, 101]
|
||||
sorts = [True, False]
|
||||
versions = [0, 1, 2, 3]
|
||||
factors = [Ks, sorts]
|
||||
for K, sort in product(*factors):
|
||||
x = torch.randn(N, lengths1.max(), D, device=device)
|
||||
y = torch.randn(N, lengths2.max(), D, device=device)
|
||||
out1 = _knn_points_idx_naive(
|
||||
x, y, lengths1=lengths1, lengths2=lengths2, K=K
|
||||
)
|
||||
for version in versions:
|
||||
if version == 3 and K > 4:
|
||||
continue
|
||||
out2 = knn_points_idx(
|
||||
x, y, lengths1=lengths1, lengths2=lengths2, K=K, sorted=sort
|
||||
)
|
||||
self._check_knn_result(out1, out2, sort)
|
||||
|
Loading…
x
Reference in New Issue
Block a user