knn autograd

Summary:
Adds knn backward to return `grad_pts1` and `grad_pts2`. Adds `knn_gather` to return the nearest neighbors in pts2.

The BM tests include backward pass and are ran on an M40.
```
Benchmark                               Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
KNN_SQUARE_32_256_128_3_24_cpu              39558           43485             13
KNN_SQUARE_32_256_128_3_24_cuda:0            1080            1404            463
KNN_SQUARE_32_256_512_3_24_cpu              81950           85781              7
KNN_SQUARE_32_256_512_3_24_cuda:0            1519            1641            330
--------------------------------------------------------------------------------

Benchmark                               Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
KNN_RAGGED_32_256_128_3_24_cpu              13798           14650             37
KNN_RAGGED_32_256_128_3_24_cuda:0            1576            1713            318
KNN_RAGGED_32_256_512_3_24_cpu              31255           32210             16
KNN_RAGGED_32_256_512_3_24_cuda:0            2024            2162            248
--------------------------------------------------------------------------------
```

Reviewed By: jcjohnson

Differential Revision: D20945556

fbshipit-source-id: a16f616029c6b5f8c2afceb5f2bc12c5c20d2f3c
This commit is contained in:
Georgia Gkioxari
2020-04-14 17:20:16 -07:00
committed by Facebook GitHub Bot
parent 487d4d6607
commit b2b0c5a442
8 changed files with 545 additions and 365 deletions

View File

@@ -20,6 +20,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("packed_to_padded", &PackedToPadded);
m.def("padded_to_packed", &PaddedToPacked);
m.def("knn_points_idx", &KNearestNeighborIdx);
m.def("knn_points_backward", &KNearestNeighborBackward);
m.def("nn_points_idx", &NearestNeighborIdx);
m.def("gather_scatter", &gather_scatter);
m.def("rasterize_points", &RasterizePoints);

View File

@@ -412,3 +412,93 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
return std::make_tuple(idxs, dists);
}
// ------------------------------------------------------------- //
// Backward Operators //
// ------------------------------------------------------------- //
// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
// Currently, support is for floats only.
__global__ void KNearestNeighborBackwardKernel(
const float* __restrict__ p1, // (N, P1, D)
const float* __restrict__ p2, // (N, P2, D)
const int64_t* __restrict__ lengths1, // (N,)
const int64_t* __restrict__ lengths2, // (N,)
const int64_t* __restrict__ idxs, // (N, P1, K)
const float* __restrict__ grad_dists, // (N, P1, K)
float* __restrict__ grad_p1, // (N, P1, D)
float* __restrict__ grad_p2, // (N, P2, D)
const size_t N,
const size_t P1,
const size_t P2,
const size_t K,
const size_t D) {
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = gridDim.x * blockDim.x;
for (size_t i = tid; i < N * P1 * K * D; i += stride) {
const size_t n = i / (P1 * K * D); // batch index
size_t rem = i % (P1 * K * D);
const size_t p1_idx = rem / (K * D); // index of point in p1
rem = rem % (K * D);
const size_t k = rem / D; // k-th nearest neighbor
const size_t d = rem % D; // d-th dimension in the feature vector
const size_t num1 = lengths1[n]; // number of valid points in p1 in batch
const size_t num2 = lengths2[n]; // number of valid points in p2 in batch
if ((p1_idx < num1) && (k < num2)) {
const float grad_dist = grad_dists[n * P1 * K + p1_idx * K + k];
// index of point in p2 corresponding to the k-th nearest neighbor
const size_t p2_idx = idxs[n * P1 * K + p1_idx * K + k];
const float diff = 2.0 * grad_dist *
(p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff);
atomicAdd(grad_p2 + n * P2 * D + p2_idx * D + d, -1.0f * diff);
}
}
}
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const at::Tensor& idxs,
const at::Tensor& grad_dists) {
const auto N = p1.size(0);
const auto P1 = p1.size(1);
const auto P2 = p2.size(1);
const auto D = p2.size(2);
const auto K = idxs.size(2);
AT_ASSERTM(p2.size(2) == D, "Point sets must have the same last dimension");
AT_ASSERTM(idxs.size(0) == N, "KNN idxs must have the same batch dimension");
AT_ASSERTM(
idxs.size(1) == P1, "KNN idxs must have the same point dimension as p1");
AT_ASSERTM(grad_dists.size(0) == N);
AT_ASSERTM(grad_dists.size(1) == P1);
AT_ASSERTM(grad_dists.size(2) == K);
auto grad_p1 = at::zeros({N, P1, D}, p1.options());
auto grad_p2 = at::zeros({N, P2, D}, p2.options());
const int blocks = 64;
const int threads = 512;
KNearestNeighborBackwardKernel<<<blocks, threads>>>(
p1.data_ptr<float>(),
p2.data_ptr<float>(),
lengths1.data_ptr<int64_t>(),
lengths2.data_ptr<int64_t>(),
idxs.data_ptr<int64_t>(),
grad_dists.data_ptr<float>(),
grad_p1.data_ptr<float>(),
grad_p2.data_ptr<float>(),
N,
P1,
P2,
K,
D);
return std::make_tuple(grad_p1, grad_p2);
}

