mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
knn autograd
Summary: Adds knn backward to return `grad_pts1` and `grad_pts2`. Adds `knn_gather` to return the nearest neighbors in pts2. The BM tests include backward pass and are ran on an M40. ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- KNN_SQUARE_32_256_128_3_24_cpu 39558 43485 13 KNN_SQUARE_32_256_128_3_24_cuda:0 1080 1404 463 KNN_SQUARE_32_256_512_3_24_cpu 81950 85781 7 KNN_SQUARE_32_256_512_3_24_cuda:0 1519 1641 330 -------------------------------------------------------------------------------- Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- KNN_RAGGED_32_256_128_3_24_cpu 13798 14650 37 KNN_RAGGED_32_256_128_3_24_cuda:0 1576 1713 318 KNN_RAGGED_32_256_512_3_24_cpu 31255 32210 16 KNN_RAGGED_32_256_512_3_24_cuda:0 2024 2162 248 -------------------------------------------------------------------------------- ``` Reviewed By: jcjohnson Differential Revision: D20945556 fbshipit-source-id: a16f616029c6b5f8c2afceb5f2bc12c5c20d2f3c
This commit is contained in:
parent
487d4d6607
commit
b2b0c5a442
@ -20,6 +20,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
m.def("packed_to_padded", &PackedToPadded);
|
m.def("packed_to_padded", &PackedToPadded);
|
||||||
m.def("padded_to_packed", &PaddedToPacked);
|
m.def("padded_to_packed", &PaddedToPacked);
|
||||||
m.def("knn_points_idx", &KNearestNeighborIdx);
|
m.def("knn_points_idx", &KNearestNeighborIdx);
|
||||||
|
m.def("knn_points_backward", &KNearestNeighborBackward);
|
||||||
m.def("nn_points_idx", &NearestNeighborIdx);
|
m.def("nn_points_idx", &NearestNeighborIdx);
|
||||||
m.def("gather_scatter", &gather_scatter);
|
m.def("gather_scatter", &gather_scatter);
|
||||||
m.def("rasterize_points", &RasterizePoints);
|
m.def("rasterize_points", &RasterizePoints);
|
||||||
|
@ -412,3 +412,93 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
|||||||
|
|
||||||
return std::make_tuple(idxs, dists);
|
return std::make_tuple(idxs, dists);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ------------------------------------------------------------- //
|
||||||
|
// Backward Operators //
|
||||||
|
// ------------------------------------------------------------- //
|
||||||
|
|
||||||
|
// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
|
||||||
|
// Currently, support is for floats only.
|
||||||
|
__global__ void KNearestNeighborBackwardKernel(
|
||||||
|
const float* __restrict__ p1, // (N, P1, D)
|
||||||
|
const float* __restrict__ p2, // (N, P2, D)
|
||||||
|
const int64_t* __restrict__ lengths1, // (N,)
|
||||||
|
const int64_t* __restrict__ lengths2, // (N,)
|
||||||
|
const int64_t* __restrict__ idxs, // (N, P1, K)
|
||||||
|
const float* __restrict__ grad_dists, // (N, P1, K)
|
||||||
|
float* __restrict__ grad_p1, // (N, P1, D)
|
||||||
|
float* __restrict__ grad_p2, // (N, P2, D)
|
||||||
|
const size_t N,
|
||||||
|
const size_t P1,
|
||||||
|
const size_t P2,
|
||||||
|
const size_t K,
|
||||||
|
const size_t D) {
|
||||||
|
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const size_t stride = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
|
for (size_t i = tid; i < N * P1 * K * D; i += stride) {
|
||||||
|
const size_t n = i / (P1 * K * D); // batch index
|
||||||
|
size_t rem = i % (P1 * K * D);
|
||||||
|
const size_t p1_idx = rem / (K * D); // index of point in p1
|
||||||
|
rem = rem % (K * D);
|
||||||
|
const size_t k = rem / D; // k-th nearest neighbor
|
||||||
|
const size_t d = rem % D; // d-th dimension in the feature vector
|
||||||
|
|
||||||
|
const size_t num1 = lengths1[n]; // number of valid points in p1 in batch
|
||||||
|
const size_t num2 = lengths2[n]; // number of valid points in p2 in batch
|
||||||
|
if ((p1_idx < num1) && (k < num2)) {
|
||||||
|
const float grad_dist = grad_dists[n * P1 * K + p1_idx * K + k];
|
||||||
|
// index of point in p2 corresponding to the k-th nearest neighbor
|
||||||
|
const size_t p2_idx = idxs[n * P1 * K + p1_idx * K + k];
|
||||||
|
const float diff = 2.0 * grad_dist *
|
||||||
|
(p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
|
||||||
|
atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff);
|
||||||
|
atomicAdd(grad_p2 + n * P2 * D + p2_idx * D + d, -1.0f * diff);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
|
||||||
|
const at::Tensor& p1,
|
||||||
|
const at::Tensor& p2,
|
||||||
|
const at::Tensor& lengths1,
|
||||||
|
const at::Tensor& lengths2,
|
||||||
|
const at::Tensor& idxs,
|
||||||
|
const at::Tensor& grad_dists) {
|
||||||
|
const auto N = p1.size(0);
|
||||||
|
const auto P1 = p1.size(1);
|
||||||
|
const auto P2 = p2.size(1);
|
||||||
|
const auto D = p2.size(2);
|
||||||
|
const auto K = idxs.size(2);
|
||||||
|
|
||||||
|
AT_ASSERTM(p2.size(2) == D, "Point sets must have the same last dimension");
|
||||||
|
AT_ASSERTM(idxs.size(0) == N, "KNN idxs must have the same batch dimension");
|
||||||
|
AT_ASSERTM(
|
||||||
|
idxs.size(1) == P1, "KNN idxs must have the same point dimension as p1");
|
||||||
|
AT_ASSERTM(grad_dists.size(0) == N);
|
||||||
|
AT_ASSERTM(grad_dists.size(1) == P1);
|
||||||
|
AT_ASSERTM(grad_dists.size(2) == K);
|
||||||
|
|
||||||
|
auto grad_p1 = at::zeros({N, P1, D}, p1.options());
|
||||||
|
auto grad_p2 = at::zeros({N, P2, D}, p2.options());
|
||||||
|
|
||||||
|
const int blocks = 64;
|
||||||
|
const int threads = 512;
|
||||||
|
|
||||||
|
KNearestNeighborBackwardKernel<<<blocks, threads>>>(
|
||||||
|
p1.data_ptr<float>(),
|
||||||
|
p2.data_ptr<float>(),
|
||||||
|
lengths1.data_ptr<int64_t>(),
|
||||||
|
lengths2.data_ptr<int64_t>(),
|
||||||
|
idxs.data_ptr<int64_t>(),
|
||||||
|
grad_dists.data_ptr<float>(),
|
||||||
|
grad_p1.data_ptr<float>(),
|
||||||
|
grad_p2.data_ptr<float>(),
|
||||||
|
N,
|
||||||
|
P1,
|
||||||
|
P2,
|
||||||
|
K,
|
||||||
|
D);
|
||||||
|
|
||||||
|
return std::make_tuple(grad_p1, grad_p2);
|
||||||
|
}
|
||||||
|
@ -16,8 +16,6 @@
|
|||||||
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
|
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
|
||||||
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
|
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
|
||||||
// K: int giving the number of nearest points to return.
|
// K: int giving the number of nearest points to return.
|
||||||
// sorted: bool telling whether to sort the K returned points by their
|
|
||||||
// distance.
|
|
||||||
// version: Integer telling which implementation to use.
|
// version: Integer telling which implementation to use.
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
@ -67,3 +65,66 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
|
|||||||
}
|
}
|
||||||
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, K);
|
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, K);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Compute gradients with respect to p1 and p2
|
||||||
|
//
|
||||||
|
// Args:
|
||||||
|
// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
|
||||||
|
// containing P1 points of dimension D.
|
||||||
|
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
|
||||||
|
// containing P2 points of dimension D.
|
||||||
|
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
|
||||||
|
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
|
||||||
|
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
|
||||||
|
// p1_neighbor_idx[n, i, k] = j means that the kth nearest
|
||||||
|
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
|
||||||
|
// It is padded with zeros so that it can be used easily in a later
|
||||||
|
// gather() operation. This is computed from the forward pass.
|
||||||
|
// grad_dists: FLoatTensor of shape (N, P1, K) which contains the input
|
||||||
|
// gradients.
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// grad_p1: FloatTensor of shape (N, P1, D) containing the output gradients
|
||||||
|
// wrt p1.
|
||||||
|
// grad_p2: FloatTensor of shape (N, P2, D) containing the output gradients
|
||||||
|
// wrt p2.
|
||||||
|
|
||||||
|
// CPU implementation.
|
||||||
|
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
|
||||||
|
const at::Tensor& p1,
|
||||||
|
const at::Tensor& p2,
|
||||||
|
const at::Tensor& lengths1,
|
||||||
|
const at::Tensor& lengths2,
|
||||||
|
const at::Tensor& idxs,
|
||||||
|
const at::Tensor& grad_dists);
|
||||||
|
|
||||||
|
// CUDA implementation
|
||||||
|
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
|
||||||
|
const at::Tensor& p1,
|
||||||
|
const at::Tensor& p2,
|
||||||
|
const at::Tensor& lengths1,
|
||||||
|
const at::Tensor& lengths2,
|
||||||
|
const at::Tensor& idxs,
|
||||||
|
const at::Tensor& grad_dists);
|
||||||
|
|
||||||
|
// Implementation which is exposed.
|
||||||
|
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackward(
|
||||||
|
const at::Tensor& p1,
|
||||||
|
const at::Tensor& p2,
|
||||||
|
const at::Tensor& lengths1,
|
||||||
|
const at::Tensor& lengths2,
|
||||||
|
const at::Tensor& idxs,
|
||||||
|
const at::Tensor& grad_dists) {
|
||||||
|
if (p1.is_cuda() || p2.is_cuda()) {
|
||||||
|
#ifdef WITH_CUDA
|
||||||
|
CHECK_CONTIGUOUS_CUDA(p1);
|
||||||
|
CHECK_CONTIGUOUS_CUDA(p2);
|
||||||
|
return KNearestNeighborBackwardCuda(
|
||||||
|
p1, p2, lengths1, lengths2, idxs, grad_dists);
|
||||||
|
#else
|
||||||
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
return KNearestNeighborBackwardCpu(
|
||||||
|
p1, p2, lengths1, lengths2, idxs, grad_dists);
|
||||||
|
}
|
||||||
|
@ -57,3 +57,51 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
|
|||||||
}
|
}
|
||||||
return std::make_tuple(idxs, dists);
|
return std::make_tuple(idxs, dists);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ------------------------------------------------------------- //
|
||||||
|
// Backward Operators //
|
||||||
|
// ------------------------------------------------------------- //
|
||||||
|
|
||||||
|
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
|
||||||
|
const at::Tensor& p1,
|
||||||
|
const at::Tensor& p2,
|
||||||
|
const at::Tensor& lengths1,
|
||||||
|
const at::Tensor& lengths2,
|
||||||
|
const at::Tensor& idxs,
|
||||||
|
const at::Tensor& grad_dists) {
|
||||||
|
const int N = p1.size(0);
|
||||||
|
const int P1 = p1.size(1);
|
||||||
|
const int D = p1.size(2);
|
||||||
|
const int P2 = p2.size(1);
|
||||||
|
const int K = idxs.size(2);
|
||||||
|
|
||||||
|
torch::Tensor grad_p1 = torch::full({N, P1, D}, 0, p1.options());
|
||||||
|
torch::Tensor grad_p2 = torch::full({N, P2, D}, 0, p2.options());
|
||||||
|
|
||||||
|
auto p1_a = p1.accessor<float, 3>();
|
||||||
|
auto p2_a = p2.accessor<float, 3>();
|
||||||
|
auto lengths1_a = lengths1.accessor<int64_t, 1>();
|
||||||
|
auto lengths2_a = lengths2.accessor<int64_t, 1>();
|
||||||
|
auto idxs_a = idxs.accessor<int64_t, 3>();
|
||||||
|
auto grad_dists_a = grad_dists.accessor<float, 3>();
|
||||||
|
auto grad_p1_a = grad_p1.accessor<float, 3>();
|
||||||
|
auto grad_p2_a = grad_p2.accessor<float, 3>();
|
||||||
|
|
||||||
|
for (int n = 0; n < N; ++n) {
|
||||||
|
const int64_t length1 = lengths1_a[n];
|
||||||
|
int64_t length2 = lengths2_a[n];
|
||||||
|
length2 = (length2 < K) ? length2 : K;
|
||||||
|
for (int64_t i1 = 0; i1 < length1; ++i1) {
|
||||||
|
for (int64_t k = 0; k < length2; ++k) {
|
||||||
|
const int64_t i2 = idxs_a[n][i1][k];
|
||||||
|
for (int64_t d = 0; d < D; ++d) {
|
||||||
|
const float diff =
|
||||||
|
2.0f * grad_dists_a[n][i1][k] * (p1_a[n][i1][d] - p2_a[n][i2][d]);
|
||||||
|
grad_p1_a[n][i1][d] += diff;
|
||||||
|
grad_p2_a[n][i2][d] += -1.0f * diff;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::make_tuple(grad_p1, grad_p2);
|
||||||
|
}
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
from .cubify import cubify
|
from .cubify import cubify
|
||||||
from .graph_conv import GraphConv
|
from .graph_conv import GraphConv
|
||||||
|
from .knn import knn_gather, knn_points
|
||||||
from .mesh_face_areas_normals import mesh_face_areas_normals
|
from .mesh_face_areas_normals import mesh_face_areas_normals
|
||||||
from .nearest_neighbor_points import nn_points_idx
|
from .nearest_neighbor_points import nn_points_idx
|
||||||
from .packed_to_padded import packed_to_padded, padded_to_packed
|
from .packed_to_padded import packed_to_padded, padded_to_packed
|
||||||
|
@ -1,152 +1,215 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
from collections import namedtuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d import _C
|
from pytorch3d import _C
|
||||||
|
from torch.autograd import Function
|
||||||
|
from torch.autograd.function import once_differentiable
|
||||||
|
|
||||||
|
|
||||||
def knn_points_idx(
|
_KNN = namedtuple("KNN", "dists idx knn")
|
||||||
p1,
|
|
||||||
p2,
|
|
||||||
K: int,
|
class _knn_points(Function):
|
||||||
lengths1=None,
|
"""
|
||||||
lengths2=None,
|
Torch autograd Function wrapper for KNN C++/CUDA implementations.
|
||||||
sorted: bool = False,
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, p1, p2, lengths1, lengths2, K, version):
|
||||||
|
"""
|
||||||
|
K-Nearest neighbors on point clouds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each
|
||||||
|
containing up to P1 points of dimension D.
|
||||||
|
p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each
|
||||||
|
containing up to P2 points of dimension D.
|
||||||
|
lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
|
||||||
|
length of each pointcloud in p1. Or None to indicate that every cloud has
|
||||||
|
length P1.
|
||||||
|
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
|
||||||
|
length of each pointcloud in p2. Or None to indicate that every cloud has
|
||||||
|
length P2.
|
||||||
|
K: Integer giving the number of nearest neighbors to return.
|
||||||
|
version: Which KNN implementation to use in the backend. If version=-1,
|
||||||
|
the correct implementation is selected based on the shapes of the inputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
p1_dists: Tensor of shape (N, P1, K) giving the squared distances to
|
||||||
|
the nearest neighbors. This is padded with zeros both where a cloud in p2
|
||||||
|
has fewer than K points and where a cloud in p1 has fewer than P1 points.
|
||||||
|
|
||||||
|
p1_idx: LongTensor of shape (N, P1, K) giving the indices of the
|
||||||
|
K nearest neighbors from points in p1 to points in p2.
|
||||||
|
Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest
|
||||||
|
neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud
|
||||||
|
in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points.
|
||||||
|
"""
|
||||||
|
|
||||||
|
idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, K, version)
|
||||||
|
|
||||||
|
# sort KNN in ascending order if K > 1
|
||||||
|
if K > 1:
|
||||||
|
if lengths2.min() < K:
|
||||||
|
P1 = p1.shape[1]
|
||||||
|
mask = lengths2[:, None] <= torch.arange(K, device=dists.device)[None]
|
||||||
|
# mask has shape [N, K], true where dists irrelevant
|
||||||
|
mask = mask[:, None].expand(-1, P1, -1)
|
||||||
|
# mask has shape [N, P1, K], true where dists irrelevant
|
||||||
|
dists[mask] = float("inf")
|
||||||
|
dists, sort_idx = dists.sort(dim=2)
|
||||||
|
dists[mask] = 0
|
||||||
|
else:
|
||||||
|
dists, sort_idx = dists.sort(dim=2)
|
||||||
|
idx = idx.gather(2, sort_idx)
|
||||||
|
|
||||||
|
ctx.save_for_backward(p1, p2, lengths1, lengths2, idx)
|
||||||
|
return dists, idx
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@once_differentiable
|
||||||
|
def backward(ctx, grad_dists, grad_idx):
|
||||||
|
p1, p2, lengths1, lengths2, idx = ctx.saved_tensors
|
||||||
|
# TODO(gkioxari) Change cast to floats once we add support for doubles.
|
||||||
|
if not (grad_dists.dtype == torch.float32):
|
||||||
|
grad_dists = grad_dists.float()
|
||||||
|
if not (p1.dtype == torch.float32):
|
||||||
|
p1 = p1.float()
|
||||||
|
if not (p2.dtype == torch.float32):
|
||||||
|
p2 = p2.float()
|
||||||
|
grad_p1, grad_p2 = _C.knn_points_backward(
|
||||||
|
p1, p2, lengths1, lengths2, idx, grad_dists
|
||||||
|
)
|
||||||
|
return grad_p1, grad_p2, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
def knn_points(
|
||||||
|
p1: torch.Tensor,
|
||||||
|
p2: torch.Tensor,
|
||||||
|
lengths1: Union[torch.Tensor, None] = None,
|
||||||
|
lengths2: Union[torch.Tensor, None] = None,
|
||||||
|
K: int = 1,
|
||||||
version: int = -1,
|
version: int = -1,
|
||||||
|
return_nn: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
K-Nearest neighbors on point clouds.
|
K-Nearest neighbors on point clouds.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
p1: Tensor of shape (N, P1, D) giving a batch of point clouds, each
|
p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each
|
||||||
containing up to P1 points of dimension D.
|
containing up to P1 points of dimension D.
|
||||||
p2: Tensor of shape (N, P2, D) giving a batch of point clouds, each
|
p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each
|
||||||
containing up to P2 points of dimension D.
|
containing up to P2 points of dimension D.
|
||||||
K: Integer giving the number of nearest neighbors to return.
|
|
||||||
lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
|
lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
|
||||||
length of each pointcloud in p1. Or None to indicate that every cloud has
|
length of each pointcloud in p1. Or None to indicate that every cloud has
|
||||||
length P1.
|
length P1.
|
||||||
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
|
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
|
||||||
length of each pointcloud in p2. Or None to indicate that every cloud has
|
length of each pointcloud in p2. Or None to indicate that every cloud has
|
||||||
length P2.
|
length P2.
|
||||||
sorted: Whether to sort the resulting points.
|
K: Integer giving the number of nearest neighbors to return.
|
||||||
version: Which KNN implementation to use in the backend. If version=-1,
|
version: Which KNN implementation to use in the backend. If version=-1,
|
||||||
the correct implementation is selected based on the shapes of the inputs.
|
the correct implementation is selected based on the shapes of the inputs.
|
||||||
|
return_nn: If set to True returns the K nearest neighors in p2 for each point in p1.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
p1_neighbor_idx: LongTensor of shape (N, P1, K) giving the indices of the
|
p1_idx: LongTensor of shape (N, P1, K) giving the indices of the
|
||||||
K nearest neighbors from points in p1 to points in p2.
|
K nearest neighbors from points in p1 to points in p2.
|
||||||
Concretely, if idx[n, i, k] = j then p2[n, j] is one of the K nearest
|
Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest
|
||||||
neighbors to p1[n, i] in p2[n]. If sorted=True, then p2[n, j] is the kth
|
neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud
|
||||||
nearest neighbor to p1[n, i]. This is padded with zeros both where a cloud
|
|
||||||
in p2 has fewer than K points and where a cloud in p1 has fewer than P1
|
in p2 has fewer than K points and where a cloud in p1 has fewer than P1
|
||||||
points.
|
points.
|
||||||
If you want an (N, P1, K, D) tensor of the actual points, you can get it
|
|
||||||
using
|
|
||||||
p2[:, :, None].expand(-1, -1, K, -1).gather(1,
|
|
||||||
x_idx[:, :, :, None].expand(-1, -1, -1, D)
|
|
||||||
)
|
|
||||||
If K=1 and you want an (N, P1, D) tensor of the actual points, use
|
|
||||||
p2.gather(1, x_idx.expand(-1, -1, D))
|
|
||||||
|
|
||||||
p1_neighbor_dists: Tensor of shape (N, P1, K) giving the squared distances to
|
p1_dists: Tensor of shape (N, P1, K) giving the squared distances to
|
||||||
the nearest neighbors. This is padded with zeros both where a cloud in p2
|
the nearest neighbors. This is padded with zeros both where a cloud in p2
|
||||||
has fewer than K points and where a cloud in p1 has fewer than P1 points.
|
has fewer than K points and where a cloud in p1 has fewer than P1 points.
|
||||||
Warning: this is calculated outside of the autograd framework.
|
|
||||||
|
p2_nn: Tensor of shape (N, P1, K, D) giving the K nearest neighbors in p2 for
|
||||||
|
each point in p1. Concretely, `p2_nn[n, i, k]` gives the k-th nearest neighbor
|
||||||
|
for `p1[n, i]`. Returned if `return_nn` is True.
|
||||||
|
The nearest neighbors are collected using `knn_gather`
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
p2_nn = knn_gather(p2, p1_idx, lengths2)
|
||||||
|
|
||||||
|
which is a helper function that allows indexing any tensor of shape (N, P2, U) with
|
||||||
|
the indices `p1_idx` returned by `knn_points`. The outout is a tensor
|
||||||
|
of shape (N, P1, K, U).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if p1.shape[0] != p2.shape[0]:
|
||||||
|
raise ValueError("pts1 and pts2 must have the same batch dimension.")
|
||||||
|
if p1.shape[2] != p2.shape[2]:
|
||||||
|
raise ValueError("pts1 and pts2 must have the same point dimension.")
|
||||||
|
|
||||||
P1 = p1.shape[1]
|
P1 = p1.shape[1]
|
||||||
P2 = p2.shape[1]
|
P2 = p2.shape[1]
|
||||||
|
|
||||||
if lengths1 is None:
|
if lengths1 is None:
|
||||||
lengths1 = torch.full((p1.shape[0],), P1, dtype=torch.int64, device=p1.device)
|
lengths1 = torch.full((p1.shape[0],), P1, dtype=torch.int64, device=p1.device)
|
||||||
if lengths2 is None:
|
if lengths2 is None:
|
||||||
lengths2 = torch.full((p1.shape[0],), P2, dtype=torch.int64, device=p1.device)
|
lengths2 = torch.full((p1.shape[0],), P2, dtype=torch.int64, device=p1.device)
|
||||||
idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, K, version)
|
|
||||||
if sorted:
|
p1_dists, p1_idx = _knn_points.apply(p1, p2, lengths1, lengths2, K, version)
|
||||||
if lengths2.min() < K:
|
|
||||||
device = dists.device
|
p2_nn = None
|
||||||
mask1 = lengths2[:, None] <= torch.arange(K, device=device)[None]
|
if return_nn:
|
||||||
# mask1 has shape [N, K], true where dists irrelevant
|
p2_nn = knn_gather(p2, p1_idx, lengths2)
|
||||||
mask2 = mask1[:, None].expand(-1, P1, -1)
|
|
||||||
# mask2 has shape [N, P1, K], true where dists irrelevant
|
return _KNN(dists=p1_dists, idx=p1_idx, knn=p2_nn if return_nn else None)
|
||||||
dists[mask2] = float("inf")
|
|
||||||
dists, sort_idx = dists.sort(dim=2)
|
|
||||||
dists[mask2] = 0
|
|
||||||
else:
|
|
||||||
dists, sort_idx = dists.sort(dim=2)
|
|
||||||
idx = idx.gather(2, sort_idx)
|
|
||||||
return idx, dists
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
def knn_gather(
|
||||||
def _knn_points_idx_naive(p1, p2, K: int, lengths1, lengths2) -> torch.Tensor:
|
x: torch.Tensor, idx: torch.Tensor, lengths: Union[torch.Tensor, None] = None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Naive PyTorch implementation of K-Nearest Neighbors.
|
A helper function for knn that allows indexing a tensor x with the indices `idx`
|
||||||
|
returned by `knn_points`.
|
||||||
|
|
||||||
This is much less efficient than _C.knn_points_idx, but we include this
|
For example, if `dists, idx = knn_points(p, x, lengths_p, lengths, K)`
|
||||||
naive implementation for testing and benchmarking.
|
where p is a tensor of shape (N, L, D) and x a tensor of shape (N, M, D),
|
||||||
|
then one can compute the K nearest neighbors of p with `p_nn = knn_gather(x, idx, lengths)`.
|
||||||
|
It can also be applied for any tensor x of shape (N, M, U) where U != D.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
p1: Tensor of shape (N, P1, D) giving a batch of point clouds, each
|
x: Tensor of shape (N, M, U) containing U-dimensional features to
|
||||||
containing up to P1 points of dimension D.
|
be gathered.
|
||||||
p2: Tensor of shape (N, P2, D) giving a batch of point clouds, each
|
idx: LongTensor of shape (N, L, K) giving the indices returned by `knn_points`.
|
||||||
containing up to P2 points of dimension D.
|
lengths: LongTensor of shape (N,) of values in the range [0, M], giving the
|
||||||
K: Integer giving the number of nearest neighbors to return.
|
length of each example in the batch in x. Or None to indicate that every
|
||||||
lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
|
example has length M.
|
||||||
length of each pointcloud in p1. Or None to indicate that every cloud has
|
|
||||||
length P1.
|
|
||||||
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
|
|
||||||
length of each pointcloud in p2. Or None to indicate that every cloud has
|
|
||||||
length P2.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
idx: LongTensor of shape (N, P1, K) giving the indices of the
|
x_out: Tensor of shape (N, L, K, U) resulting from gathering the elements of x
|
||||||
K nearest neighbors from points in p1 to points in p2.
|
with idx, s.t. `x_out[n, l, k] = x[n, idx[n, l, k]]`.
|
||||||
Concretely, if idx[n, i, k] = j then p2[n, j] is the kth nearest neighbor
|
If `k > lengths[n]` then `x_out[n, l, k]` is filled with 0.0.
|
||||||
to p1[n, i]. This is padded with zeros both where a cloud in p2 has fewer
|
|
||||||
than K points and where a cloud in p1 has fewer than P1 points.
|
|
||||||
dists: Tensor of shape (N, P1, K) giving the squared distances to the nearest
|
|
||||||
neighbors. This is padded with zeros both where a cloud in p2 has fewer than
|
|
||||||
K points and where a cloud in p1 has fewer than P1 points.
|
|
||||||
"""
|
"""
|
||||||
N, P1, D = p1.shape
|
N, M, U = x.shape
|
||||||
_N, P2, _D = p2.shape
|
_N, L, K = idx.shape
|
||||||
|
|
||||||
assert N == _N and D == _D
|
if N != _N:
|
||||||
|
raise ValueError("x and idx must have same batch dimension.")
|
||||||
|
|
||||||
if lengths1 is None:
|
if lengths is None:
|
||||||
lengths1 = torch.full((N,), P1, dtype=torch.int64, device=p1.device)
|
lengths = torch.full((x.shape[0],), M, dtype=torch.int64, device=x.device)
|
||||||
if lengths2 is None:
|
|
||||||
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device)
|
|
||||||
|
|
||||||
p1_copy = p1.clone()
|
idx_expanded = idx[:, :, :, None].expand(-1, -1, -1, U)
|
||||||
p2_copy = p2.clone()
|
# idx_expanded has shape [N, L, K, U]
|
||||||
|
|
||||||
# We pad the values with infinities so that the smallest differences are
|
x_out = x[:, :, None].expand(-1, -1, K, -1).gather(1, idx_expanded)
|
||||||
# among actual points.
|
# p2_nn has shape [N, L, K, U]
|
||||||
inf = float("inf")
|
|
||||||
p1_mask = torch.arange(P1, device=p1.device)[None] >= lengths1[:, None]
|
|
||||||
p1_copy[p1_mask] = inf
|
|
||||||
p2_copy[torch.arange(P2, device=p1.device)[None] >= lengths2[:, None]] = -inf
|
|
||||||
|
|
||||||
# view is safe here: we are merely adding extra dimensions of length 1
|
needs_mask = lengths.min() < K
|
||||||
diffs = p1_copy.view(N, P1, 1, D) - p2_copy.view(N, 1, P2, D)
|
if needs_mask:
|
||||||
dists2 = (diffs * diffs).sum(dim=3)
|
# mask has shape [N, K], true where idx is irrelevant because
|
||||||
|
# there is less number of points in p2 than K
|
||||||
|
mask = lengths[:, None] <= torch.arange(K, device=x.device)[None]
|
||||||
|
|
||||||
# We always sort, because this works well with padding.
|
# expand mask to shape [N, L, K, U]
|
||||||
out = dists2.topk(min(K, P2), dim=2, largest=False, sorted=True)
|
mask = mask[:, None].expand(-1, L, -1)
|
||||||
|
mask = mask[:, :, :, None].expand(-1, -1, -1, U)
|
||||||
|
x_out[mask] = 0.0
|
||||||
|
|
||||||
out_indices = out.indices
|
return x_out
|
||||||
out_values = out.values
|
|
||||||
|
|
||||||
if P2 < K:
|
|
||||||
# Need to add padding
|
|
||||||
pad_shape = (N, P1, K - P2)
|
|
||||||
out_indices = torch.cat([out_indices, out_indices.new_zeros(pad_shape)], 2)
|
|
||||||
out_values = torch.cat([out_values, out_values.new_zeros(pad_shape)], 2)
|
|
||||||
|
|
||||||
K_mask = torch.arange(K, device=p1.device)[None] >= lengths2[:, None]
|
|
||||||
# Create a combined mask for where the points in p1 are padded
|
|
||||||
# or the corresponding p2 has fewer than K points.
|
|
||||||
p1_K_mask = p1_mask[:, :, None] | K_mask[:, None, :]
|
|
||||||
out_indices[p1_K_mask] = 0
|
|
||||||
out_values[p1_K_mask] = 0
|
|
||||||
return out_indices, out_values
|
|
||||||
|
181
tests/bm_knn.py
181
tests/bm_knn.py
@ -2,180 +2,25 @@
|
|||||||
|
|
||||||
from itertools import product
|
from itertools import product
|
||||||
|
|
||||||
import torch
|
|
||||||
from fvcore.common.benchmark import benchmark
|
from fvcore.common.benchmark import benchmark
|
||||||
from pytorch3d import _C
|
from test_knn import TestKNN
|
||||||
from pytorch3d.ops.knn import _knn_points_idx_naive
|
|
||||||
|
|
||||||
|
|
||||||
def bm_knn() -> None:
|
def bm_knn() -> None:
|
||||||
""" Entry point for the benchmark """
|
|
||||||
benchmark_knn_cpu()
|
|
||||||
benchmark_knn_cuda_vs_naive()
|
|
||||||
benchmark_knn_cuda_versions()
|
|
||||||
benchmark_knn_cuda_versions_ragged()
|
|
||||||
|
|
||||||
|
backends = ["cpu", "cuda:0"]
|
||||||
|
|
||||||
def benchmark_knn_cuda_versions() -> None:
|
kwargs_list = []
|
||||||
# Compare our different KNN implementations,
|
Ns = [32]
|
||||||
# and also compare against our existing 1-NN
|
P1s = [256]
|
||||||
Ns = [1, 2]
|
P2s = [128, 512]
|
||||||
Ps = [4096, 16384]
|
|
||||||
Ds = [3]
|
Ds = [3]
|
||||||
Ks = [1, 4, 16, 64]
|
Ks = [24]
|
||||||
versions = [0, 1, 2, 3]
|
test_cases = product(Ns, P1s, P2s, Ds, Ks, backends)
|
||||||
knn_kwargs, nn_kwargs = [], []
|
for case in test_cases:
|
||||||
for N, P, D, K, version in product(Ns, Ps, Ds, Ks, versions):
|
N, P1, P2, D, K, b = case
|
||||||
if version == 2 and K > 32:
|
kwargs_list.append({"N": N, "P1": P1, "P2": P2, "D": D, "K": K, "device": b})
|
||||||
continue
|
|
||||||
if version == 3 and K > 4:
|
|
||||||
continue
|
|
||||||
knn_kwargs.append({"N": N, "D": D, "P": P, "K": K, "v": version})
|
|
||||||
for N, P, D in product(Ns, Ps, Ds):
|
|
||||||
nn_kwargs.append({"N": N, "D": D, "P": P})
|
|
||||||
benchmark(knn_cuda_with_init, "KNN_CUDA_VERSIONS", knn_kwargs, warmup_iters=1)
|
|
||||||
benchmark(nn_cuda_with_init, "NN_CUDA", nn_kwargs, warmup_iters=1)
|
|
||||||
|
|
||||||
|
benchmark(TestKNN.knn_square, "KNN_SQUARE", kwargs_list, warmup_iters=1)
|
||||||
|
|
||||||
def benchmark_knn_cuda_versions_ragged() -> None:
|
benchmark(TestKNN.knn_ragged, "KNN_RAGGED", kwargs_list, warmup_iters=1)
|
||||||
# Compare our different KNN implementations,
|
|
||||||
# and also compare against our existing 1-NN
|
|
||||||
Ns = [8]
|
|
||||||
Ps = [4096, 16384]
|
|
||||||
Ds = [3]
|
|
||||||
Ks = [1, 4, 16, 64]
|
|
||||||
versions = [0, 1, 2, 3]
|
|
||||||
knn_kwargs = []
|
|
||||||
for N, P, D, K, version in product(Ns, Ps, Ds, Ks, versions):
|
|
||||||
if version == 2 and K > 32:
|
|
||||||
continue
|
|
||||||
if version == 3 and K > 4:
|
|
||||||
continue
|
|
||||||
knn_kwargs.append({"N": N, "D": D, "P": P, "K": K, "v": version})
|
|
||||||
benchmark(knn_cuda_with_init, "KNN_CUDA_COMPARISON", knn_kwargs, warmup_iters=1)
|
|
||||||
benchmark(knn_cuda_ragged, "KNN_CUDA_RAGGED", knn_kwargs, warmup_iters=1)
|
|
||||||
|
|
||||||
|
|
||||||
def benchmark_knn_cuda_vs_naive() -> None:
|
|
||||||
# Compare against naive pytorch version of KNN
|
|
||||||
Ns = [1, 2, 4]
|
|
||||||
Ps = [1024, 4096, 16384, 65536]
|
|
||||||
Ds = [3]
|
|
||||||
Ks = [1, 2, 4, 8, 16]
|
|
||||||
knn_kwargs, naive_kwargs = [], []
|
|
||||||
for N, P, D, K in product(Ns, Ps, Ds, Ks):
|
|
||||||
knn_kwargs.append({"N": N, "D": D, "P": P, "K": K})
|
|
||||||
if P <= 4096:
|
|
||||||
naive_kwargs.append({"N": N, "D": D, "P": P, "K": K})
|
|
||||||
benchmark(
|
|
||||||
knn_python_cuda_with_init, "KNN_CUDA_PYTHON", naive_kwargs, warmup_iters=1
|
|
||||||
)
|
|
||||||
benchmark(knn_cuda_with_init, "KNN_CUDA", knn_kwargs, warmup_iters=1)
|
|
||||||
|
|
||||||
|
|
||||||
def benchmark_knn_cpu() -> None:
|
|
||||||
Ns = [1, 2]
|
|
||||||
Ps = [256, 512]
|
|
||||||
Ds = [3]
|
|
||||||
Ks = [1, 2, 4]
|
|
||||||
knn_kwargs, nn_kwargs = [], []
|
|
||||||
for N, P, D, K in product(Ns, Ps, Ds, Ks):
|
|
||||||
knn_kwargs.append({"N": N, "D": D, "P": P, "K": K})
|
|
||||||
for N, P, D in product(Ns, Ps, Ds):
|
|
||||||
nn_kwargs.append({"N": N, "D": D, "P": P})
|
|
||||||
benchmark(knn_python_cpu_with_init, "KNN_CPU_PYTHON", knn_kwargs, warmup_iters=1)
|
|
||||||
benchmark(knn_cpu_with_init, "KNN_CPU_CPP", knn_kwargs, warmup_iters=1)
|
|
||||||
benchmark(nn_cpu_with_init, "NN_CPU_CPP", nn_kwargs, warmup_iters=1)
|
|
||||||
|
|
||||||
|
|
||||||
def knn_cuda_with_init(N, D, P, K, v=-1):
|
|
||||||
device = torch.device("cuda:0")
|
|
||||||
x = torch.randn(N, P, D, device=device)
|
|
||||||
y = torch.randn(N, P, D, device=device)
|
|
||||||
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
def knn():
|
|
||||||
_C.knn_points_idx(x, y, lengths, lengths, K, v)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
return knn
|
|
||||||
|
|
||||||
|
|
||||||
def knn_cuda_ragged(N, D, P, K, v=-1):
|
|
||||||
device = torch.device("cuda:0")
|
|
||||||
x = torch.randn(N, P, D, device=device)
|
|
||||||
y = torch.randn(N, P, D, device=device)
|
|
||||||
lengths1 = torch.randint(P, size=(N,), device=device, dtype=torch.int64)
|
|
||||||
lengths2 = torch.randint(P, size=(N,), device=device, dtype=torch.int64)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
def knn():
|
|
||||||
_C.knn_points_idx(x, y, lengths1, lengths2, K, v)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
return knn
|
|
||||||
|
|
||||||
|
|
||||||
def knn_cpu_with_init(N, D, P, K):
|
|
||||||
device = torch.device("cpu")
|
|
||||||
x = torch.randn(N, P, D, device=device)
|
|
||||||
y = torch.randn(N, P, D, device=device)
|
|
||||||
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
|
||||||
|
|
||||||
def knn():
|
|
||||||
_C.knn_points_idx(x, y, lengths, lengths, K, -1)
|
|
||||||
|
|
||||||
return knn
|
|
||||||
|
|
||||||
|
|
||||||
def knn_python_cuda_with_init(N, D, P, K):
|
|
||||||
device = torch.device("cuda")
|
|
||||||
x = torch.randn(N, P, D, device=device)
|
|
||||||
y = torch.randn(N, P, D, device=device)
|
|
||||||
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
def knn():
|
|
||||||
_knn_points_idx_naive(x, y, K=K, lengths1=lengths, lengths2=lengths)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
return knn
|
|
||||||
|
|
||||||
|
|
||||||
def knn_python_cpu_with_init(N, D, P, K):
|
|
||||||
device = torch.device("cpu")
|
|
||||||
x = torch.randn(N, P, D, device=device)
|
|
||||||
y = torch.randn(N, P, D, device=device)
|
|
||||||
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
|
||||||
|
|
||||||
def knn():
|
|
||||||
_knn_points_idx_naive(x, y, K=K, lengths1=lengths, lengths2=lengths)
|
|
||||||
|
|
||||||
return knn
|
|
||||||
|
|
||||||
|
|
||||||
def nn_cuda_with_init(N, D, P):
|
|
||||||
device = torch.device("cuda")
|
|
||||||
x = torch.randn(N, P, D, device=device)
|
|
||||||
y = torch.randn(N, P, D, device=device)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
def knn():
|
|
||||||
_C.nn_points_idx(x, y)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
return knn
|
|
||||||
|
|
||||||
|
|
||||||
def nn_cpu_with_init(N, D, P):
|
|
||||||
device = torch.device("cpu")
|
|
||||||
x = torch.randn(N, P, D, device=device)
|
|
||||||
y = torch.randn(N, P, D, device=device)
|
|
||||||
|
|
||||||
def knn():
|
|
||||||
_C.nn_points_idx(x, y)
|
|
||||||
|
|
||||||
return knn
|
|
||||||
|
@ -4,116 +4,187 @@ import unittest
|
|||||||
from itertools import product
|
from itertools import product
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.ops.knn import _knn_points_idx_naive, knn_points_idx
|
from common_testing import TestCaseMixin
|
||||||
|
from pytorch3d.ops.knn import _KNN, knn_gather, knn_points
|
||||||
|
|
||||||
|
|
||||||
class TestKNN(unittest.TestCase):
|
class TestKNN(TestCaseMixin, unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
super().setUp()
|
super().setUp()
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
|
|
||||||
def _check_knn_result(self, out1, out2, sorted):
|
@staticmethod
|
||||||
# When sorted=True, points should be sorted by distance and should
|
def _knn_points_naive(p1, p2, lengths1, lengths2, K: int) -> torch.Tensor:
|
||||||
# match between implementations. When sorted=False we we only want to
|
"""
|
||||||
# check that we got the same set of indices, so we sort the indices by
|
Naive PyTorch implementation of K-Nearest Neighbors.
|
||||||
# index value.
|
Returns always sorted results
|
||||||
idx1, dist1 = out1
|
"""
|
||||||
idx2, dist2 = out2
|
N, P1, D = p1.shape
|
||||||
if not sorted:
|
_N, P2, _D = p2.shape
|
||||||
idx1 = idx1.sort(dim=2).values
|
|
||||||
idx2 = idx2.sort(dim=2).values
|
|
||||||
dist1 = dist1.sort(dim=2).values
|
|
||||||
dist2 = dist2.sort(dim=2).values
|
|
||||||
if not torch.all(idx1 == idx2):
|
|
||||||
print(idx1)
|
|
||||||
print(idx2)
|
|
||||||
self.assertTrue(torch.all(idx1 == idx2))
|
|
||||||
self.assertTrue(torch.allclose(dist1, dist2))
|
|
||||||
|
|
||||||
def test_knn_vs_python_cpu_square(self):
|
assert N == _N and D == _D
|
||||||
""" Test CPU output vs PyTorch implementation """
|
|
||||||
device = torch.device("cpu")
|
|
||||||
Ns = [1, 4]
|
|
||||||
Ds = [2, 3]
|
|
||||||
P1s = [1, 10, 101]
|
|
||||||
P2s = [10, 101]
|
|
||||||
Ks = [1, 3, 10]
|
|
||||||
sorts = [True, False]
|
|
||||||
factors = [Ns, Ds, P1s, P2s, Ks, sorts]
|
|
||||||
for N, D, P1, P2, K, sort in product(*factors):
|
|
||||||
lengths1 = torch.full((N,), P1, dtype=torch.int64, device=device)
|
|
||||||
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=device)
|
|
||||||
x = torch.randn(N, P1, D, device=device)
|
|
||||||
y = torch.randn(N, P2, D, device=device)
|
|
||||||
out1 = _knn_points_idx_naive(
|
|
||||||
x, y, lengths1=lengths1, lengths2=lengths2, K=K
|
|
||||||
)
|
|
||||||
out2 = knn_points_idx(
|
|
||||||
x, y, K=K, lengths1=lengths1, lengths2=lengths2, sorted=sort
|
|
||||||
)
|
|
||||||
self._check_knn_result(out1, out2, sort)
|
|
||||||
|
|
||||||
def test_knn_vs_python_cuda_square(self):
|
if lengths1 is None:
|
||||||
""" Test CUDA output vs PyTorch implementation """
|
lengths1 = torch.full((N,), P1, dtype=torch.int64, device=p1.device)
|
||||||
device = torch.device("cuda")
|
if lengths2 is None:
|
||||||
|
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device)
|
||||||
|
|
||||||
|
dists = torch.zeros((N, P1, K), dtype=torch.float32, device=p1.device)
|
||||||
|
idx = torch.zeros((N, P1, K), dtype=torch.int64, device=p1.device)
|
||||||
|
|
||||||
|
for n in range(N):
|
||||||
|
num1 = lengths1[n].item()
|
||||||
|
num2 = lengths2[n].item()
|
||||||
|
pp1 = p1[n, :num1].view(num1, 1, D)
|
||||||
|
pp2 = p2[n, :num2].view(1, num2, D)
|
||||||
|
diff = pp1 - pp2
|
||||||
|
diff = (diff * diff).sum(2)
|
||||||
|
num2 = min(num2, K)
|
||||||
|
for i in range(num1):
|
||||||
|
dd = diff[i]
|
||||||
|
srt_dd, srt_idx = dd.sort()
|
||||||
|
|
||||||
|
dists[n, i, :num2] = srt_dd[:num2]
|
||||||
|
idx[n, i, :num2] = srt_idx[:num2]
|
||||||
|
|
||||||
|
return _KNN(dists=dists, idx=idx, knn=None)
|
||||||
|
|
||||||
|
def _knn_vs_python_square_helper(self, device):
|
||||||
Ns = [1, 4]
|
Ns = [1, 4]
|
||||||
Ds = [2, 3, 8]
|
Ds = [3, 5, 8]
|
||||||
P1s = [1, 8, 64, 128, 1001]
|
P1s = [8, 24]
|
||||||
P2s = [32, 128, 513]
|
P2s = [8, 16, 32]
|
||||||
Ks = [1, 3, 10]
|
Ks = [1, 3, 10]
|
||||||
sorts = [True, False]
|
|
||||||
versions = [0, 1, 2, 3]
|
versions = [0, 1, 2, 3]
|
||||||
factors = [Ns, Ds, P1s, P2s, Ks, sorts]
|
factors = [Ns, Ds, P1s, P2s, Ks]
|
||||||
for N, D, P1, P2, K, sort in product(*factors):
|
for N, D, P1, P2, K in product(*factors):
|
||||||
x = torch.randn(N, P1, D, device=device)
|
|
||||||
y = torch.randn(N, P2, D, device=device)
|
|
||||||
out1 = _knn_points_idx_naive(x, y, lengths1=None, lengths2=None, K=K)
|
|
||||||
for version in versions:
|
for version in versions:
|
||||||
if version == 3 and K > 4:
|
if version == 3 and K > 4:
|
||||||
continue
|
continue
|
||||||
out2 = knn_points_idx(x, y, K=K, sorted=sort, version=version)
|
x = torch.randn(N, P1, D, device=device, requires_grad=True)
|
||||||
self._check_knn_result(out1, out2, sort)
|
x_cuda = x.clone().detach()
|
||||||
|
x_cuda.requires_grad_(True)
|
||||||
|
y = torch.randn(N, P2, D, device=device, requires_grad=True)
|
||||||
|
y_cuda = y.clone().detach()
|
||||||
|
y_cuda.requires_grad_(True)
|
||||||
|
|
||||||
def test_knn_vs_python_cpu_ragged(self):
|
# forward
|
||||||
|
out1 = self._knn_points_naive(x, y, lengths1=None, lengths2=None, K=K)
|
||||||
|
out2 = knn_points(x_cuda, y_cuda, K=K, version=version)
|
||||||
|
self.assertClose(out1[0], out2[0])
|
||||||
|
self.assertTrue(torch.all(out1[1] == out2[1]))
|
||||||
|
|
||||||
|
# backward
|
||||||
|
grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device)
|
||||||
|
loss1 = (out1.dists * grad_dist).sum()
|
||||||
|
loss1.backward()
|
||||||
|
loss2 = (out2.dists * grad_dist).sum()
|
||||||
|
loss2.backward()
|
||||||
|
|
||||||
|
self.assertClose(x_cuda.grad, x.grad, atol=5e-6)
|
||||||
|
self.assertClose(y_cuda.grad, y.grad, atol=5e-6)
|
||||||
|
|
||||||
|
def test_knn_vs_python_square_cpu(self):
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
lengths1 = torch.tensor([10, 100, 10, 100], device=device, dtype=torch.int64)
|
self._knn_vs_python_square_helper(device)
|
||||||
lengths2 = torch.tensor([10, 10, 100, 100], device=device, dtype=torch.int64)
|
|
||||||
N = 4
|
|
||||||
D = 3
|
|
||||||
Ks = [1, 9, 10, 11, 101]
|
|
||||||
sorts = [False, True]
|
|
||||||
factors = [Ks, sorts]
|
|
||||||
for K, sort in product(*factors):
|
|
||||||
x = torch.randn(N, lengths1.max(), D, device=device)
|
|
||||||
y = torch.randn(N, lengths2.max(), D, device=device)
|
|
||||||
out1 = _knn_points_idx_naive(
|
|
||||||
x, y, lengths1=lengths1, lengths2=lengths2, K=K
|
|
||||||
)
|
|
||||||
out2 = knn_points_idx(
|
|
||||||
x, y, lengths1=lengths1, lengths2=lengths2, K=K, sorted=sort
|
|
||||||
)
|
|
||||||
self._check_knn_result(out1, out2, sort)
|
|
||||||
|
|
||||||
def test_knn_vs_python_cuda_ragged(self):
|
def test_knn_vs_python_square_cuda(self):
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda:0")
|
||||||
lengths1 = torch.tensor([10, 100, 10, 100], device=device, dtype=torch.int64)
|
self._knn_vs_python_square_helper(device)
|
||||||
lengths2 = torch.tensor([10, 10, 100, 100], device=device, dtype=torch.int64)
|
|
||||||
N = 4
|
def _knn_vs_python_ragged_helper(self, device):
|
||||||
D = 3
|
Ns = [1, 4]
|
||||||
Ks = [1, 9, 10, 11, 101]
|
Ds = [3, 5, 8]
|
||||||
sorts = [True, False]
|
P1s = [8, 24]
|
||||||
versions = [0, 1, 2, 3]
|
P2s = [8, 16, 32]
|
||||||
factors = [Ks, sorts]
|
Ks = [1, 3, 10]
|
||||||
for K, sort in product(*factors):
|
factors = [Ns, Ds, P1s, P2s, Ks]
|
||||||
x = torch.randn(N, lengths1.max(), D, device=device)
|
for N, D, P1, P2, K in product(*factors):
|
||||||
y = torch.randn(N, lengths2.max(), D, device=device)
|
x = torch.rand((N, P1, D), device=device, requires_grad=True)
|
||||||
out1 = _knn_points_idx_naive(
|
y = torch.rand((N, P2, D), device=device, requires_grad=True)
|
||||||
|
lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
|
||||||
|
lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device)
|
||||||
|
|
||||||
|
x_csrc = x.clone().detach()
|
||||||
|
x_csrc.requires_grad_(True)
|
||||||
|
y_csrc = y.clone().detach()
|
||||||
|
y_csrc.requires_grad_(True)
|
||||||
|
|
||||||
|
# forward
|
||||||
|
out1 = self._knn_points_naive(
|
||||||
x, y, lengths1=lengths1, lengths2=lengths2, K=K
|
x, y, lengths1=lengths1, lengths2=lengths2, K=K
|
||||||
)
|
)
|
||||||
for version in versions:
|
out2 = knn_points(x_csrc, y_csrc, lengths1=lengths1, lengths2=lengths2, K=K)
|
||||||
if version == 3 and K > 4:
|
self.assertClose(out1[0], out2[0])
|
||||||
continue
|
self.assertTrue(torch.all(out1[1] == out2[1]))
|
||||||
out2 = knn_points_idx(
|
|
||||||
x, y, lengths1=lengths1, lengths2=lengths2, K=K, sorted=sort
|
# backward
|
||||||
)
|
grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device)
|
||||||
self._check_knn_result(out1, out2, sort)
|
loss1 = (out1.dists * grad_dist).sum()
|
||||||
|
loss1.backward()
|
||||||
|
loss2 = (out2.dists * grad_dist).sum()
|
||||||
|
loss2.backward()
|
||||||
|
|
||||||
|
self.assertClose(x_csrc.grad, x.grad, atol=5e-6)
|
||||||
|
self.assertClose(y_csrc.grad, y.grad, atol=5e-6)
|
||||||
|
|
||||||
|
def test_knn_vs_python_ragged_cpu(self):
|
||||||
|
device = torch.device("cpu")
|
||||||
|
self._knn_vs_python_ragged_helper(device)
|
||||||
|
|
||||||
|
def test_knn_vs_python_ragged_cuda(self):
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
self._knn_vs_python_ragged_helper(device)
|
||||||
|
|
||||||
|
def test_knn_gather(self):
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
N, P1, P2, K, D = 4, 16, 12, 8, 3
|
||||||
|
x = torch.rand((N, P1, D), device=device)
|
||||||
|
y = torch.rand((N, P2, D), device=device)
|
||||||
|
lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
|
||||||
|
lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device)
|
||||||
|
|
||||||
|
out = knn_points(x, y, lengths1=lengths1, lengths2=lengths2, K=K)
|
||||||
|
y_nn = knn_gather(y, out.idx, lengths2)
|
||||||
|
|
||||||
|
for n in range(N):
|
||||||
|
for p1 in range(P1):
|
||||||
|
for k in range(K):
|
||||||
|
if k < lengths2[n]:
|
||||||
|
self.assertClose(y_nn[n, p1, k], y[n, out.idx[n, p1, k]])
|
||||||
|
else:
|
||||||
|
self.assertTrue(torch.all(y_nn[n, p1, k] == 0.0))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def knn_square(N: int, P1: int, P2: int, D: int, K: int, device: str):
|
||||||
|
device = torch.device(device)
|
||||||
|
pts1 = torch.randn(N, P1, D, device=device, requires_grad=True)
|
||||||
|
pts2 = torch.randn(N, P2, D, device=device, requires_grad=True)
|
||||||
|
grad_dists = torch.randn(N, P1, K, device=device)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
def output():
|
||||||
|
out = knn_points(pts1, pts2, K=K)
|
||||||
|
loss = (out.dists * grad_dists).sum()
|
||||||
|
loss.backward()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def knn_ragged(N: int, P1: int, P2: int, D: int, K: int, device: str):
|
||||||
|
device = torch.device(device)
|
||||||
|
pts1 = torch.rand((N, P1, D), device=device, requires_grad=True)
|
||||||
|
pts2 = torch.rand((N, P2, D), device=device, requires_grad=True)
|
||||||
|
lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
|
||||||
|
lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device)
|
||||||
|
grad_dists = torch.randn(N, P1, K, device=device)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
def output():
|
||||||
|
out = knn_points(pts1, pts2, lengths1=lengths1, lengths2=lengths2, K=K)
|
||||||
|
loss = (out.dists * grad_dists).sum()
|
||||||
|
loss.backward()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
return output
|
||||||
|
Loading…
x
Reference in New Issue
Block a user