knn autograd

Summary:
Adds knn backward to return `grad_pts1` and `grad_pts2`. Adds `knn_gather` to return the nearest neighbors in pts2.

The BM tests include backward pass and are ran on an M40.
```
Benchmark                               Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
KNN_SQUARE_32_256_128_3_24_cpu              39558           43485             13
KNN_SQUARE_32_256_128_3_24_cuda:0            1080            1404            463
KNN_SQUARE_32_256_512_3_24_cpu              81950           85781              7
KNN_SQUARE_32_256_512_3_24_cuda:0            1519            1641            330
--------------------------------------------------------------------------------

Benchmark                               Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
KNN_RAGGED_32_256_128_3_24_cpu              13798           14650             37
KNN_RAGGED_32_256_128_3_24_cuda:0            1576            1713            318
KNN_RAGGED_32_256_512_3_24_cpu              31255           32210             16
KNN_RAGGED_32_256_512_3_24_cuda:0            2024            2162            248
--------------------------------------------------------------------------------
```

Reviewed By: jcjohnson

Differential Revision: D20945556

fbshipit-source-id: a16f616029c6b5f8c2afceb5f2bc12c5c20d2f3c
This commit is contained in:
Georgia Gkioxari 2020-04-14 17:20:16 -07:00 committed by Facebook GitHub Bot
parent 487d4d6607
commit b2b0c5a442
8 changed files with 545 additions and 365 deletions

View File

@ -20,6 +20,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("packed_to_padded", &PackedToPadded); m.def("packed_to_padded", &PackedToPadded);
m.def("padded_to_packed", &PaddedToPacked); m.def("padded_to_packed", &PaddedToPacked);
m.def("knn_points_idx", &KNearestNeighborIdx); m.def("knn_points_idx", &KNearestNeighborIdx);
m.def("knn_points_backward", &KNearestNeighborBackward);
m.def("nn_points_idx", &NearestNeighborIdx); m.def("nn_points_idx", &NearestNeighborIdx);
m.def("gather_scatter", &gather_scatter); m.def("gather_scatter", &gather_scatter);
m.def("rasterize_points", &RasterizePoints); m.def("rasterize_points", &RasterizePoints);

View File

@ -412,3 +412,93 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
return std::make_tuple(idxs, dists); return std::make_tuple(idxs, dists);
} }
// ------------------------------------------------------------- //
// Backward Operators //
// ------------------------------------------------------------- //
// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
// Currently, support is for floats only.
__global__ void KNearestNeighborBackwardKernel(
const float* __restrict__ p1, // (N, P1, D)
const float* __restrict__ p2, // (N, P2, D)
const int64_t* __restrict__ lengths1, // (N,)
const int64_t* __restrict__ lengths2, // (N,)
const int64_t* __restrict__ idxs, // (N, P1, K)
const float* __restrict__ grad_dists, // (N, P1, K)
float* __restrict__ grad_p1, // (N, P1, D)
float* __restrict__ grad_p2, // (N, P2, D)
const size_t N,
const size_t P1,
const size_t P2,
const size_t K,
const size_t D) {
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = gridDim.x * blockDim.x;
for (size_t i = tid; i < N * P1 * K * D; i += stride) {
const size_t n = i / (P1 * K * D); // batch index
size_t rem = i % (P1 * K * D);
const size_t p1_idx = rem / (K * D); // index of point in p1
rem = rem % (K * D);
const size_t k = rem / D; // k-th nearest neighbor
const size_t d = rem % D; // d-th dimension in the feature vector
const size_t num1 = lengths1[n]; // number of valid points in p1 in batch
const size_t num2 = lengths2[n]; // number of valid points in p2 in batch
if ((p1_idx < num1) && (k < num2)) {
const float grad_dist = grad_dists[n * P1 * K + p1_idx * K + k];
// index of point in p2 corresponding to the k-th nearest neighbor
const size_t p2_idx = idxs[n * P1 * K + p1_idx * K + k];
const float diff = 2.0 * grad_dist *
(p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff);
atomicAdd(grad_p2 + n * P2 * D + p2_idx * D + d, -1.0f * diff);
}
}
}
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const at::Tensor& idxs,
const at::Tensor& grad_dists) {
const auto N = p1.size(0);
const auto P1 = p1.size(1);
const auto P2 = p2.size(1);
const auto D = p2.size(2);
const auto K = idxs.size(2);
AT_ASSERTM(p2.size(2) == D, "Point sets must have the same last dimension");
AT_ASSERTM(idxs.size(0) == N, "KNN idxs must have the same batch dimension");
AT_ASSERTM(
idxs.size(1) == P1, "KNN idxs must have the same point dimension as p1");
AT_ASSERTM(grad_dists.size(0) == N);
AT_ASSERTM(grad_dists.size(1) == P1);
AT_ASSERTM(grad_dists.size(2) == K);
auto grad_p1 = at::zeros({N, P1, D}, p1.options());
auto grad_p2 = at::zeros({N, P2, D}, p2.options());
const int blocks = 64;
const int threads = 512;
KNearestNeighborBackwardKernel<<<blocks, threads>>>(
p1.data_ptr<float>(),
p2.data_ptr<float>(),
lengths1.data_ptr<int64_t>(),
lengths2.data_ptr<int64_t>(),
idxs.data_ptr<int64_t>(),
grad_dists.data_ptr<float>(),
grad_p1.data_ptr<float>(),
grad_p2.data_ptr<float>(),
N,
P1,
P2,
K,
D);
return std::make_tuple(grad_p1, grad_p2);
}