View File

@@ -16,8 +16,6 @@
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
// K: int giving the number of nearest points to return.
// sorted: bool telling whether to sort the K returned points by their
// distance.
// version: Integer telling which implementation to use.
//
// Returns:
@@ -67,3 +65,66 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
}
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, K);
}
// Compute gradients with respect to p1 and p2
//
// Args:
// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
// containing P1 points of dimension D.
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
// containing P2 points of dimension D.
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
// p1_neighbor_idx[n, i, k] = j means that the kth nearest
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
// It is padded with zeros so that it can be used easily in a later
// gather() operation. This is computed from the forward pass.
// grad_dists: FLoatTensor of shape (N, P1, K) which contains the input
// gradients.
//
// Returns:
// grad_p1: FloatTensor of shape (N, P1, D) containing the output gradients
// wrt p1.
// grad_p2: FloatTensor of shape (N, P2, D) containing the output gradients
// wrt p2.
// CPU implementation.
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const at::Tensor& idxs,
const at::Tensor& grad_dists);
// CUDA implementation
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const at::Tensor& idxs,
const at::Tensor& grad_dists);
// Implementation which is exposed.
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackward(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const at::Tensor& idxs,
const at::Tensor& grad_dists) {
if (p1.is_cuda() || p2.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(p1);
CHECK_CONTIGUOUS_CUDA(p2);
return KNearestNeighborBackwardCuda(
p1, p2, lengths1, lengths2, idxs, grad_dists);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
return KNearestNeighborBackwardCpu(
p1, p2, lengths1, lengths2, idxs, grad_dists);
}

View File

@@ -57,3 +57,51 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
}
return std::make_tuple(idxs, dists);
}
// ------------------------------------------------------------- //
// Backward Operators //
// ------------------------------------------------------------- //
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const at::Tensor& idxs,
const at::Tensor& grad_dists) {
const int N = p1.size(0);
const int P1 = p1.size(1);
const int D = p1.size(2);
const int P2 = p2.size(1);
const int K = idxs.size(2);
torch::Tensor grad_p1 = torch::full({N, P1, D}, 0, p1.options());
torch::Tensor grad_p2 = torch::full({N, P2, D}, 0, p2.options());
auto p1_a = p1.accessor<float, 3>();
auto p2_a = p2.accessor<float, 3>();
auto lengths1_a = lengths1.accessor<int64_t, 1>();
auto lengths2_a = lengths2.accessor<int64_t, 1>();
auto idxs_a = idxs.accessor<int64_t, 3>();
auto grad_dists_a = grad_dists.accessor<float, 3>();
auto grad_p1_a = grad_p1.accessor<float, 3>();
auto grad_p2_a = grad_p2.accessor<float, 3>();
for (int n = 0; n < N; ++n) {
const int64_t length1 = lengths1_a[n];
int64_t length2 = lengths2_a[n];
length2 = (length2 < K) ? length2 : K;
for (int64_t i1 = 0; i1 < length1; ++i1) {
for (int64_t k = 0; k < length2; ++k) {
const int64_t i2 = idxs_a[n][i1][k];
for (int64_t d = 0; d < D; ++d) {
const float diff =
2.0f * grad_dists_a[n][i1][k] * (p1_a[n][i1][d] - p2_a[n][i2][d]);
grad_p1_a[n][i1][d] += diff;
grad_p2_a[n][i2][d] += -1.0f * diff;
}
}
}
}
return std::make_tuple(grad_p1, grad_p2);
}

