diff --git a/pytorch3d/csrc/ext.cpp b/pytorch3d/csrc/ext.cpp index c0141f28..b162bc59 100644 --- a/pytorch3d/csrc/ext.cpp +++ b/pytorch3d/csrc/ext.cpp @@ -20,6 +20,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("packed_to_padded", &PackedToPadded); m.def("padded_to_packed", &PaddedToPacked); m.def("knn_points_idx", &KNearestNeighborIdx); + m.def("knn_points_backward", &KNearestNeighborBackward); m.def("nn_points_idx", &NearestNeighborIdx); m.def("gather_scatter", &gather_scatter); m.def("rasterize_points", &RasterizePoints); diff --git a/pytorch3d/csrc/knn/knn.cu b/pytorch3d/csrc/knn/knn.cu index b9580e5f..20250d7d 100644 --- a/pytorch3d/csrc/knn/knn.cu +++ b/pytorch3d/csrc/knn/knn.cu @@ -412,3 +412,93 @@ std::tuple KNearestNeighborIdxCuda( return std::make_tuple(idxs, dists); } + +// ------------------------------------------------------------- // +// Backward Operators // +// ------------------------------------------------------------- // + +// TODO(gkioxari) support all data types once AtomicAdd supports doubles. +// Currently, support is for floats only. +__global__ void KNearestNeighborBackwardKernel( + const float* __restrict__ p1, // (N, P1, D) + const float* __restrict__ p2, // (N, P2, D) + const int64_t* __restrict__ lengths1, // (N,) + const int64_t* __restrict__ lengths2, // (N,) + const int64_t* __restrict__ idxs, // (N, P1, K) + const float* __restrict__ grad_dists, // (N, P1, K) + float* __restrict__ grad_p1, // (N, P1, D) + float* __restrict__ grad_p2, // (N, P2, D) + const size_t N, + const size_t P1, + const size_t P2, + const size_t K, + const size_t D) { + const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = gridDim.x * blockDim.x; + + for (size_t i = tid; i < N * P1 * K * D; i += stride) { + const size_t n = i / (P1 * K * D); // batch index + size_t rem = i % (P1 * K * D); + const size_t p1_idx = rem / (K * D); // index of point in p1 + rem = rem % (K * D); + const size_t k = rem / D; // k-th nearest neighbor + const size_t d = rem % D; // d-th dimension in the feature vector + + const size_t num1 = lengths1[n]; // number of valid points in p1 in batch + const size_t num2 = lengths2[n]; // number of valid points in p2 in batch + if ((p1_idx < num1) && (k < num2)) { + const float grad_dist = grad_dists[n * P1 * K + p1_idx * K + k]; + // index of point in p2 corresponding to the k-th nearest neighbor + const size_t p2_idx = idxs[n * P1 * K + p1_idx * K + k]; + const float diff = 2.0 * grad_dist * + (p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]); + atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff); + atomicAdd(grad_p2 + n * P2 * D + p2_idx * D + d, -1.0f * diff); + } + } +} + +std::tuple KNearestNeighborBackwardCuda( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + const at::Tensor& idxs, + const at::Tensor& grad_dists) { + const auto N = p1.size(0); + const auto P1 = p1.size(1); + const auto P2 = p2.size(1); + const auto D = p2.size(2); + const auto K = idxs.size(2); + + AT_ASSERTM(p2.size(2) == D, "Point sets must have the same last dimension"); + AT_ASSERTM(idxs.size(0) == N, "KNN idxs must have the same batch dimension"); + AT_ASSERTM( + idxs.size(1) == P1, "KNN idxs must have the same point dimension as p1"); + AT_ASSERTM(grad_dists.size(0) == N); + AT_ASSERTM(grad_dists.size(1) == P1); + AT_ASSERTM(grad_dists.size(2) == K); + + auto grad_p1 = at::zeros({N, P1, D}, p1.options()); + auto grad_p2 = at::zeros({N, P2, D}, p2.options()); + + const int blocks = 64; + const int threads = 512; + + KNearestNeighborBackwardKernel<<>>( + p1.data_ptr(), + p2.data_ptr(), + lengths1.data_ptr(), + lengths2.data_ptr(), + idxs.data_ptr(), + grad_dists.data_ptr(), + grad_p1.data_ptr(), + grad_p2.data_ptr(), + N, + P1, + P2, + K, + D); + + return std::make_tuple(grad_p1, grad_p2); +} diff --git a/pytorch3d/csrc/knn/knn.h b/pytorch3d/csrc/knn/knn.h index b447dfe2..321da6cd 100644 --- a/pytorch3d/csrc/knn/knn.h +++ b/pytorch3d/csrc/knn/knn.h @@ -16,8 +16,6 @@ // lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud. // lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud. // K: int giving the number of nearest points to return. -// sorted: bool telling whether to sort the K returned points by their -// distance. // version: Integer telling which implementation to use. // // Returns: @@ -67,3 +65,66 @@ std::tuple KNearestNeighborIdx( } return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, K); } + +// Compute gradients with respect to p1 and p2 +// +// Args: +// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each +// containing P1 points of dimension D. +// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each +// containing P2 points of dimension D. +// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud. +// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud. +// p1_neighbor_idx: LongTensor of shape (N, P1, K), where +// p1_neighbor_idx[n, i, k] = j means that the kth nearest +// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j]. +// It is padded with zeros so that it can be used easily in a later +// gather() operation. This is computed from the forward pass. +// grad_dists: FLoatTensor of shape (N, P1, K) which contains the input +// gradients. +// +// Returns: +// grad_p1: FloatTensor of shape (N, P1, D) containing the output gradients +// wrt p1. +// grad_p2: FloatTensor of shape (N, P2, D) containing the output gradients +// wrt p2. + +// CPU implementation. +std::tuple KNearestNeighborBackwardCpu( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + const at::Tensor& idxs, + const at::Tensor& grad_dists); + +// CUDA implementation +std::tuple KNearestNeighborBackwardCuda( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + const at::Tensor& idxs, + const at::Tensor& grad_dists); + +// Implementation which is exposed. +std::tuple KNearestNeighborBackward( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + const at::Tensor& idxs, + const at::Tensor& grad_dists) { + if (p1.is_cuda() || p2.is_cuda()) { +#ifdef WITH_CUDA + CHECK_CONTIGUOUS_CUDA(p1); + CHECK_CONTIGUOUS_CUDA(p2); + return KNearestNeighborBackwardCuda( + p1, p2, lengths1, lengths2, idxs, grad_dists); +#else + AT_ERROR("Not compiled with GPU support."); +#endif + } + return KNearestNeighborBackwardCpu( + p1, p2, lengths1, lengths2, idxs, grad_dists); +} diff --git a/pytorch3d/csrc/knn/knn_cpu.cpp b/pytorch3d/csrc/knn/knn_cpu.cpp index 84d18a65..0150a11d 100644 --- a/pytorch3d/csrc/knn/knn_cpu.cpp +++ b/pytorch3d/csrc/knn/knn_cpu.cpp @@ -57,3 +57,51 @@ std::tuple KNearestNeighborIdxCpu( } return std::make_tuple(idxs, dists); } + +// ------------------------------------------------------------- // +// Backward Operators // +// ------------------------------------------------------------- // + +std::tuple KNearestNeighborBackwardCpu( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + const at::Tensor& idxs, + const at::Tensor& grad_dists) { + const int N = p1.size(0); + const int P1 = p1.size(1); + const int D = p1.size(2); + const int P2 = p2.size(1); + const int K = idxs.size(2); + + torch::Tensor grad_p1 = torch::full({N, P1, D}, 0, p1.options()); + torch::Tensor grad_p2 = torch::full({N, P2, D}, 0, p2.options()); + + auto p1_a = p1.accessor(); + auto p2_a = p2.accessor(); + auto lengths1_a = lengths1.accessor(); + auto lengths2_a = lengths2.accessor(); + auto idxs_a = idxs.accessor(); + auto grad_dists_a = grad_dists.accessor(); + auto grad_p1_a = grad_p1.accessor(); + auto grad_p2_a = grad_p2.accessor(); + + for (int n = 0; n < N; ++n) { + const int64_t length1 = lengths1_a[n]; + int64_t length2 = lengths2_a[n]; + length2 = (length2 < K) ? length2 : K; + for (int64_t i1 = 0; i1 < length1; ++i1) { + for (int64_t k = 0; k < length2; ++k) { + const int64_t i2 = idxs_a[n][i1][k]; + for (int64_t d = 0; d < D; ++d) { + const float diff = + 2.0f * grad_dists_a[n][i1][k] * (p1_a[n][i1][d] - p2_a[n][i2][d]); + grad_p1_a[n][i1][d] += diff; + grad_p2_a[n][i2][d] += -1.0f * diff; + } + } + } + } + return std::make_tuple(grad_p1, grad_p2); +} diff --git a/pytorch3d/ops/__init__.py b/pytorch3d/ops/__init__.py index 7d84f6d6..0ca9eb6d 100644 --- a/pytorch3d/ops/__init__.py +++ b/pytorch3d/ops/__init__.py @@ -3,6 +3,7 @@ from .cubify import cubify from .graph_conv import GraphConv +from .knn import knn_gather, knn_points from .mesh_face_areas_normals import mesh_face_areas_normals from .nearest_neighbor_points import nn_points_idx from .packed_to_padded import packed_to_padded, padded_to_packed diff --git a/pytorch3d/ops/knn.py b/pytorch3d/ops/knn.py index 122336df..1043592b 100644 --- a/pytorch3d/ops/knn.py +++ b/pytorch3d/ops/knn.py @@ -1,152 +1,215 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +from collections import namedtuple +from typing import Union + import torch from pytorch3d import _C +from torch.autograd import Function +from torch.autograd.function import once_differentiable -def knn_points_idx( - p1, - p2, - K: int, - lengths1=None, - lengths2=None, - sorted: bool = False, +_KNN = namedtuple("KNN", "dists idx knn") + + +class _knn_points(Function): + """ + Torch autograd Function wrapper for KNN C++/CUDA implementations. + """ + + @staticmethod + def forward(ctx, p1, p2, lengths1, lengths2, K, version): + """ + K-Nearest neighbors on point clouds. + + Args: + p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each + containing up to P1 points of dimension D. + p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each + containing up to P2 points of dimension D. + lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the + length of each pointcloud in p1. Or None to indicate that every cloud has + length P1. + lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the + length of each pointcloud in p2. Or None to indicate that every cloud has + length P2. + K: Integer giving the number of nearest neighbors to return. + version: Which KNN implementation to use in the backend. If version=-1, + the correct implementation is selected based on the shapes of the inputs. + + Returns: + p1_dists: Tensor of shape (N, P1, K) giving the squared distances to + the nearest neighbors. This is padded with zeros both where a cloud in p2 + has fewer than K points and where a cloud in p1 has fewer than P1 points. + + p1_idx: LongTensor of shape (N, P1, K) giving the indices of the + K nearest neighbors from points in p1 to points in p2. + Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest + neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud + in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points. + """ + + idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, K, version) + + # sort KNN in ascending order if K > 1 + if K > 1: + if lengths2.min() < K: + P1 = p1.shape[1] + mask = lengths2[:, None] <= torch.arange(K, device=dists.device)[None] + # mask has shape [N, K], true where dists irrelevant + mask = mask[:, None].expand(-1, P1, -1) + # mask has shape [N, P1, K], true where dists irrelevant + dists[mask] = float("inf") + dists, sort_idx = dists.sort(dim=2) + dists[mask] = 0 + else: + dists, sort_idx = dists.sort(dim=2) + idx = idx.gather(2, sort_idx) + + ctx.save_for_backward(p1, p2, lengths1, lengths2, idx) + return dists, idx + + @staticmethod + @once_differentiable + def backward(ctx, grad_dists, grad_idx): + p1, p2, lengths1, lengths2, idx = ctx.saved_tensors + # TODO(gkioxari) Change cast to floats once we add support for doubles. + if not (grad_dists.dtype == torch.float32): + grad_dists = grad_dists.float() + if not (p1.dtype == torch.float32): + p1 = p1.float() + if not (p2.dtype == torch.float32): + p2 = p2.float() + grad_p1, grad_p2 = _C.knn_points_backward( + p1, p2, lengths1, lengths2, idx, grad_dists + ) + return grad_p1, grad_p2, None, None, None, None + + +def knn_points( + p1: torch.Tensor, + p2: torch.Tensor, + lengths1: Union[torch.Tensor, None] = None, + lengths2: Union[torch.Tensor, None] = None, + K: int = 1, version: int = -1, + return_nn: bool = False, ): """ K-Nearest neighbors on point clouds. Args: - p1: Tensor of shape (N, P1, D) giving a batch of point clouds, each + p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each containing up to P1 points of dimension D. - p2: Tensor of shape (N, P2, D) giving a batch of point clouds, each + p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each containing up to P2 points of dimension D. - K: Integer giving the number of nearest neighbors to return. lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the length of each pointcloud in p1. Or None to indicate that every cloud has length P1. lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the length of each pointcloud in p2. Or None to indicate that every cloud has length P2. - sorted: Whether to sort the resulting points. + K: Integer giving the number of nearest neighbors to return. version: Which KNN implementation to use in the backend. If version=-1, the correct implementation is selected based on the shapes of the inputs. + return_nn: If set to True returns the K nearest neighors in p2 for each point in p1. Returns: - p1_neighbor_idx: LongTensor of shape (N, P1, K) giving the indices of the + p1_idx: LongTensor of shape (N, P1, K) giving the indices of the K nearest neighbors from points in p1 to points in p2. - Concretely, if idx[n, i, k] = j then p2[n, j] is one of the K nearest - neighbors to p1[n, i] in p2[n]. If sorted=True, then p2[n, j] is the kth - nearest neighbor to p1[n, i]. This is padded with zeros both where a cloud + Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest + neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points. - If you want an (N, P1, K, D) tensor of the actual points, you can get it - using - p2[:, :, None].expand(-1, -1, K, -1).gather(1, - x_idx[:, :, :, None].expand(-1, -1, -1, D) - ) - If K=1 and you want an (N, P1, D) tensor of the actual points, use - p2.gather(1, x_idx.expand(-1, -1, D)) - p1_neighbor_dists: Tensor of shape (N, P1, K) giving the squared distances to + p1_dists: Tensor of shape (N, P1, K) giving the squared distances to the nearest neighbors. This is padded with zeros both where a cloud in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points. - Warning: this is calculated outside of the autograd framework. + + p2_nn: Tensor of shape (N, P1, K, D) giving the K nearest neighbors in p2 for + each point in p1. Concretely, `p2_nn[n, i, k]` gives the k-th nearest neighbor + for `p1[n, i]`. Returned if `return_nn` is True. + The nearest neighbors are collected using `knn_gather` + + .. code-block:: + + p2_nn = knn_gather(p2, p1_idx, lengths2) + + which is a helper function that allows indexing any tensor of shape (N, P2, U) with + the indices `p1_idx` returned by `knn_points`. The outout is a tensor + of shape (N, P1, K, U). + """ + if p1.shape[0] != p2.shape[0]: + raise ValueError("pts1 and pts2 must have the same batch dimension.") + if p1.shape[2] != p2.shape[2]: + raise ValueError("pts1 and pts2 must have the same point dimension.") + P1 = p1.shape[1] P2 = p2.shape[1] + if lengths1 is None: lengths1 = torch.full((p1.shape[0],), P1, dtype=torch.int64, device=p1.device) if lengths2 is None: lengths2 = torch.full((p1.shape[0],), P2, dtype=torch.int64, device=p1.device) - idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, K, version) - if sorted: - if lengths2.min() < K: - device = dists.device - mask1 = lengths2[:, None] <= torch.arange(K, device=device)[None] - # mask1 has shape [N, K], true where dists irrelevant - mask2 = mask1[:, None].expand(-1, P1, -1) - # mask2 has shape [N, P1, K], true where dists irrelevant - dists[mask2] = float("inf") - dists, sort_idx = dists.sort(dim=2) - dists[mask2] = 0 - else: - dists, sort_idx = dists.sort(dim=2) - idx = idx.gather(2, sort_idx) - return idx, dists + + p1_dists, p1_idx = _knn_points.apply(p1, p2, lengths1, lengths2, K, version) + + p2_nn = None + if return_nn: + p2_nn = knn_gather(p2, p1_idx, lengths2) + + return _KNN(dists=p1_dists, idx=p1_idx, knn=p2_nn if return_nn else None) -@torch.no_grad() -def _knn_points_idx_naive(p1, p2, K: int, lengths1, lengths2) -> torch.Tensor: +def knn_gather( + x: torch.Tensor, idx: torch.Tensor, lengths: Union[torch.Tensor, None] = None +): """ - Naive PyTorch implementation of K-Nearest Neighbors. - - This is much less efficient than _C.knn_points_idx, but we include this - naive implementation for testing and benchmarking. + A helper function for knn that allows indexing a tensor x with the indices `idx` + returned by `knn_points`. + + For example, if `dists, idx = knn_points(p, x, lengths_p, lengths, K)` + where p is a tensor of shape (N, L, D) and x a tensor of shape (N, M, D), + then one can compute the K nearest neighbors of p with `p_nn = knn_gather(x, idx, lengths)`. + It can also be applied for any tensor x of shape (N, M, U) where U != D. Args: - p1: Tensor of shape (N, P1, D) giving a batch of point clouds, each - containing up to P1 points of dimension D. - p2: Tensor of shape (N, P2, D) giving a batch of point clouds, each - containing up to P2 points of dimension D. - K: Integer giving the number of nearest neighbors to return. - lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the - length of each pointcloud in p1. Or None to indicate that every cloud has - length P1. - lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the - length of each pointcloud in p2. Or None to indicate that every cloud has - length P2. - + x: Tensor of shape (N, M, U) containing U-dimensional features to + be gathered. + idx: LongTensor of shape (N, L, K) giving the indices returned by `knn_points`. + lengths: LongTensor of shape (N,) of values in the range [0, M], giving the + length of each example in the batch in x. Or None to indicate that every + example has length M. Returns: - idx: LongTensor of shape (N, P1, K) giving the indices of the - K nearest neighbors from points in p1 to points in p2. - Concretely, if idx[n, i, k] = j then p2[n, j] is the kth nearest neighbor - to p1[n, i]. This is padded with zeros both where a cloud in p2 has fewer - than K points and where a cloud in p1 has fewer than P1 points. - dists: Tensor of shape (N, P1, K) giving the squared distances to the nearest - neighbors. This is padded with zeros both where a cloud in p2 has fewer than - K points and where a cloud in p1 has fewer than P1 points. + x_out: Tensor of shape (N, L, K, U) resulting from gathering the elements of x + with idx, s.t. `x_out[n, l, k] = x[n, idx[n, l, k]]`. + If `k > lengths[n]` then `x_out[n, l, k]` is filled with 0.0. """ - N, P1, D = p1.shape - _N, P2, _D = p2.shape + N, M, U = x.shape + _N, L, K = idx.shape - assert N == _N and D == _D + if N != _N: + raise ValueError("x and idx must have same batch dimension.") - if lengths1 is None: - lengths1 = torch.full((N,), P1, dtype=torch.int64, device=p1.device) - if lengths2 is None: - lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device) + if lengths is None: + lengths = torch.full((x.shape[0],), M, dtype=torch.int64, device=x.device) - p1_copy = p1.clone() - p2_copy = p2.clone() + idx_expanded = idx[:, :, :, None].expand(-1, -1, -1, U) + # idx_expanded has shape [N, L, K, U] - # We pad the values with infinities so that the smallest differences are - # among actual points. - inf = float("inf") - p1_mask = torch.arange(P1, device=p1.device)[None] >= lengths1[:, None] - p1_copy[p1_mask] = inf - p2_copy[torch.arange(P2, device=p1.device)[None] >= lengths2[:, None]] = -inf + x_out = x[:, :, None].expand(-1, -1, K, -1).gather(1, idx_expanded) + # p2_nn has shape [N, L, K, U] - # view is safe here: we are merely adding extra dimensions of length 1 - diffs = p1_copy.view(N, P1, 1, D) - p2_copy.view(N, 1, P2, D) - dists2 = (diffs * diffs).sum(dim=3) + needs_mask = lengths.min() < K + if needs_mask: + # mask has shape [N, K], true where idx is irrelevant because + # there is less number of points in p2 than K + mask = lengths[:, None] <= torch.arange(K, device=x.device)[None] - # We always sort, because this works well with padding. - out = dists2.topk(min(K, P2), dim=2, largest=False, sorted=True) + # expand mask to shape [N, L, K, U] + mask = mask[:, None].expand(-1, L, -1) + mask = mask[:, :, :, None].expand(-1, -1, -1, U) + x_out[mask] = 0.0 - out_indices = out.indices - out_values = out.values - - if P2 < K: - # Need to add padding - pad_shape = (N, P1, K - P2) - out_indices = torch.cat([out_indices, out_indices.new_zeros(pad_shape)], 2) - out_values = torch.cat([out_values, out_values.new_zeros(pad_shape)], 2) - - K_mask = torch.arange(K, device=p1.device)[None] >= lengths2[:, None] - # Create a combined mask for where the points in p1 are padded - # or the corresponding p2 has fewer than K points. - p1_K_mask = p1_mask[:, :, None] | K_mask[:, None, :] - out_indices[p1_K_mask] = 0 - out_values[p1_K_mask] = 0 - return out_indices, out_values + return x_out diff --git a/tests/bm_knn.py b/tests/bm_knn.py index 38685811..4a96d64b 100644 --- a/tests/bm_knn.py +++ b/tests/bm_knn.py @@ -2,180 +2,25 @@ from itertools import product -import torch from fvcore.common.benchmark import benchmark -from pytorch3d import _C -from pytorch3d.ops.knn import _knn_points_idx_naive +from test_knn import TestKNN def bm_knn() -> None: - """ Entry point for the benchmark """ - benchmark_knn_cpu() - benchmark_knn_cuda_vs_naive() - benchmark_knn_cuda_versions() - benchmark_knn_cuda_versions_ragged() + backends = ["cpu", "cuda:0"] -def benchmark_knn_cuda_versions() -> None: - # Compare our different KNN implementations, - # and also compare against our existing 1-NN - Ns = [1, 2] - Ps = [4096, 16384] + kwargs_list = [] + Ns = [32] + P1s = [256] + P2s = [128, 512] Ds = [3] - Ks = [1, 4, 16, 64] - versions = [0, 1, 2, 3] - knn_kwargs, nn_kwargs = [], [] - for N, P, D, K, version in product(Ns, Ps, Ds, Ks, versions): - if version == 2 and K > 32: - continue - if version == 3 and K > 4: - continue - knn_kwargs.append({"N": N, "D": D, "P": P, "K": K, "v": version}) - for N, P, D in product(Ns, Ps, Ds): - nn_kwargs.append({"N": N, "D": D, "P": P}) - benchmark(knn_cuda_with_init, "KNN_CUDA_VERSIONS", knn_kwargs, warmup_iters=1) - benchmark(nn_cuda_with_init, "NN_CUDA", nn_kwargs, warmup_iters=1) + Ks = [24] + test_cases = product(Ns, P1s, P2s, Ds, Ks, backends) + for case in test_cases: + N, P1, P2, D, K, b = case + kwargs_list.append({"N": N, "P1": P1, "P2": P2, "D": D, "K": K, "device": b}) + benchmark(TestKNN.knn_square, "KNN_SQUARE", kwargs_list, warmup_iters=1) -def benchmark_knn_cuda_versions_ragged() -> None: - # Compare our different KNN implementations, - # and also compare against our existing 1-NN - Ns = [8] - Ps = [4096, 16384] - Ds = [3] - Ks = [1, 4, 16, 64] - versions = [0, 1, 2, 3] - knn_kwargs = [] - for N, P, D, K, version in product(Ns, Ps, Ds, Ks, versions): - if version == 2 and K > 32: - continue - if version == 3 and K > 4: - continue - knn_kwargs.append({"N": N, "D": D, "P": P, "K": K, "v": version}) - benchmark(knn_cuda_with_init, "KNN_CUDA_COMPARISON", knn_kwargs, warmup_iters=1) - benchmark(knn_cuda_ragged, "KNN_CUDA_RAGGED", knn_kwargs, warmup_iters=1) - - -def benchmark_knn_cuda_vs_naive() -> None: - # Compare against naive pytorch version of KNN - Ns = [1, 2, 4] - Ps = [1024, 4096, 16384, 65536] - Ds = [3] - Ks = [1, 2, 4, 8, 16] - knn_kwargs, naive_kwargs = [], [] - for N, P, D, K in product(Ns, Ps, Ds, Ks): - knn_kwargs.append({"N": N, "D": D, "P": P, "K": K}) - if P <= 4096: - naive_kwargs.append({"N": N, "D": D, "P": P, "K": K}) - benchmark( - knn_python_cuda_with_init, "KNN_CUDA_PYTHON", naive_kwargs, warmup_iters=1 - ) - benchmark(knn_cuda_with_init, "KNN_CUDA", knn_kwargs, warmup_iters=1) - - -def benchmark_knn_cpu() -> None: - Ns = [1, 2] - Ps = [256, 512] - Ds = [3] - Ks = [1, 2, 4] - knn_kwargs, nn_kwargs = [], [] - for N, P, D, K in product(Ns, Ps, Ds, Ks): - knn_kwargs.append({"N": N, "D": D, "P": P, "K": K}) - for N, P, D in product(Ns, Ps, Ds): - nn_kwargs.append({"N": N, "D": D, "P": P}) - benchmark(knn_python_cpu_with_init, "KNN_CPU_PYTHON", knn_kwargs, warmup_iters=1) - benchmark(knn_cpu_with_init, "KNN_CPU_CPP", knn_kwargs, warmup_iters=1) - benchmark(nn_cpu_with_init, "NN_CPU_CPP", nn_kwargs, warmup_iters=1) - - -def knn_cuda_with_init(N, D, P, K, v=-1): - device = torch.device("cuda:0") - x = torch.randn(N, P, D, device=device) - y = torch.randn(N, P, D, device=device) - lengths = torch.full((N,), P, dtype=torch.int64, device=device) - - torch.cuda.synchronize() - - def knn(): - _C.knn_points_idx(x, y, lengths, lengths, K, v) - torch.cuda.synchronize() - - return knn - - -def knn_cuda_ragged(N, D, P, K, v=-1): - device = torch.device("cuda:0") - x = torch.randn(N, P, D, device=device) - y = torch.randn(N, P, D, device=device) - lengths1 = torch.randint(P, size=(N,), device=device, dtype=torch.int64) - lengths2 = torch.randint(P, size=(N,), device=device, dtype=torch.int64) - torch.cuda.synchronize() - - def knn(): - _C.knn_points_idx(x, y, lengths1, lengths2, K, v) - torch.cuda.synchronize() - - return knn - - -def knn_cpu_with_init(N, D, P, K): - device = torch.device("cpu") - x = torch.randn(N, P, D, device=device) - y = torch.randn(N, P, D, device=device) - lengths = torch.full((N,), P, dtype=torch.int64, device=device) - - def knn(): - _C.knn_points_idx(x, y, lengths, lengths, K, -1) - - return knn - - -def knn_python_cuda_with_init(N, D, P, K): - device = torch.device("cuda") - x = torch.randn(N, P, D, device=device) - y = torch.randn(N, P, D, device=device) - lengths = torch.full((N,), P, dtype=torch.int64, device=device) - - torch.cuda.synchronize() - - def knn(): - _knn_points_idx_naive(x, y, K=K, lengths1=lengths, lengths2=lengths) - torch.cuda.synchronize() - - return knn - - -def knn_python_cpu_with_init(N, D, P, K): - device = torch.device("cpu") - x = torch.randn(N, P, D, device=device) - y = torch.randn(N, P, D, device=device) - lengths = torch.full((N,), P, dtype=torch.int64, device=device) - - def knn(): - _knn_points_idx_naive(x, y, K=K, lengths1=lengths, lengths2=lengths) - - return knn - - -def nn_cuda_with_init(N, D, P): - device = torch.device("cuda") - x = torch.randn(N, P, D, device=device) - y = torch.randn(N, P, D, device=device) - torch.cuda.synchronize() - - def knn(): - _C.nn_points_idx(x, y) - torch.cuda.synchronize() - - return knn - - -def nn_cpu_with_init(N, D, P): - device = torch.device("cpu") - x = torch.randn(N, P, D, device=device) - y = torch.randn(N, P, D, device=device) - - def knn(): - _C.nn_points_idx(x, y) - - return knn + benchmark(TestKNN.knn_ragged, "KNN_RAGGED", kwargs_list, warmup_iters=1) diff --git a/tests/test_knn.py b/tests/test_knn.py index 5fe3698a..d39df6f0 100644 --- a/tests/test_knn.py +++ b/tests/test_knn.py @@ -4,116 +4,187 @@ import unittest from itertools import product import torch -from pytorch3d.ops.knn import _knn_points_idx_naive, knn_points_idx +from common_testing import TestCaseMixin +from pytorch3d.ops.knn import _KNN, knn_gather, knn_points -class TestKNN(unittest.TestCase): +class TestKNN(TestCaseMixin, unittest.TestCase): def setUp(self) -> None: super().setUp() torch.manual_seed(1) - def _check_knn_result(self, out1, out2, sorted): - # When sorted=True, points should be sorted by distance and should - # match between implementations. When sorted=False we we only want to - # check that we got the same set of indices, so we sort the indices by - # index value. - idx1, dist1 = out1 - idx2, dist2 = out2 - if not sorted: - idx1 = idx1.sort(dim=2).values - idx2 = idx2.sort(dim=2).values - dist1 = dist1.sort(dim=2).values - dist2 = dist2.sort(dim=2).values - if not torch.all(idx1 == idx2): - print(idx1) - print(idx2) - self.assertTrue(torch.all(idx1 == idx2)) - self.assertTrue(torch.allclose(dist1, dist2)) + @staticmethod + def _knn_points_naive(p1, p2, lengths1, lengths2, K: int) -> torch.Tensor: + """ + Naive PyTorch implementation of K-Nearest Neighbors. + Returns always sorted results + """ + N, P1, D = p1.shape + _N, P2, _D = p2.shape - def test_knn_vs_python_cpu_square(self): - """ Test CPU output vs PyTorch implementation """ - device = torch.device("cpu") - Ns = [1, 4] - Ds = [2, 3] - P1s = [1, 10, 101] - P2s = [10, 101] - Ks = [1, 3, 10] - sorts = [True, False] - factors = [Ns, Ds, P1s, P2s, Ks, sorts] - for N, D, P1, P2, K, sort in product(*factors): - lengths1 = torch.full((N,), P1, dtype=torch.int64, device=device) - lengths2 = torch.full((N,), P2, dtype=torch.int64, device=device) - x = torch.randn(N, P1, D, device=device) - y = torch.randn(N, P2, D, device=device) - out1 = _knn_points_idx_naive( - x, y, lengths1=lengths1, lengths2=lengths2, K=K - ) - out2 = knn_points_idx( - x, y, K=K, lengths1=lengths1, lengths2=lengths2, sorted=sort - ) - self._check_knn_result(out1, out2, sort) + assert N == _N and D == _D - def test_knn_vs_python_cuda_square(self): - """ Test CUDA output vs PyTorch implementation """ - device = torch.device("cuda") + if lengths1 is None: + lengths1 = torch.full((N,), P1, dtype=torch.int64, device=p1.device) + if lengths2 is None: + lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device) + + dists = torch.zeros((N, P1, K), dtype=torch.float32, device=p1.device) + idx = torch.zeros((N, P1, K), dtype=torch.int64, device=p1.device) + + for n in range(N): + num1 = lengths1[n].item() + num2 = lengths2[n].item() + pp1 = p1[n, :num1].view(num1, 1, D) + pp2 = p2[n, :num2].view(1, num2, D) + diff = pp1 - pp2 + diff = (diff * diff).sum(2) + num2 = min(num2, K) + for i in range(num1): + dd = diff[i] + srt_dd, srt_idx = dd.sort() + + dists[n, i, :num2] = srt_dd[:num2] + idx[n, i, :num2] = srt_idx[:num2] + + return _KNN(dists=dists, idx=idx, knn=None) + + def _knn_vs_python_square_helper(self, device): Ns = [1, 4] - Ds = [2, 3, 8] - P1s = [1, 8, 64, 128, 1001] - P2s = [32, 128, 513] + Ds = [3, 5, 8] + P1s = [8, 24] + P2s = [8, 16, 32] Ks = [1, 3, 10] - sorts = [True, False] versions = [0, 1, 2, 3] - factors = [Ns, Ds, P1s, P2s, Ks, sorts] - for N, D, P1, P2, K, sort in product(*factors): - x = torch.randn(N, P1, D, device=device) - y = torch.randn(N, P2, D, device=device) - out1 = _knn_points_idx_naive(x, y, lengths1=None, lengths2=None, K=K) + factors = [Ns, Ds, P1s, P2s, Ks] + for N, D, P1, P2, K in product(*factors): for version in versions: if version == 3 and K > 4: continue - out2 = knn_points_idx(x, y, K=K, sorted=sort, version=version) - self._check_knn_result(out1, out2, sort) + x = torch.randn(N, P1, D, device=device, requires_grad=True) + x_cuda = x.clone().detach() + x_cuda.requires_grad_(True) + y = torch.randn(N, P2, D, device=device, requires_grad=True) + y_cuda = y.clone().detach() + y_cuda.requires_grad_(True) - def test_knn_vs_python_cpu_ragged(self): + # forward + out1 = self._knn_points_naive(x, y, lengths1=None, lengths2=None, K=K) + out2 = knn_points(x_cuda, y_cuda, K=K, version=version) + self.assertClose(out1[0], out2[0]) + self.assertTrue(torch.all(out1[1] == out2[1])) + + # backward + grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device) + loss1 = (out1.dists * grad_dist).sum() + loss1.backward() + loss2 = (out2.dists * grad_dist).sum() + loss2.backward() + + self.assertClose(x_cuda.grad, x.grad, atol=5e-6) + self.assertClose(y_cuda.grad, y.grad, atol=5e-6) + + def test_knn_vs_python_square_cpu(self): device = torch.device("cpu") - lengths1 = torch.tensor([10, 100, 10, 100], device=device, dtype=torch.int64) - lengths2 = torch.tensor([10, 10, 100, 100], device=device, dtype=torch.int64) - N = 4 - D = 3 - Ks = [1, 9, 10, 11, 101] - sorts = [False, True] - factors = [Ks, sorts] - for K, sort in product(*factors): - x = torch.randn(N, lengths1.max(), D, device=device) - y = torch.randn(N, lengths2.max(), D, device=device) - out1 = _knn_points_idx_naive( - x, y, lengths1=lengths1, lengths2=lengths2, K=K - ) - out2 = knn_points_idx( - x, y, lengths1=lengths1, lengths2=lengths2, K=K, sorted=sort - ) - self._check_knn_result(out1, out2, sort) + self._knn_vs_python_square_helper(device) - def test_knn_vs_python_cuda_ragged(self): - device = torch.device("cuda") - lengths1 = torch.tensor([10, 100, 10, 100], device=device, dtype=torch.int64) - lengths2 = torch.tensor([10, 10, 100, 100], device=device, dtype=torch.int64) - N = 4 - D = 3 - Ks = [1, 9, 10, 11, 101] - sorts = [True, False] - versions = [0, 1, 2, 3] - factors = [Ks, sorts] - for K, sort in product(*factors): - x = torch.randn(N, lengths1.max(), D, device=device) - y = torch.randn(N, lengths2.max(), D, device=device) - out1 = _knn_points_idx_naive( + def test_knn_vs_python_square_cuda(self): + device = torch.device("cuda:0") + self._knn_vs_python_square_helper(device) + + def _knn_vs_python_ragged_helper(self, device): + Ns = [1, 4] + Ds = [3, 5, 8] + P1s = [8, 24] + P2s = [8, 16, 32] + Ks = [1, 3, 10] + factors = [Ns, Ds, P1s, P2s, Ks] + for N, D, P1, P2, K in product(*factors): + x = torch.rand((N, P1, D), device=device, requires_grad=True) + y = torch.rand((N, P2, D), device=device, requires_grad=True) + lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device) + lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device) + + x_csrc = x.clone().detach() + x_csrc.requires_grad_(True) + y_csrc = y.clone().detach() + y_csrc.requires_grad_(True) + + # forward + out1 = self._knn_points_naive( x, y, lengths1=lengths1, lengths2=lengths2, K=K ) - for version in versions: - if version == 3 and K > 4: - continue - out2 = knn_points_idx( - x, y, lengths1=lengths1, lengths2=lengths2, K=K, sorted=sort - ) - self._check_knn_result(out1, out2, sort) + out2 = knn_points(x_csrc, y_csrc, lengths1=lengths1, lengths2=lengths2, K=K) + self.assertClose(out1[0], out2[0]) + self.assertTrue(torch.all(out1[1] == out2[1])) + + # backward + grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device) + loss1 = (out1.dists * grad_dist).sum() + loss1.backward() + loss2 = (out2.dists * grad_dist).sum() + loss2.backward() + + self.assertClose(x_csrc.grad, x.grad, atol=5e-6) + self.assertClose(y_csrc.grad, y.grad, atol=5e-6) + + def test_knn_vs_python_ragged_cpu(self): + device = torch.device("cpu") + self._knn_vs_python_ragged_helper(device) + + def test_knn_vs_python_ragged_cuda(self): + device = torch.device("cuda:0") + self._knn_vs_python_ragged_helper(device) + + def test_knn_gather(self): + device = torch.device("cuda:0") + N, P1, P2, K, D = 4, 16, 12, 8, 3 + x = torch.rand((N, P1, D), device=device) + y = torch.rand((N, P2, D), device=device) + lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device) + lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device) + + out = knn_points(x, y, lengths1=lengths1, lengths2=lengths2, K=K) + y_nn = knn_gather(y, out.idx, lengths2) + + for n in range(N): + for p1 in range(P1): + for k in range(K): + if k < lengths2[n]: + self.assertClose(y_nn[n, p1, k], y[n, out.idx[n, p1, k]]) + else: + self.assertTrue(torch.all(y_nn[n, p1, k] == 0.0)) + + @staticmethod + def knn_square(N: int, P1: int, P2: int, D: int, K: int, device: str): + device = torch.device(device) + pts1 = torch.randn(N, P1, D, device=device, requires_grad=True) + pts2 = torch.randn(N, P2, D, device=device, requires_grad=True) + grad_dists = torch.randn(N, P1, K, device=device) + torch.cuda.synchronize() + + def output(): + out = knn_points(pts1, pts2, K=K) + loss = (out.dists * grad_dists).sum() + loss.backward() + torch.cuda.synchronize() + + return output + + @staticmethod + def knn_ragged(N: int, P1: int, P2: int, D: int, K: int, device: str): + device = torch.device(device) + pts1 = torch.rand((N, P1, D), device=device, requires_grad=True) + pts2 = torch.rand((N, P2, D), device=device, requires_grad=True) + lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device) + lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device) + grad_dists = torch.randn(N, P1, K, device=device) + torch.cuda.synchronize() + + def output(): + out = knn_points(pts1, pts2, lengths1=lengths1, lengths2=lengths2, K=K) + loss = (out.dists * grad_dists).sum() + loss.backward() + torch.cuda.synchronize() + + return output