View File

@ -16,8 +16,6 @@
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud. // lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud. // lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
// K: int giving the number of nearest points to return. // K: int giving the number of nearest points to return.
// sorted: bool telling whether to sort the K returned points by their
// distance.
// version: Integer telling which implementation to use. // version: Integer telling which implementation to use.
// //
// Returns: // Returns:
@ -67,3 +65,66 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
} }
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, K); return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, K);
} }
// Compute gradients with respect to p1 and p2
//
// Args:
// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
// containing P1 points of dimension D.
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
// containing P2 points of dimension D.
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
// p1_neighbor_idx[n, i, k] = j means that the kth nearest
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
// It is padded with zeros so that it can be used easily in a later
// gather() operation. This is computed from the forward pass.
// grad_dists: FLoatTensor of shape (N, P1, K) which contains the input
// gradients.
//
// Returns:
// grad_p1: FloatTensor of shape (N, P1, D) containing the output gradients
// wrt p1.
// grad_p2: FloatTensor of shape (N, P2, D) containing the output gradients
// wrt p2.
// CPU implementation.
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const at::Tensor& idxs,
const at::Tensor& grad_dists);
// CUDA implementation
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const at::Tensor& idxs,
const at::Tensor& grad_dists);
// Implementation which is exposed.
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackward(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const at::Tensor& idxs,
const at::Tensor& grad_dists) {
if (p1.is_cuda() || p2.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(p1);
CHECK_CONTIGUOUS_CUDA(p2);
return KNearestNeighborBackwardCuda(
p1, p2, lengths1, lengths2, idxs, grad_dists);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
return KNearestNeighborBackwardCpu(
p1, p2, lengths1, lengths2, idxs, grad_dists);
}

View File

@ -57,3 +57,51 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
} }
return std::make_tuple(idxs, dists); return std::make_tuple(idxs, dists);
} }
// ------------------------------------------------------------- //
// Backward Operators //
// ------------------------------------------------------------- //
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const at::Tensor& idxs,
const at::Tensor& grad_dists) {
const int N = p1.size(0);
const int P1 = p1.size(1);
const int D = p1.size(2);
const int P2 = p2.size(1);
const int K = idxs.size(2);
torch::Tensor grad_p1 = torch::full({N, P1, D}, 0, p1.options());
torch::Tensor grad_p2 = torch::full({N, P2, D}, 0, p2.options());
auto p1_a = p1.accessor<float, 3>();
auto p2_a = p2.accessor<float, 3>();
auto lengths1_a = lengths1.accessor<int64_t, 1>();
auto lengths2_a = lengths2.accessor<int64_t, 1>();
auto idxs_a = idxs.accessor<int64_t, 3>();
auto grad_dists_a = grad_dists.accessor<float, 3>();
auto grad_p1_a = grad_p1.accessor<float, 3>();
auto grad_p2_a = grad_p2.accessor<float, 3>();
for (int n = 0; n < N; ++n) {
const int64_t length1 = lengths1_a[n];
int64_t length2 = lengths2_a[n];
length2 = (length2 < K) ? length2 : K;
for (int64_t i1 = 0; i1 < length1; ++i1) {
for (int64_t k = 0; k < length2; ++k) {
const int64_t i2 = idxs_a[n][i1][k];
for (int64_t d = 0; d < D; ++d) {
const float diff =
2.0f * grad_dists_a[n][i1][k] * (p1_a[n][i1][d] - p2_a[n][i2][d]);
grad_p1_a[n][i1][d] += diff;
grad_p2_a[n][i2][d] += -1.0f * diff;
}
}
}
}
return std::make_tuple(grad_p1, grad_p2);
}

View File

@ -3,6 +3,7 @@
from .cubify import cubify from .cubify import cubify
from .graph_conv import GraphConv from .graph_conv import GraphConv
from .knn import knn_gather, knn_points
from .mesh_face_areas_normals import mesh_face_areas_normals from .mesh_face_areas_normals import mesh_face_areas_normals
from .nearest_neighbor_points import nn_points_idx from .nearest_neighbor_points import nn_points_idx
from .packed_to_padded import packed_to_padded, padded_to_packed from .packed_to_padded import packed_to_padded, padded_to_packed

View File

