heterogenous KNN

Summary: Interface and working implementation of ragged KNN. Benchmarks (which aren't ragged) haven't slowed. New benchmark shows that ragged is faster than non-ragged of the same shape.

Reviewed By: jcjohnson

Differential Revision: D20696507

fbshipit-source-id: 21b80f71343a3475c8d3ee0ce2680f92f0fae4de
This commit is contained in:
Jeremy Reizenstein 2020-04-07 01:45:43 -07:00 committed by Facebook GitHub Bot
parent 29b9c44c0a
commit 01b5f7b228
6 changed files with 332 additions and 84 deletions

View File

@ -8,10 +8,20 @@
#include "dispatch.cuh"
#include "mink.cuh"
// A chunk of work is blocksize-many points of P1.
// The number of potential chunks to do is N*(1+(P1-1)/blocksize)
// call (1+(P1-1)/blocksize) chunks_per_cloud
// These chunks are divided among the gridSize-many blocks.
// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc .
// In chunk i, we work on cloud i/chunks_per_cloud on points starting from
// blocksize*(i%chunks_per_cloud).
template <typename scalar_t>
__global__ void KNearestNeighborKernelV0(
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const size_t N,
@ -19,18 +29,19 @@ __global__ void KNearestNeighborKernelV0(
const size_t P2,
const size_t D,
const size_t K) {
// Stupid version: Make each thread handle one query point and loop over
// all P2 target points. There are N * P1 input points to handle, so
// do a trivial parallelization over threads.
// Store both dists and indices for knn in global memory.
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int num_threads = blockDim.x * gridDim.x;
for (int np = tid; np < N * P1; np += num_threads) {
int n = np / P1;
int p1 = np % P1;
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
const int64_t chunks_to_do = N * chunks_per_cloud;
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
const int64_t n = chunk / chunks_per_cloud;
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
int64_t p1 = start_point + threadIdx.x;
if (p1 >= lengths1[n])
continue;
int offset = n * P1 * K + p1 * K;
int64_t length2 = lengths2[n];
MinK<scalar_t, int64_t> mink(dists + offset, idxs + offset, K);
for (int p2 = 0; p2 < P2; ++p2) {
for (int p2 = 0; p2 < length2; ++p2) {
// Find the distance between points1[n, p1] and points[n, p2]
scalar_t dist = 0;
for (int d = 0; d < D; ++d) {
@ -48,6 +59,8 @@ template <typename scalar_t, int64_t D>
__global__ void KNearestNeighborKernelV1(
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const size_t N,
@ -58,18 +71,22 @@ __global__ void KNearestNeighborKernelV1(
// so we can cache the current point in a thread-local array. We still store
// the current best K dists and indices in global memory, so this should work
// for very large K and fairly large D.
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int num_threads = blockDim.x * gridDim.x;
scalar_t cur_point[D];
for (int np = tid; np < N * P1; np += num_threads) {
int n = np / P1;
int p1 = np % P1;
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
const int64_t chunks_to_do = N * chunks_per_cloud;
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
const int64_t n = chunk / chunks_per_cloud;
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
int64_t p1 = start_point + threadIdx.x;
if (p1 >= lengths1[n])
continue;
for (int d = 0; d < D; ++d) {
cur_point[d] = points1[n * P1 * D + p1 * D + d];
}
int offset = n * P1 * K + p1 * K;
int64_t length2 = lengths2[n];
MinK<scalar_t, int64_t> mink(dists + offset, idxs + offset, K);
for (int p2 = 0; p2 < P2; ++p2) {
for (int p2 = 0; p2 < length2; ++p2) {
// Find the distance between cur_point and points[n, p2]
scalar_t dist = 0;
for (int d = 0; d < D; ++d) {
@ -89,14 +106,16 @@ struct KNearestNeighborV1Functor {
size_t threads,
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const size_t N,
const size_t P1,
const size_t P2,
const size_t K) {
KNearestNeighborKernelV1<scalar_t, D>
<<<blocks, threads>>>(points1, points2, dists, idxs, N, P1, P2, K);
KNearestNeighborKernelV1<scalar_t, D><<<blocks, threads>>>(
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K);
}
};
@ -104,25 +123,31 @@ template <typename scalar_t, int64_t D, int64_t K>
__global__ void KNearestNeighborKernelV2(
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const int64_t N,
const int64_t P1,
const int64_t P2) {
// Same general implementation as V2, but also hoist K into a template arg.
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int num_threads = blockDim.x * gridDim.x;
scalar_t cur_point[D];
scalar_t min_dists[K];
int min_idxs[K];
for (int np = tid; np < N * P1; np += num_threads) {
int n = np / P1;
int p1 = np % P1;
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
const int64_t chunks_to_do = N * chunks_per_cloud;
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
const int64_t n = chunk / chunks_per_cloud;
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
int64_t p1 = start_point + threadIdx.x;
if (p1 >= lengths1[n])
continue;
for (int d = 0; d < D; ++d) {
cur_point[d] = points1[n * P1 * D + p1 * D + d];
}
int64_t length2 = lengths2[n];
MinK<scalar_t, int> mink(min_dists, min_idxs, K);
for (int p2 = 0; p2 < P2; ++p2) {
for (int p2 = 0; p2 < length2; ++p2) {
scalar_t dist = 0;
for (int d = 0; d < D; ++d) {
int offset = n * P2 * D + p2 * D + d;
@ -146,13 +171,15 @@ struct KNearestNeighborKernelV2Functor {
size_t threads,
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const int64_t N,
const int64_t P1,
const int64_t P2) {
KNearestNeighborKernelV2<scalar_t, D, K>
<<<blocks, threads>>>(points1, points2, dists, idxs, N, P1, P2);
KNearestNeighborKernelV2<scalar_t, D, K><<<blocks, threads>>>(
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
}
};
@ -160,6 +187,8 @@ template <typename scalar_t, int D, int K>
__global__ void KNearestNeighborKernelV3(
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const size_t N,
@ -169,19 +198,23 @@ __global__ void KNearestNeighborKernelV3(
// Enabling sorting for this version leads to huge slowdowns; I suspect
// that it forces min_dists into local memory rather than registers.
// As a result this version is always unsorted.
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int num_threads = blockDim.x * gridDim.x;
scalar_t cur_point[D];
scalar_t min_dists[K];
int min_idxs[K];
for (int np = tid; np < N * P1; np += num_threads) {
int n = np / P1;
int p1 = np % P1;
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
const int64_t chunks_to_do = N * chunks_per_cloud;
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
const int64_t n = chunk / chunks_per_cloud;
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
int64_t p1 = start_point + threadIdx.x;
if (p1 >= lengths1[n])
continue;
for (int d = 0; d < D; ++d) {
cur_point[d] = points1[n * P1 * D + p1 * D + d];
}
int64_t length2 = lengths2[n];
RegisterMinK<scalar_t, int, K> mink(min_dists, min_idxs);
for (int p2 = 0; p2 < P2; ++p2) {
for (int p2 = 0; p2 < length2; ++p2) {
scalar_t dist = 0;
for (int d = 0; d < D; ++d) {
int offset = n * P2 * D + p2 * D + d;
@ -205,13 +238,15 @@ struct KNearestNeighborKernelV3Functor {
size_t threads,
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const size_t N,
const size_t P1,
const size_t P2) {
KNearestNeighborKernelV3<scalar_t, D, K>
<<<blocks, threads>>>(points1, points2, dists, idxs, N, P1, P2);
KNearestNeighborKernelV3<scalar_t, D, K><<<blocks, threads>>>(
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
}
};
@ -257,6 +292,8 @@ int ChooseVersion(const int64_t D, const int64_t K) {
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
int K,
int version) {
const auto N = p1.size(0);
@ -267,8 +304,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
AT_ASSERTM(p2.size(2) == D, "Point sets must have the same last dimension");
auto long_dtype = p1.options().dtype(at::kLong);
auto idxs = at::full({N, P1, K}, -1, long_dtype);
auto dists = at::full({N, P1, K}, -1, p1.options());
auto idxs = at::zeros({N, P1, K}, long_dtype);
auto dists = at::zeros({N, P1, K}, p1.options());
if (version < 0) {
version = ChooseVersion(D, K);
@ -294,6 +331,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
<<<blocks, threads>>>(
p1.data_ptr<scalar_t>(),
p2.data_ptr<scalar_t>(),
lengths1.data_ptr<int64_t>(),
lengths2.data_ptr<int64_t>(),
dists.data_ptr<scalar_t>(),
idxs.data_ptr<int64_t>(),
N,
@ -314,6 +353,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
threads,
p1.data_ptr<scalar_t>(),
p2.data_ptr<scalar_t>(),
lengths1.data_ptr<int64_t>(),
lengths2.data_ptr<int64_t>(),
dists.data_ptr<scalar_t>(),
idxs.data_ptr<int64_t>(),
N,
@ -336,6 +377,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
threads,
p1.data_ptr<scalar_t>(),
p2.data_ptr<scalar_t>(),
lengths1.data_ptr<int64_t>(),
lengths2.data_ptr<int64_t>(),
dists.data_ptr<scalar_t>(),
idxs.data_ptr<int64_t>(),
N,
@ -357,6 +400,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
threads,
p1.data_ptr<scalar_t>(),
p2.data_ptr<scalar_t>(),
lengths1.data_ptr<int64_t>(),
lengths2.data_ptr<int64_t>(),
dists.data_ptr<scalar_t>(),
idxs.data_ptr<int64_t>(),
N,

View File

@ -13,25 +13,38 @@
// containing P1 points of dimension D.
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
// containing P2 points of dimension D.
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
// K: int giving the number of nearest points to return.
// sorted: bool telling whether to sort the K returned points by their
// distance.
// version: Integer telling which implementation to use.
// TODO(jcjohns): Document this more, or maybe remove it before landing.
//
// Returns:
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
// p1_neighbor_idx[n, i, k] = j means that the kth nearest
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
// p1_neighbor_idx[n, i, k] = j means that the kth nearest
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
// It is padded with zeros so that it can be used easily in a later
// gather() operation.
//
// p1_neighbor_dists: FloatTensor of shape (N, P1, K) containing the squared
// distance from each point p1[n, p, :] to its K neighbors
// p2[n, p1_neighbor_idx[n, p, k], :].
// CPU implementation.
std::tuple<at::Tensor, at::Tensor>
KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K);
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
int K);
// CUDA implementation
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
int K,
int version);
@ -39,16 +52,18 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
int K,
int version) {
if (p1.is_cuda() || p2.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(p1);
CHECK_CONTIGUOUS_CUDA(p2);
return KNearestNeighborIdxCuda(p1, p2, K, version);
return KNearestNeighborIdxCuda(p1, p2, lengths1, lengths2, K, version);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
return KNearestNeighborIdxCpu(p1, p2, K);
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, K);
}

View File

@ -4,27 +4,35 @@
#include <queue>
#include <tuple>
std::tuple<at::Tensor, at::Tensor>
KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K) {
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
int K) {
const int N = p1.size(0);
const int P1 = p1.size(1);
const int D = p1.size(2);
const int P2 = p2.size(1);
auto long_opts = p1.options().dtype(torch::kInt64);
torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts);
torch::Tensor idxs = torch::full({N, P1, K}, 0, long_opts);
torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());
auto p1_a = p1.accessor<float, 3>();
auto p2_a = p2.accessor<float, 3>();
auto lengths1_a = lengths1.accessor<int64_t, 1>();
auto lengths2_a = lengths2.accessor<int64_t, 1>();
auto idxs_a = idxs.accessor<int64_t, 3>();
auto dists_a = dists.accessor<float, 3>();
for (int n = 0; n < N; ++n) {
for (int i1 = 0; i1 < P1; ++i1) {
const int64_t length1 = lengths1_a[n];
const int64_t length2 = lengths2_a[n];
for (int64_t i1 = 0; i1 < length1; ++i1) {
// Use a priority queue to store (distance, index) tuples.
std::priority_queue<std::tuple<float, int>> q;
for (int i2 = 0; i2 < P2; ++i2) {
for (int64_t i2 = 0; i2 < length2; ++i2) {
float dist = 0;
for (int d = 0; d < D; ++d) {
float diff = p1_a[n][i1][d] - p2_a[n][i2][d];

View File

@ -4,37 +4,80 @@ import torch
from pytorch3d import _C
def knn_points_idx(p1, p2, K, sorted=False, version=-1):
def knn_points_idx(
p1,
p2,
K: int,
lengths1=None,
lengths2=None,
sorted: bool = False,
version: int = -1,
):
"""
K-Nearest neighbors on point clouds.
Args:
p1: Tensor of shape (N, P1, D) giving a batch of point clouds, each
containing P1 points of dimension D.
containing up to P1 points of dimension D.
p2: Tensor of shape (N, P2, D) giving a batch of point clouds, each
containing P2 points of dimension D.
K: Integer giving the number of nearest neighbors to return
containing up to P2 points of dimension D.
K: Integer giving the number of nearest neighbors to return.
lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
length of each pointcloud in p1. Or None to indicate that every cloud has
length P1.
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
length of each pointcloud in p2. Or None to indicate that every cloud has
length P2.
sorted: Whether to sort the resulting points.
version: Which KNN implementation to use in the backend. If version=-1,
the correct implementation is selected based on the shapes of
the inputs.
the correct implementation is selected based on the shapes of the inputs.
Returns:
idx: LongTensor of shape (N, P1, K) giving the indices of the
K nearest neighbors from points in p1 to points in p2.
Concretely, if idx[n, i, k] = j then p2[n, j] is one of the K
nearest neighbor to p1[n, i] in p2[n]. If sorted=True, then
p2[n, j] is the kth nearest neighbor to p1[n, i].
p1_neighbor_idx: LongTensor of shape (N, P1, K) giving the indices of the
K nearest neighbors from points in p1 to points in p2.
Concretely, if idx[n, i, k] = j then p2[n, j] is one of the K nearest
neighbors to p1[n, i] in p2[n]. If sorted=True, then p2[n, j] is the kth
nearest neighbor to p1[n, i]. This is padded with zeros both where a cloud
in p2 has fewer than K points and where a cloud in p1 has fewer than P1
points.
If you want an (N, P1, K, D) tensor of the actual points, you can get it
using
p2[:, :, None].expand(-1, -1, K, -1).gather(1,
x_idx[:, :, :, None].expand(-1, -1, -1, D)
)
If K=1 and you want an (N, P1, D) tensor of the actual points, use
p2.gather(1, x_idx.expand(-1, -1, D))
p1_neighbor_dists: Tensor of shape (N, P1, K) giving the squared distances to
the nearest neighbors. This is padded with zeros both where a cloud in p2
has fewer than K points and where a cloud in p1 has fewer than P1 points.
Warning: this is calculated outside of the autograd framework.
"""
idx, dists = _C.knn_points_idx(p1, p2, K, version)
P1 = p1.shape[1]
P2 = p2.shape[1]
if lengths1 is None:
lengths1 = torch.full((p1.shape[0],), P1, dtype=torch.int64, device=p1.device)
if lengths2 is None:
lengths2 = torch.full((p1.shape[0],), P2, dtype=torch.int64, device=p1.device)
idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, K, version)
if sorted:
dists, sort_idx = dists.sort(dim=2)
if lengths2.min() < K:
device = dists.device
mask1 = lengths2[:, None] <= torch.arange(K, device=device)[None]
# mask1 has shape [N, K], true where dists irrelevant
mask2 = mask1[:, None].expand(-1, P1, -1)
# mask2 has shape [N, P1, K], true where dists irrelevant
dists[mask2] = float("inf")
dists, sort_idx = dists.sort(dim=2)
dists[mask2] = 0
else:
dists, sort_idx = dists.sort(dim=2)
idx = idx.gather(2, sort_idx)
return idx, dists
@torch.no_grad()
def _knn_points_idx_naive(p1, p2, K, sorted=False) -> torch.Tensor:
def _knn_points_idx_naive(p1, p2, K: int, lengths1, lengths2) -> torch.Tensor:
"""
Naive PyTorch implementation of K-Nearest Neighbors.
@ -43,25 +86,67 @@ def _knn_points_idx_naive(p1, p2, K, sorted=False) -> torch.Tensor:
Args:
p1: Tensor of shape (N, P1, D) giving a batch of point clouds, each
containing P1 points of dimension D.
containing up to P1 points of dimension D.
p2: Tensor of shape (N, P2, D) giving a batch of point clouds, each
containing P2 points of dimension D.
K: Integer giving the number of nearest neighbors to return
sorted: Whether to sort the resulting points.
containing up to P2 points of dimension D.
K: Integer giving the number of nearest neighbors to return.
lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
length of each pointcloud in p1. Or None to indicate that every cloud has
length P1.
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
length of each pointcloud in p2. Or None to indicate that every cloud has
length P2.
Returns:
idx: LongTensor of shape (N, P1, K) giving the indices of the
K nearest neighbors from points in p1 to points in p2.
Concretely, if idx[n, i, k] = j then p2[n, j] is one of the K
nearest neighbor to p1[n, i] in p2[n]. If sorted=True, then
p2[n, j] is the kth nearest neighbor to p1[n, i].
dists: Tensor of shape (N, P1, K) giving the distances to the nearest
neighbors.
K nearest neighbors from points in p1 to points in p2.
Concretely, if idx[n, i, k] = j then p2[n, j] is the kth nearest neighbor
to p1[n, i]. This is padded with zeros both where a cloud in p2 has fewer
than K points and where a cloud in p1 has fewer than P1 points.
dists: Tensor of shape (N, P1, K) giving the squared distances to the nearest
neighbors. This is padded with zeros both where a cloud in p2 has fewer than
K points and where a cloud in p1 has fewer than P1 points.
"""
N, P1, D = p1.shape
_N, P2, _D = p2.shape
assert N == _N and D == _D
diffs = p1.view(N, P1, 1, D) - p2.view(N, 1, P2, D)
if lengths1 is None:
lengths1 = torch.full((N,), P1, dtype=torch.int64, device=p1.device)
if lengths2 is None:
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device)
p1_copy = p1.clone()
p2_copy = p2.clone()
# We pad the values with infinities so that the smallest differences are
# among actual points.
inf = float("inf")
p1_mask = torch.arange(P1, device=p1.device)[None] >= lengths1[:, None]
p1_copy[p1_mask] = inf
p2_copy[torch.arange(P2, device=p1.device)[None] >= lengths2[:, None]] = -inf
# view is safe here: we are merely adding extra dimensions of length 1
diffs = p1_copy.view(N, P1, 1, D) - p2_copy.view(N, 1, P2, D)
dists2 = (diffs * diffs).sum(dim=3)
out = dists2.topk(K, dim=2, largest=False, sorted=sorted)
return out.indices, out.values
# We always sort, because this works well with padding.
out = dists2.topk(min(K, P2), dim=2, largest=False, sorted=True)
out_indices = out.indices
out_values = out.values
if P2 < K:
# Need to add padding
pad_shape = (N, P1, K - P2)
out_indices = torch.cat([out_indices, out_indices.new_zeros(pad_shape)], 2)
out_values = torch.cat([out_values, out_values.new_zeros(pad_shape)], 2)
K_mask = torch.arange(K, device=p1.device)[None] >= lengths2[:, None]
# Create a combined mask for where the points in p1 are padded
# or the corresponding p2 has fewer than K points.
p1_K_mask = p1_mask[:, :, None] | K_mask[:, None, :]
out_indices[p1_K_mask] = 0
out_values[p1_K_mask] = 0
return out_indices, out_values

View File

@ -13,6 +13,7 @@ def bm_knn() -> None:
benchmark_knn_cpu()
benchmark_knn_cuda_vs_naive()
benchmark_knn_cuda_versions()
benchmark_knn_cuda_versions_ragged()
def benchmark_knn_cuda_versions() -> None:
@ -36,6 +37,25 @@ def benchmark_knn_cuda_versions() -> None:
benchmark(nn_cuda_with_init, "NN_CUDA", nn_kwargs, warmup_iters=1)
def benchmark_knn_cuda_versions_ragged() -> None:
# Compare our different KNN implementations,
# and also compare against our existing 1-NN
Ns = [8]
Ps = [4096, 16384]
Ds = [3]
Ks = [1, 4, 16, 64]
versions = [0, 1, 2, 3]
knn_kwargs = []
for N, P, D, K, version in product(Ns, Ps, Ds, Ks, versions):
if version == 2 and K > 32:
continue
if version == 3 and K > 4:
continue
knn_kwargs.append({"N": N, "D": D, "P": P, "K": K, "v": version})
benchmark(knn_cuda_with_init, "KNN_CUDA_COMPARISON", knn_kwargs, warmup_iters=1)
benchmark(knn_cuda_ragged, "KNN_CUDA_RAGGED", knn_kwargs, warmup_iters=1)
def benchmark_knn_cuda_vs_naive() -> None:
# Compare against naive pytorch version of KNN
Ns = [1, 2, 4]
@ -72,10 +92,27 @@ def knn_cuda_with_init(N, D, P, K, v=-1):
device = torch.device("cuda:0")
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
torch.cuda.synchronize()
def knn():
_C.knn_points_idx(x, y, K, v)
_C.knn_points_idx(x, y, lengths, lengths, K, v)
torch.cuda.synchronize()
return knn
def knn_cuda_ragged(N, D, P, K, v=-1):
device = torch.device("cuda:0")
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
lengths1 = torch.randint(P, size=(N,), device=device, dtype=torch.int64)
lengths2 = torch.randint(P, size=(N,), device=device, dtype=torch.int64)
torch.cuda.synchronize()
def knn():
_C.knn_points_idx(x, y, lengths1, lengths2, K, v)
torch.cuda.synchronize()
return knn
@ -85,9 +122,10 @@ def knn_cpu_with_init(N, D, P, K):
device = torch.device("cpu")
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
def knn():
_C.knn_points_idx(x, y, K, 0)
_C.knn_points_idx(x, y, lengths, lengths, K, -1)
return knn
@ -96,10 +134,12 @@ def knn_python_cuda_with_init(N, D, P, K):
device = torch.device("cuda")
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
torch.cuda.synchronize()
def knn():
_knn_points_idx_naive(x, y, K)
_knn_points_idx_naive(x, y, K=K, lengths1=lengths, lengths2=lengths)
torch.cuda.synchronize()
return knn
@ -109,9 +149,10 @@ def knn_python_cpu_with_init(N, D, P, K):
device = torch.device("cpu")
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
def knn():
_knn_points_idx_naive(x, y, K)
_knn_points_idx_naive(x, y, K=K, lengths1=lengths, lengths2=lengths)
return knn

View File

@ -8,6 +8,10 @@ from pytorch3d.ops.knn import _knn_points_idx_naive, knn_points_idx
class TestKNN(unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(1)
def _check_knn_result(self, out1, out2, sorted):
# When sorted=True, points should be sorted by distance and should
# match between implementations. When sorted=False we we only want to
@ -26,7 +30,7 @@ class TestKNN(unittest.TestCase):
self.assertTrue(torch.all(idx1 == idx2))
self.assertTrue(torch.allclose(dist1, dist2))
def test_knn_vs_python_cpu(self):
def test_knn_vs_python_cpu_square(self):
""" Test CPU output vs PyTorch implementation """
device = torch.device("cpu")
Ns = [1, 4]
@ -37,13 +41,19 @@ class TestKNN(unittest.TestCase):
sorts = [True, False]
factors = [Ns, Ds, P1s, P2s, Ks, sorts]
for N, D, P1, P2, K, sort in product(*factors):
lengths1 = torch.full((N,), P1, dtype=torch.int64, device=device)
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=device)
x = torch.randn(N, P1, D, device=device)
y = torch.randn(N, P2, D, device=device)
out1 = _knn_points_idx_naive(x, y, K, sort)
out2 = knn_points_idx(x, y, K, sort)
out1 = _knn_points_idx_naive(
x, y, lengths1=lengths1, lengths2=lengths2, K=K
)
out2 = knn_points_idx(
x, y, K=K, lengths1=lengths1, lengths2=lengths2, sorted=sort
)
self._check_knn_result(out1, out2, sort)
def test_knn_vs_python_cuda(self):
def test_knn_vs_python_cuda_square(self):
""" Test CUDA output vs PyTorch implementation """
device = torch.device("cuda")
Ns = [1, 4]
@ -57,9 +67,53 @@ class TestKNN(unittest.TestCase):
for N, D, P1, P2, K, sort in product(*factors):
x = torch.randn(N, P1, D, device=device)
y = torch.randn(N, P2, D, device=device)
out1 = _knn_points_idx_naive(x, y, K, sorted=sort)
out1 = _knn_points_idx_naive(x, y, lengths1=None, lengths2=None, K=K)
for version in versions:
if version == 3 and K > 4:
continue
out2 = knn_points_idx(x, y, K, sort, version)
out2 = knn_points_idx(x, y, K=K, sorted=sort, version=version)
self._check_knn_result(out1, out2, sort)
def test_knn_vs_python_cpu_ragged(self):
device = torch.device("cpu")
lengths1 = torch.tensor([10, 100, 10, 100], device=device, dtype=torch.int64)
lengths2 = torch.tensor([10, 10, 100, 100], device=device, dtype=torch.int64)
N = 4
D = 3
Ks = [1, 9, 10, 11, 101]
sorts = [False, True]
factors = [Ks, sorts]
for K, sort in product(*factors):
x = torch.randn(N, lengths1.max(), D, device=device)
y = torch.randn(N, lengths2.max(), D, device=device)
out1 = _knn_points_idx_naive(
x, y, lengths1=lengths1, lengths2=lengths2, K=K
)
out2 = knn_points_idx(
x, y, lengths1=lengths1, lengths2=lengths2, K=K, sorted=sort
)
self._check_knn_result(out1, out2, sort)
def test_knn_vs_python_cuda_ragged(self):
device = torch.device("cuda")
lengths1 = torch.tensor([10, 100, 10, 100], device=device, dtype=torch.int64)
lengths2 = torch.tensor([10, 10, 100, 100], device=device, dtype=torch.int64)
N = 4
D = 3
Ks = [1, 9, 10, 11, 101]
sorts = [True, False]
versions = [0, 1, 2, 3]
factors = [Ks, sorts]
for K, sort in product(*factors):
x = torch.randn(N, lengths1.max(), D, device=device)
y = torch.randn(N, lengths2.max(), D, device=device)
out1 = _knn_points_idx_naive(
x, y, lengths1=lengths1, lengths2=lengths2, K=K
)
for version in versions:
if version == 3 and K > 4:
continue
out2 = knn_points_idx(
x, y, lengths1=lengths1, lengths2=lengths2, K=K, sorted=sort
)
self._check_knn_result(out1, out2, sort)