View File

@@ -3,6 +3,7 @@
from .cubify import cubify
from .graph_conv import GraphConv
from .knn import knn_gather, knn_points
from .mesh_face_areas_normals import mesh_face_areas_normals
from .nearest_neighbor_points import nn_points_idx
from .packed_to_padded import packed_to_padded, padded_to_packed

View File

@@ -1,152 +1,215 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from collections import namedtuple
from typing import Union
import torch
from pytorch3d import _C
from torch.autograd import Function
from torch.autograd.function import once_differentiable
def knn_points_idx(
p1,
p2,
K: int,
lengths1=None,
lengths2=None,
sorted: bool = False,
_KNN = namedtuple("KNN", "dists idx knn")
class _knn_points(Function):
"""
Torch autograd Function wrapper for KNN C++/CUDA implementations.
"""
@staticmethod
def forward(ctx, p1, p2, lengths1, lengths2, K, version):
"""
K-Nearest neighbors on point clouds.
Args:
p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each
containing up to P1 points of dimension D.
p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each
containing up to P2 points of dimension D.
lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
length of each pointcloud in p1. Or None to indicate that every cloud has
length P1.
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
length of each pointcloud in p2. Or None to indicate that every cloud has
length P2.
K: Integer giving the number of nearest neighbors to return.
version: Which KNN implementation to use in the backend. If version=-1,
the correct implementation is selected based on the shapes of the inputs.
Returns:
p1_dists: Tensor of shape (N, P1, K) giving the squared distances to
the nearest neighbors. This is padded with zeros both where a cloud in p2
has fewer than K points and where a cloud in p1 has fewer than P1 points.
p1_idx: LongTensor of shape (N, P1, K) giving the indices of the
K nearest neighbors from points in p1 to points in p2.
Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest
neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud
in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points.
"""
idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, K, version)
# sort KNN in ascending order if K > 1
if K > 1:
if lengths2.min() < K:
P1 = p1.shape[1]
mask = lengths2[:, None] <= torch.arange(K, device=dists.device)[None]
# mask has shape [N, K], true where dists irrelevant
mask = mask[:, None].expand(-1, P1, -1)
# mask has shape [N, P1, K], true where dists irrelevant
dists[mask] = float("inf")
dists, sort_idx = dists.sort(dim=2)
dists[mask] = 0
else:
dists, sort_idx = dists.sort(dim=2)
idx = idx.gather(2, sort_idx)
ctx.save_for_backward(p1, p2, lengths1, lengths2, idx)
return dists, idx
@staticmethod
@once_differentiable
def backward(ctx, grad_dists, grad_idx):
p1, p2, lengths1, lengths2, idx = ctx.saved_tensors
# TODO(gkioxari) Change cast to floats once we add support for doubles.
if not (grad_dists.dtype == torch.float32):
grad_dists = grad_dists.float()
if not (p1.dtype == torch.float32):
p1 = p1.float()
if not (p2.dtype == torch.float32):
p2 = p2.float()
grad_p1, grad_p2 = _C.knn_points_backward(
p1, p2, lengths1, lengths2, idx, grad_dists
)
return grad_p1, grad_p2, None, None, None, None
def knn_points(
p1: torch.Tensor,
p2: torch.Tensor,
lengths1: Union[torch.Tensor, None] = None,
lengths2: Union[torch.Tensor, None] = None,
K: int = 1,
version: int = -1,
return_nn: bool = False,
):
"""
K-Nearest neighbors on point clouds.
Args:
p1: Tensor of shape (N, P1, D) giving a batch of point clouds, each
p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each
containing up to P1 points of dimension D.
p2: Tensor of shape (N, P2, D) giving a batch of point clouds, each
p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each
containing up to P2 points of dimension D.
K: Integer giving the number of nearest neighbors to return.
lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
length of each pointcloud in p1. Or None to indicate that every cloud has
length P1.
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
length of each pointcloud in p2. Or None to indicate that every cloud has
length P2.
sorted: Whether to sort the resulting points.
K: Integer giving the number of nearest neighbors to return.
version: Which KNN implementation to use in the backend. If version=-1,
the correct implementation is selected based on the shapes of the inputs.
return_nn: If set to True returns the K nearest neighors in p2 for each point in p1.
Returns:
p1_neighbor_idx: LongTensor of shape (N, P1, K) giving the indices of the
p1_idx: LongTensor of shape (N, P1, K) giving the indices of the
K nearest neighbors from points in p1 to points in p2.
Concretely, if idx[n, i, k] = j then p2[n, j] is one of the K nearest
neighbors to p1[n, i] in p2[n]. If sorted=True, then p2[n, j] is the kth
nearest neighbor to p1[n, i]. This is padded with zeros both where a cloud
Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest
neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud
in p2 has fewer than K points and where a cloud in p1 has fewer than P1
points.
If you want an (N, P1, K, D) tensor of the actual points, you can get it
using
p2[:, :, None].expand(-1, -1, K, -1).gather(1,
x_idx[:, :, :, None].expand(-1, -1, -1, D)
)
If K=1 and you want an (N, P1, D) tensor of the actual points, use
p2.gather(1, x_idx.expand(-1, -1, D))
p1_neighbor_dists: Tensor of shape (N, P1, K) giving the squared distances to
p1_dists: Tensor of shape (N, P1, K) giving the squared distances to
the nearest neighbors. This is padded with zeros both where a cloud in p2
has fewer than K points and where a cloud in p1 has fewer than P1 points.
Warning: this is calculated outside of the autograd framework.
p2_nn: Tensor of shape (N, P1, K, D) giving the K nearest neighbors in p2 for
each point in p1. Concretely, `p2_nn[n, i, k]` gives the k-th nearest neighbor
for `p1[n, i]`. Returned if `return_nn` is True.
The nearest neighbors are collected using `knn_gather`
.. code-block::
p2_nn = knn_gather(p2, p1_idx, lengths2)
which is a helper function that allows indexing any tensor of shape (N, P2, U) with
the indices `p1_idx` returned by `knn_points`. The outout is a tensor
of shape (N, P1, K, U).
"""
if p1.shape[0] != p2.shape[0]:
raise ValueError("pts1 and pts2 must have the same batch dimension.")
if p1.shape[2] != p2.shape[2]:
raise ValueError("pts1 and pts2 must have the same point dimension.")
P1 = p1.shape[1]
P2 = p2.shape[1]
if lengths1 is None:
lengths1 = torch.full((p1.shape[0],), P1, dtype=torch.int64, device=p1.device)
if lengths2 is None:
lengths2 = torch.full((p1.shape[0],), P2, dtype=torch.int64, device=p1.device)
idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, K, version)
if sorted:
if lengths2.min() < K:
device = dists.device
mask1 = lengths2[:, None] <= torch.arange(K, device=device)[None]
# mask1 has shape [N, K], true where dists irrelevant
mask2 = mask1[:, None].expand(-1, P1, -1)
# mask2 has shape [N, P1, K], true where dists irrelevant
dists[mask2] = float("inf")
dists, sort_idx = dists.sort(dim=2)
dists[mask2] = 0
else:
dists, sort_idx = dists.sort(dim=2)
idx = idx.gather(2, sort_idx)
return idx, dists
p1_dists, p1_idx = _knn_points.apply(p1, p2, lengths1, lengths2, K, version)
p2_nn = None
if return_nn:
p2_nn = knn_gather(p2, p1_idx, lengths2)
return _KNN(dists=p1_dists, idx=p1_idx, knn=p2_nn if return_nn else None)
@torch.no_grad()
def _knn_points_idx_naive(p1, p2, K: int, lengths1, lengths2) -> torch.Tensor:
def knn_gather(
x: torch.Tensor, idx: torch.Tensor, lengths: Union[torch.Tensor, None] = None
):
"""
Naive PyTorch implementation of K-Nearest Neighbors.
This is much less efficient than _C.knn_points_idx, but we include this
naive implementation for testing and benchmarking.
A helper function for knn that allows indexing a tensor x with the indices `idx`
returned by `knn_points`.
For example, if `dists, idx = knn_points(p, x, lengths_p, lengths, K)`
where p is a tensor of shape (N, L, D) and x a tensor of shape (N, M, D),
then one can compute the K nearest neighbors of p with `p_nn = knn_gather(x, idx, lengths)`.
It can also be applied for any tensor x of shape (N, M, U) where U != D.
Args:
p1: Tensor of shape (N, P1, D) giving a batch of point clouds, each
containing up to P1 points of dimension D.
p2: Tensor of shape (N, P2, D) giving a batch of point clouds, each
containing up to P2 points of dimension D.
K: Integer giving the number of nearest neighbors to return.
lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
length of each pointcloud in p1. Or None to indicate that every cloud has
length P1.
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
length of each pointcloud in p2. Or None to indicate that every cloud has
length P2.
x: Tensor of shape (N, M, U) containing U-dimensional features to
be gathered.
idx: LongTensor of shape (N, L, K) giving the indices returned by `knn_points`.
lengths: LongTensor of shape (N,) of values in the range [0, M], giving the
length of each example in the batch in x. Or None to indicate that every
example has length M.
Returns:
idx: LongTensor of shape (N, P1, K) giving the indices of the
K nearest neighbors from points in p1 to points in p2.
Concretely, if idx[n, i, k] = j then p2[n, j] is the kth nearest neighbor
to p1[n, i]. This is padded with zeros both where a cloud in p2 has fewer
than K points and where a cloud in p1 has fewer than P1 points.
dists: Tensor of shape (N, P1, K) giving the squared distances to the nearest
neighbors. This is padded with zeros both where a cloud in p2 has fewer than
K points and where a cloud in p1 has fewer than P1 points.
x_out: Tensor of shape (N, L, K, U) resulting from gathering the elements of x
with idx, s.t. `x_out[n, l, k] = x[n, idx[n, l, k]]`.
If `k > lengths[n]` then `x_out[n, l, k]` is filled with 0.0.
"""
N, P1, D = p1.shape
_N, P2, _D = p2.shape
N, M, U = x.shape
_N, L, K = idx.shape
assert N == _N and D == _D
if N != _N:
raise ValueError("x and idx must have same batch dimension.")
if lengths1 is None:
lengths1 = torch.full((N,), P1, dtype=torch.int64, device=p1.device)
if lengths2 is None:
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device)
if lengths is None:
lengths = torch.full((x.shape[0],), M, dtype=torch.int64, device=x.device)
p1_copy = p1.clone()
p2_copy = p2.clone()
idx_expanded = idx[:, :, :, None].expand(-1, -1, -1, U)
# idx_expanded has shape [N, L, K, U]
# We pad the values with infinities so that the smallest differences are
# among actual points.
inf = float("inf")
p1_mask = torch.arange(P1, device=p1.device)[None] >= lengths1[:, None]
p1_copy[p1_mask] = inf
p2_copy[torch.arange(P2, device=p1.device)[None] >= lengths2[:, None]] = -inf
x_out = x[:, :, None].expand(-1, -1, K, -1).gather(1, idx_expanded)
# p2_nn has shape [N, L, K, U]
# view is safe here: we are merely adding extra dimensions of length 1
diffs = p1_copy.view(N, P1, 1, D) - p2_copy.view(N, 1, P2, D)
dists2 = (diffs * diffs).sum(dim=3)
needs_mask = lengths.min() < K
if needs_mask:
# mask has shape [N, K], true where idx is irrelevant because
# there is less number of points in p2 than K
mask = lengths[:, None] <= torch.arange(K, device=x.device)[None]
# We always sort, because this works well with padding.
out = dists2.topk(min(K, P2), dim=2, largest=False, sorted=True)
# expand mask to shape [N, L, K, U]
mask = mask[:, None].expand(-1, L, -1)
mask = mask[:, :, :, None].expand(-1, -1, -1, U)
x_out[mask] = 0.0
out_indices = out.indices
out_values = out.values
if P2 < K:
# Need to add padding
pad_shape = (N, P1, K - P2)
out_indices = torch.cat([out_indices, out_indices.new_zeros(pad_shape)], 2)
out_values = torch.cat([out_values, out_values.new_zeros(pad_shape)], 2)
K_mask = torch.arange(K, device=p1.device)[None] >= lengths2[:, None]
# Create a combined mask for where the points in p1 are padded
# or the corresponding p2 has fewer than K points.
p1_K_mask = p1_mask[:, :, None] | K_mask[:, None, :]
out_indices[p1_K_mask] = 0
out_values[p1_K_mask] = 0
return out_indices, out_values
return x_out