@ -1,152 +1,215 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from collections import namedtuple
from typing import Union
import torch import torch
from pytorch3d import _C from pytorch3d import _C
from torch.autograd import Function
from torch.autograd.function import once_differentiable
def knn_points_idx( _KNN = namedtuple("KNN", "dists idx knn")
p1,
p2,
K: int, class _knn_points(Function):
lengths1=None, """
lengths2=None, Torch autograd Function wrapper for KNN C++/CUDA implementations.
sorted: bool = False, """
@staticmethod
def forward(ctx, p1, p2, lengths1, lengths2, K, version):
"""
K-Nearest neighbors on point clouds.
Args:
p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each
containing up to P1 points of dimension D.
p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each
containing up to P2 points of dimension D.
lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
length of each pointcloud in p1. Or None to indicate that every cloud has
length P1.
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
length of each pointcloud in p2. Or None to indicate that every cloud has
length P2.
K: Integer giving the number of nearest neighbors to return.
version: Which KNN implementation to use in the backend. If version=-1,
the correct implementation is selected based on the shapes of the inputs.
Returns:
p1_dists: Tensor of shape (N, P1, K) giving the squared distances to
the nearest neighbors. This is padded with zeros both where a cloud in p2
has fewer than K points and where a cloud in p1 has fewer than P1 points.
p1_idx: LongTensor of shape (N, P1, K) giving the indices of the
K nearest neighbors from points in p1 to points in p2.
Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest
neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud
in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points.
"""
idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, K, version)
# sort KNN in ascending order if K > 1
if K > 1:
if lengths2.min() < K:
P1 = p1.shape[1]
mask = lengths2[:, None] <= torch.arange(K, device=dists.device)[None]
# mask has shape [N, K], true where dists irrelevant
mask = mask[:, None].expand(-1, P1, -1)
# mask has shape [N, P1, K], true where dists irrelevant
dists[mask] = float("inf")
dists, sort_idx = dists.sort(dim=2)
dists[mask] = 0
else:
dists, sort_idx = dists.sort(dim=2)
idx = idx.gather(2, sort_idx)
ctx.save_for_backward(p1, p2, lengths1, lengths2, idx)
return dists, idx
@staticmethod
@once_differentiable
def backward(ctx, grad_dists, grad_idx):
p1, p2, lengths1, lengths2, idx = ctx.saved_tensors
# TODO(gkioxari) Change cast to floats once we add support for doubles.
if not (grad_dists.dtype == torch.float32):
grad_dists = grad_dists.float()
if not (p1.dtype == torch.float32):
p1 = p1.float()
if not (p2.dtype == torch.float32):
p2 = p2.float()
grad_p1, grad_p2 = _C.knn_points_backward(
p1, p2, lengths1, lengths2, idx, grad_dists
)
return grad_p1, grad_p2, None, None, None, None
def knn_points(
p1: torch.Tensor,
p2: torch.Tensor,
lengths1: Union[torch.Tensor, None] = None,
lengths2: Union[torch.Tensor, None] = None,
K: int = 1,
version: int = -1, version: int = -1,
return_nn: bool = False,
): ):
""" """
K-Nearest neighbors on point clouds. K-Nearest neighbors on point clouds.
Args: Args:
p1: Tensor of shape (N, P1, D) giving a batch of point clouds, each p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each
containing up to P1 points of dimension D. containing up to P1 points of dimension D.
p2: Tensor of shape (N, P2, D) giving a batch of point clouds, each p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each
containing up to P2 points of dimension D. containing up to P2 points of dimension D.
K: Integer giving the number of nearest neighbors to return.
lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
length of each pointcloud in p1. Or None to indicate that every cloud has length of each pointcloud in p1. Or None to indicate that every cloud has
length P1. length P1.
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
length of each pointcloud in p2. Or None to indicate that every cloud has length of each pointcloud in p2. Or None to indicate that every cloud has
length P2. length P2.
sorted: Whether to sort the resulting points. K: Integer giving the number of nearest neighbors to return.
version: Which KNN implementation to use in the backend. If version=-1, version: Which KNN implementation to use in the backend. If version=-1,
the correct implementation is selected based on the shapes of the inputs. the correct implementation is selected based on the shapes of the inputs.
return_nn: If set to True returns the K nearest neighors in p2 for each point in p1.
Returns: Returns:
p1_neighbor_idx: LongTensor of shape (N, P1, K) giving the indices of the p1_idx: LongTensor of shape (N, P1, K) giving the indices of the
K nearest neighbors from points in p1 to points in p2. K nearest neighbors from points in p1 to points in p2.
Concretely, if idx[n, i, k] = j then p2[n, j] is one of the K nearest Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest
neighbors to p1[n, i] in p2[n]. If sorted=True, then p2[n, j] is the kth neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud
nearest neighbor to p1[n, i]. This is padded with zeros both where a cloud
in p2 has fewer than K points and where a cloud in p1 has fewer than P1 in p2 has fewer than K points and where a cloud in p1 has fewer than P1
points. points.
If you want an (N, P1, K, D) tensor of the actual points, you can get it
using
p2[:, :, None].expand(-1, -1, K, -1).gather(1,
x_idx[:, :, :, None].expand(-1, -1, -1, D)
)
If K=1 and you want an (N, P1, D) tensor of the actual points, use
p2.gather(1, x_idx.expand(-1, -1, D))
p1_neighbor_dists: Tensor of shape (N, P1, K) giving the squared distances to p1_dists: Tensor of shape (N, P1, K) giving the squared distances to
the nearest neighbors. This is padded with zeros both where a cloud in p2 the nearest neighbors. This is padded with zeros both where a cloud in p2
has fewer than K points and where a cloud in p1 has fewer than P1 points. has fewer than K points and where a cloud in p1 has fewer than P1 points.
Warning: this is calculated outside of the autograd framework.
p2_nn: Tensor of shape (N, P1, K, D) giving the K nearest neighbors in p2 for
each point in p1. Concretely, `p2_nn[n, i, k]` gives the k-th nearest neighbor
for `p1[n, i]`. Returned if `return_nn` is True.
The nearest neighbors are collected using `knn_gather`
.. code-block::
p2_nn = knn_gather(p2, p1_idx, lengths2)
which is a helper function that allows indexing any tensor of shape (N, P2, U) with
the indices `p1_idx` returned by `knn_points`. The outout is a tensor
of shape (N, P1, K, U).
""" """
if p1.shape[0] != p2.shape[0]:
raise ValueError("pts1 and pts2 must have the same batch dimension.")
if p1.shape[2] != p2.shape[2]:
raise ValueError("pts1 and pts2 must have the same point dimension.")
P1 = p1.shape[1] P1 = p1.shape[1]
P2 = p2.shape[1] P2 = p2.shape[1]
if lengths1 is None: if lengths1 is None:
lengths1 = torch.full((p1.shape[0],), P1, dtype=torch.int64, device=p1.device) lengths1 = torch.full((p1.shape[0],), P1, dtype=torch.int64, device=p1.device)
if lengths2 is None: if lengths2 is None:
lengths2 = torch.full((p1.shape[0],), P2, dtype=torch.int64, device=p1.device) lengths2 = torch.full((p1.shape[0],), P2, dtype=torch.int64, device=p1.device)
idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, K, version)
if sorted: p1_dists, p1_idx = _knn_points.apply(p1, p2, lengths1, lengths2, K, version)
if lengths2.min() < K:
device = dists.device p2_nn = None
mask1 = lengths2[:, None] <= torch.arange(K, device=device)[None] if return_nn:
# mask1 has shape [N, K], true where dists irrelevant p2_nn = knn_gather(p2, p1_idx, lengths2)
mask2 = mask1[:, None].expand(-1, P1, -1)
# mask2 has shape [N, P1, K], true where dists irrelevant return _KNN(dists=p1_dists, idx=p1_idx, knn=p2_nn if return_nn else None)
dists[mask2] = float("inf")
dists, sort_idx = dists.sort(dim=2)
dists[mask2] = 0
else:
dists, sort_idx = dists.sort(dim=2)
idx = idx.gather(2, sort_idx)
return idx, dists
@torch.no_grad() def knn_gather(
def _knn_points_idx_naive(p1, p2, K: int, lengths1, lengths2) -> torch.Tensor: x: torch.Tensor, idx: torch.Tensor, lengths: Union[torch.Tensor, None] = None
):
""" """
Naive PyTorch implementation of K-Nearest Neighbors. A helper function for knn that allows indexing a tensor x with the indices `idx`
returned by `knn_points`.
This is much less efficient than _C.knn_points_idx, but we include this For example, if `dists, idx = knn_points(p, x, lengths_p, lengths, K)`
naive implementation for testing and benchmarking. where p is a tensor of shape (N, L, D) and x a tensor of shape (N, M, D),
then one can compute the K nearest neighbors of p with `p_nn = knn_gather(x, idx, lengths)`.
It can also be applied for any tensor x of shape (N, M, U) where U != D.
Args: Args:
p1: Tensor of shape (N, P1, D) giving a batch of point clouds, each x: Tensor of shape (N, M, U) containing U-dimensional features to
containing up to P1 points of dimension D. be gathered.
p2: Tensor of shape (N, P2, D) giving a batch of point clouds, each idx: LongTensor of shape (N, L, K) giving the indices returned by `knn_points`.
containing up to P2 points of dimension D. lengths: LongTensor of shape (N,) of values in the range [0, M], giving the
K: Integer giving the number of nearest neighbors to return. length of each example in the batch in x. Or None to indicate that every
lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the example has length M.
length of each pointcloud in p1. Or None to indicate that every cloud has
length P1.
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
length of each pointcloud in p2. Or None to indicate that every cloud has
length P2.
Returns: Returns:
idx: LongTensor of shape (N, P1, K) giving the indices of the x_out: Tensor of shape (N, L, K, U) resulting from gathering the elements of x
K nearest neighbors from points in p1 to points in p2. with idx, s.t. `x_out[n, l, k] = x[n, idx[n, l, k]]`.
Concretely, if idx[n, i, k] = j then p2[n, j] is the kth nearest neighbor If `k > lengths[n]` then `x_out[n, l, k]` is filled with 0.0.
to p1[n, i]. This is padded with zeros both where a cloud in p2 has fewer
than K points and where a cloud in p1 has fewer than P1 points.
dists: Tensor of shape (N, P1, K) giving the squared distances to the nearest
neighbors. This is padded with zeros both where a cloud in p2 has fewer than
K points and where a cloud in p1 has fewer than P1 points.
""" """
N, P1, D = p1.shape N, M, U = x.shape
_N, P2, _D = p2.shape _N, L, K = idx.shape
assert N == _N and D == _D if N != _N:
raise ValueError("x and idx must have same batch dimension.")
if lengths1 is None: if lengths is None:
lengths1 = torch.full((N,), P1, dtype=torch.int64, device=p1.device) lengths = torch.full((x.shape[0],), M, dtype=torch.int64, device=x.device)
if lengths2 is None:
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device)
p1_copy = p1.clone() idx_expanded = idx[:, :, :, None].expand(-1, -1, -1, U)
p2_copy = p2.clone() # idx_expanded has shape [N, L, K, U]
# We pad the values with infinities so that the smallest differences are x_out = x[:, :, None].expand(-1, -1, K, -1).gather(1, idx_expanded)
# among actual points. # p2_nn has shape [N, L, K, U]
inf = float("inf")
p1_mask = torch.arange(P1, device=p1.device)[None] >= lengths1[:, None]
p1_copy[p1_mask] = inf
p2_copy[torch.arange(P2, device=p1.device)[None] >= lengths2[:, None]] = -inf
# view is safe here: we are merely adding extra dimensions of length 1 needs_mask = lengths.min() < K
diffs = p1_copy.view(N, P1, 1, D) - p2_copy.view(N, 1, P2, D) if needs_mask:
dists2 = (diffs * diffs).sum(dim=3) # mask has shape [N, K], true where idx is irrelevant because
# there is less number of points in p2 than K
mask = lengths[:, None] <= torch.arange(K, device=x.device)[None]
# We always sort, because this works well with padding. # expand mask to shape [N, L, K, U]
out = dists2.topk(min(K, P2), dim=2, largest=False, sorted=True) mask = mask[:, None].expand(-1, L, -1)
mask = mask[:, :, :, None].expand(-1, -1, -1, U)
x_out[mask] = 0.0
out_indices = out.indices return x_out
out_values = out.values
if P2 < K:
# Need to add padding
pad_shape = (N, P1, K - P2)
out_indices = torch.cat([out_indices, out_indices.new_zeros(pad_shape)], 2)
out_values = torch.cat([out_values, out_values.new_zeros(pad_shape)], 2)
K_mask = torch.arange(K, device=p1.device)[None] >= lengths2[:, None]
# Create a combined mask for where the points in p1 are padded
# or the corresponding p2 has fewer than K points.
p1_K_mask = p1_mask[:, :, None] | K_mask[:, None, :]
out_indices[p1_K_mask] = 0
out_values[p1_K_mask] = 0
return out_indices, out_values

View File

@ -2,180 +2,25 @@
from itertools import product from itertools import product
import torch
from fvcore.common.benchmark import benchmark from fvcore.common.benchmark import benchmark
from pytorch3d import _C from test_knn import TestKNN
from pytorch3d.ops.knn import _knn_points_idx_naive
def bm_knn() -> None: def bm_knn() -> None:
""" Entry point for the benchmark """
benchmark_knn_cpu()
benchmark_knn_cuda_vs_naive()
benchmark_knn_cuda_versions()
benchmark_knn_cuda_versions_ragged()
backends = ["cpu", "cuda:0"]
def benchmark_knn_cuda_versions() -> None: kwargs_list = []
# Compare our different KNN implementations, Ns = [32]
# and also compare against our existing 1-NN P1s = [256]
Ns = [1, 2] P2s = [128, 512]
Ps = [4096, 16384]
Ds = [3] Ds = [3]
Ks = [1, 4, 16, 64] Ks = [24]
versions = [0, 1, 2, 3] test_cases = product(Ns, P1s, P2s, Ds, Ks, backends)
knn_kwargs, nn_kwargs = [], [] for case in test_cases:
for N, P, D, K, version in product(Ns, Ps, Ds, Ks, versions): N, P1, P2, D, K, b = case
if version == 2 and K > 32: kwargs_list.append({"N": N, "P1": P1, "P2": P2, "D": D, "K": K, "device": b})
continue
if version == 3 and K > 4:
continue
knn_kwargs.append({"N": N, "D": D, "P": P, "K": K, "v": version})
for N, P, D in product(Ns, Ps, Ds):
nn_kwargs.append({"N": N, "D": D, "P": P})
benchmark(knn_cuda_with_init, "KNN_CUDA_VERSIONS", knn_kwargs, warmup_iters=1)
benchmark(nn_cuda_with_init, "NN_CUDA", nn_kwargs, warmup_iters=1)
benchmark(TestKNN.knn_square, "KNN_SQUARE", kwargs_list, warmup_iters=1)
def benchmark_knn_cuda_versions_ragged() -> None: benchmark(TestKNN.knn_ragged, "KNN_RAGGED", kwargs_list, warmup_iters=1)
# Compare our different KNN implementations,
# and also compare against our existing 1-NN
Ns = [8]
Ps = [4096, 16384]
Ds = [3]
Ks = [1, 4, 16, 64]
versions = [0, 1, 2, 3]
knn_kwargs = []
for N, P, D, K, version in product(Ns, Ps, Ds, Ks, versions):
if version == 2 and K > 32:
continue
if version == 3 and K > 4:
continue
knn_kwargs.append({"N": N, "D": D, "P": P, "K": K, "v": version})
benchmark(knn_cuda_with_init, "KNN_CUDA_COMPARISON", knn_kwargs, warmup_iters=1)
benchmark(knn_cuda_ragged, "KNN_CUDA_RAGGED", knn_kwargs, warmup_iters=1)
def benchmark_knn_cuda_vs_naive() -> None:
# Compare against naive pytorch version of KNN
Ns = [1, 2, 4]
Ps = [1024, 4096, 16384, 65536]
Ds = [3]
Ks = [1, 2, 4, 8, 16]
knn_kwargs, naive_kwargs = [], []
for N, P, D, K in product(Ns, Ps, Ds, Ks):
knn_kwargs.append({"N": N, "D": D, "P": P, "K": K})
if P <= 4096:
naive_kwargs.append({"N": N, "D": D, "P": P, "K": K})
benchmark(
knn_python_cuda_with_init, "KNN_CUDA_PYTHON", naive_kwargs, warmup_iters=1
)
benchmark(knn_cuda_with_init, "KNN_CUDA", knn_kwargs, warmup_iters=1)
def benchmark_knn_cpu() -> None:
Ns = [1, 2]
Ps = [256, 512]
Ds = [3]
Ks = [1, 2, 4]
knn_kwargs, nn_kwargs = [], []
for N, P, D, K in product(Ns, Ps, Ds, Ks):
knn_kwargs.append({"N": N, "D": D, "P": P, "K": K})
for N, P, D in product(Ns, Ps, Ds):
nn_kwargs.append({"N": N, "D": D, "P": P})
benchmark(knn_python_cpu_with_init, "KNN_CPU_PYTHON", knn_kwargs, warmup_iters=1)
benchmark(knn_cpu_with_init, "KNN_CPU_CPP", knn_kwargs, warmup_iters=1)
benchmark(nn_cpu_with_init, "NN_CPU_CPP", nn_kwargs, warmup_iters=1)
def knn_cuda_with_init(N, D, P, K, v=-1):
device = torch.device("cuda:0")
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
torch.cuda.synchronize()
def knn():
_C.knn_points_idx(x, y, lengths, lengths, K, v)
torch.cuda.synchronize()
return knn
def knn_cuda_ragged(N, D, P, K, v=-1):
device = torch.device("cuda:0")
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
lengths1 = torch.randint(P, size=(N,), device=device, dtype=torch.int64)
lengths2 = torch.randint(P, size=(N,), device=device, dtype=torch.int64)
torch.cuda.synchronize()
def knn():
_C.knn_points_idx(x, y, lengths1, lengths2, K, v)
torch.cuda.synchronize()
return knn
def knn_cpu_with_init(N, D, P, K):
device = torch.device("cpu")
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
def knn():
_C.knn_points_idx(x, y, lengths, lengths, K, -1)
return knn
def knn_python_cuda_with_init(N, D, P, K):
device = torch.device("cuda")
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
torch.cuda.synchronize()
def knn():
_knn_points_idx_naive(x, y, K=K, lengths1=lengths, lengths2=lengths)
torch.cuda.synchronize()
return knn
def knn_python_cpu_with_init(N, D, P, K):
device = torch.device("cpu")
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
def knn():
_knn_points_idx_naive(x, y, K=K, lengths1=lengths, lengths2=lengths)
return knn
def nn_cuda_with_init(N, D, P):
device = torch.device("cuda")
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
torch.cuda.synchronize()
def knn():
_C.nn_points_idx(x, y)
torch.cuda.synchronize()
return knn
def nn_cpu_with_init(N, D, P):
device = torch.device("cpu")
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
def knn():
_C.nn_points_idx(x, y)
return knn

View File

@ -4,116 +4,187 @@ import unittest
from itertools import product from itertools import product
import torch import torch
from pytorch3d.ops.knn import _knn_points_idx_naive, knn_points_idx from common_testing import TestCaseMixin
from pytorch3d.ops.knn import _KNN, knn_gather, knn_points
class TestKNN(unittest.TestCase): class TestKNN(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
super().setUp() super().setUp()
torch.manual_seed(1) torch.manual_seed(1)
def _check_knn_result(self, out1, out2, sorted): @staticmethod
# When sorted=True, points should be sorted by distance and should def _knn_points_naive(p1, p2, lengths1, lengths2, K: int) -> torch.Tensor:
# match between implementations. When sorted=False we we only want to """
# check that we got the same set of indices, so we sort the indices by Naive PyTorch implementation of K-Nearest Neighbors.
# index value. Returns always sorted results
idx1, dist1 = out1 """
idx2, dist2 = out2 N, P1, D = p1.shape
if not sorted: _N, P2, _D = p2.shape
idx1 = idx1.sort(dim=2).values
idx2 = idx2.sort(dim=2).values
dist1 = dist1.sort(dim=2).values
dist2 = dist2.sort(dim=2).values
if not torch.all(idx1 == idx2):
print(idx1)
print(idx2)
self.assertTrue(torch.all(idx1 == idx2))
self.assertTrue(torch.allclose(dist1, dist2))
def test_knn_vs_python_cpu_square(self): assert N == _N and D == _D
""" Test CPU output vs PyTorch implementation """
device = torch.device("cpu")
Ns = [1, 4]
Ds = [2, 3]
P1s = [1, 10, 101]
P2s = [10, 101]
Ks = [1, 3, 10]
sorts = [True, False]
factors = [Ns, Ds, P1s, P2s, Ks, sorts]
for N, D, P1, P2, K, sort in product(*factors):
lengths1 = torch.full((N,), P1, dtype=torch.int64, device=device)
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=device)
x = torch.randn(N, P1, D, device=device)
y = torch.randn(N, P2, D, device=device)
out1 = _knn_points_idx_naive(
x, y, lengths1=lengths1, lengths2=lengths2, K=K
)
out2 = knn_points_idx(
x, y, K=K, lengths1=lengths1, lengths2=lengths2, sorted=sort
)
self._check_knn_result(out1, out2, sort)
def test_knn_vs_python_cuda_square(self): if lengths1 is None:
""" Test CUDA output vs PyTorch implementation """ lengths1 = torch.full((N,), P1, dtype=torch.int64, device=p1.device)
device = torch.device("cuda") if lengths2 is None:
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device)
dists = torch.zeros((N, P1, K), dtype=torch.float32, device=p1.device)
idx = torch.zeros((N, P1, K), dtype=torch.int64, device=p1.device)
for n in range(N):
num1 = lengths1[n].item()
num2 = lengths2[n].item()
pp1 = p1[n, :num1].view(num1, 1, D)
pp2 = p2[n, :num2].view(1, num2, D)
diff = pp1 - pp2
diff = (diff * diff).sum(2)
num2 = min(num2, K)
for i in range(num1):
dd = diff[i]
srt_dd, srt_idx = dd.sort()
dists[n, i, :num2] = srt_dd[:num2]
idx[n, i, :num2] = srt_idx[:num2]
return _KNN(dists=dists, idx=idx, knn=None)
def _knn_vs_python_square_helper(self, device):
Ns = [1, 4] Ns = [1, 4]
Ds = [2, 3, 8] Ds = [3, 5, 8]
P1s = [1, 8, 64, 128, 1001] P1s = [8, 24]
P2s = [32, 128, 513] P2s = [8, 16, 32]
Ks = [1, 3, 10] Ks = [1, 3, 10]
sorts = [True, False]
versions = [0, 1, 2, 3] versions = [0, 1, 2, 3]
factors = [Ns, Ds, P1s, P2s, Ks, sorts] factors = [Ns, Ds, P1s, P2s, Ks]
for N, D, P1, P2, K, sort in product(*factors): for N, D, P1, P2, K in product(*factors):
x = torch.randn(N, P1, D, device=device)
y = torch.randn(N, P2, D, device=device)
out1 = _knn_points_idx_naive(x, y, lengths1=None, lengths2=None, K=K)
for version in versions: for version in versions:
if version == 3 and K > 4: if version == 3 and K > 4:
continue continue
out2 = knn_points_idx(x, y, K=K, sorted=sort, version=version) x = torch.randn(N, P1, D, device=device, requires_grad=True)
self._check_knn_result(out1, out2, sort) x_cuda = x.clone().detach()
x_cuda.requires_grad_(True)
y = torch.randn(N, P2, D, device=device, requires_grad=True)
y_cuda = y.clone().detach()
y_cuda.requires_grad_(True)
def test_knn_vs_python_cpu_ragged(self): # forward
out1 = self._knn_points_naive(x, y, lengths1=None, lengths2=None, K=K)
out2 = knn_points(x_cuda, y_cuda, K=K, version=version)
self.assertClose(out1[0], out2[0])
self.assertTrue(torch.all(out1[1] == out2[1]))
# backward
grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device)
loss1 = (out1.dists * grad_dist).sum()
loss1.backward()
loss2 = (out2.dists * grad_dist).sum()
loss2.backward()
self.assertClose(x_cuda.grad, x.grad, atol=5e-6)
self.assertClose(y_cuda.grad, y.grad, atol=5e-6)
def test_knn_vs_python_square_cpu(self):
device = torch.device("cpu") device = torch.device("cpu")
lengths1 = torch.tensor([10, 100, 10, 100], device=device, dtype=torch.int64) self._knn_vs_python_square_helper(device)
lengths2 = torch.tensor([10, 10, 100, 100], device=device, dtype=torch.int64)
N = 4
D = 3
Ks = [1, 9, 10, 11, 101]
sorts = [False, True]
factors = [Ks, sorts]
for K, sort in product(*factors):
x = torch.randn(N, lengths1.max(), D, device=device)
y = torch.randn(N, lengths2.max(), D, device=device)
out1 = _knn_points_idx_naive(
x, y, lengths1=lengths1, lengths2=lengths2, K=K
)
out2 = knn_points_idx(
x, y, lengths1=lengths1, lengths2=lengths2, K=K, sorted=sort
)
self._check_knn_result(out1, out2, sort)
def test_knn_vs_python_cuda_ragged(self): def test_knn_vs_python_square_cuda(self):
device = torch.device("cuda") device = torch.device("cuda:0")
lengths1 = torch.tensor([10, 100, 10, 100], device=device, dtype=torch.int64) self._knn_vs_python_square_helper(device)
lengths2 = torch.tensor([10, 10, 100, 100], device=device, dtype=torch.int64)
N = 4 def _knn_vs_python_ragged_helper(self, device):
D = 3 Ns = [1, 4]
Ks = [1, 9, 10, 11, 101] Ds = [3, 5, 8]
sorts = [True, False] P1s = [8, 24]
versions = [0, 1, 2, 3] P2s = [8, 16, 32]
factors = [Ks, sorts] Ks = [1, 3, 10]
for K, sort in product(*factors): factors = [Ns, Ds, P1s, P2s, Ks]
x = torch.randn(N, lengths1.max(), D, device=device) for N, D, P1, P2, K in product(*factors):
y = torch.randn(N, lengths2.max(), D, device=device) x = torch.rand((N, P1, D), device=device, requires_grad=True)
out1 = _knn_points_idx_naive( y = torch.rand((N, P2, D), device=device, requires_grad=True)
lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device)
x_csrc = x.clone().detach()
x_csrc.requires_grad_(True)
y_csrc = y.clone().detach()
y_csrc.requires_grad_(True)
# forward
out1 = self._knn_points_naive(
x, y, lengths1=lengths1, lengths2=lengths2, K=K x, y, lengths1=lengths1, lengths2=lengths2, K=K
) )
for version in versions: out2 = knn_points(x_csrc, y_csrc, lengths1=lengths1, lengths2=lengths2, K=K)
if version == 3 and K > 4: self.assertClose(out1[0], out2[0])
continue self.assertTrue(torch.all(out1[1] == out2[1]))
out2 = knn_points_idx(
x, y, lengths1=lengths1, lengths2=lengths2, K=K, sorted=sort # backward
) grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device)
self._check_knn_result(out1, out2, sort) loss1 = (out1.dists * grad_dist).sum()
loss1.backward()
loss2 = (out2.dists * grad_dist).sum()
loss2.backward()
self.assertClose(x_csrc.grad, x.grad, atol=5e-6)
self.assertClose(y_csrc.grad, y.grad, atol=5e-6)
def test_knn_vs_python_ragged_cpu(self):
device = torch.device("cpu")
self._knn_vs_python_ragged_helper(device)
def test_knn_vs_python_ragged_cuda(self):
device = torch.device("cuda:0")
self._knn_vs_python_ragged_helper(device)
def test_knn_gather(self):
device = torch.device("cuda:0")
N, P1, P2, K, D = 4, 16, 12, 8, 3
x = torch.rand((N, P1, D), device=device)
y = torch.rand((N, P2, D), device=device)
lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device)
out = knn_points(x, y, lengths1=lengths1, lengths2=lengths2, K=K)
y_nn = knn_gather(y, out.idx, lengths2)
for n in range(N):
for p1 in range(P1):
for k in range(K):
if k < lengths2[n]:
self.assertClose(y_nn[n, p1, k], y[n, out.idx[n, p1, k]])
else:
self.assertTrue(torch.all(y_nn[n, p1, k] == 0.0))
@staticmethod
def knn_square(N: int, P1: int, P2: int, D: int, K: int, device: str):
device = torch.device(device)
pts1 = torch.randn(N, P1, D, device=device, requires_grad=True)
pts2 = torch.randn(N, P2, D, device=device, requires_grad=True)
grad_dists = torch.randn(N, P1, K, device=device)
torch.cuda.synchronize()
def output():
out = knn_points(pts1, pts2, K=K)
loss = (out.dists * grad_dists).sum()
loss.backward()
torch.cuda.synchronize()
return output
@staticmethod
def knn_ragged(N: int, P1: int, P2: int, D: int, K: int, device: str):
device = torch.device(device)
pts1 = torch.rand((N, P1, D), device=device, requires_grad=True)
pts2 = torch.rand((N, P2, D), device=device, requires_grad=True)
lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device)
grad_dists = torch.randn(N, P1, K, device=device)
torch.cuda.synchronize()
def output():
out = knn_points(pts1, pts2, lengths1=lengths1, lengths2=lengths2, K=K)
loss = (out.dists * grad_dists).sum()
loss.backward()
torch.cuda.synchronize()
return output