mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
knn autograd
Summary: Adds knn backward to return `grad_pts1` and `grad_pts2`. Adds `knn_gather` to return the nearest neighbors in pts2. The BM tests include backward pass and are ran on an M40. ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- KNN_SQUARE_32_256_128_3_24_cpu 39558 43485 13 KNN_SQUARE_32_256_128_3_24_cuda:0 1080 1404 463 KNN_SQUARE_32_256_512_3_24_cpu 81950 85781 7 KNN_SQUARE_32_256_512_3_24_cuda:0 1519 1641 330 -------------------------------------------------------------------------------- Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- KNN_RAGGED_32_256_128_3_24_cpu 13798 14650 37 KNN_RAGGED_32_256_128_3_24_cuda:0 1576 1713 318 KNN_RAGGED_32_256_512_3_24_cpu 31255 32210 16 KNN_RAGGED_32_256_512_3_24_cuda:0 2024 2162 248 -------------------------------------------------------------------------------- ``` Reviewed By: jcjohnson Differential Revision: D20945556 fbshipit-source-id: a16f616029c6b5f8c2afceb5f2bc12c5c20d2f3c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
487d4d6607
commit
b2b0c5a442
@@ -20,6 +20,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("packed_to_padded", &PackedToPadded);
|
||||
m.def("padded_to_packed", &PaddedToPacked);
|
||||
m.def("knn_points_idx", &KNearestNeighborIdx);
|
||||
m.def("knn_points_backward", &KNearestNeighborBackward);
|
||||
m.def("nn_points_idx", &NearestNeighborIdx);
|
||||
m.def("gather_scatter", &gather_scatter);
|
||||
m.def("rasterize_points", &RasterizePoints);
|
||||
|
||||
@@ -412,3 +412,93 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
|
||||
return std::make_tuple(idxs, dists);
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------- //
|
||||
// Backward Operators //
|
||||
// ------------------------------------------------------------- //
|
||||
|
||||
// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
|
||||
// Currently, support is for floats only.
|
||||
__global__ void KNearestNeighborBackwardKernel(
|
||||
const float* __restrict__ p1, // (N, P1, D)
|
||||
const float* __restrict__ p2, // (N, P2, D)
|
||||
const int64_t* __restrict__ lengths1, // (N,)
|
||||
const int64_t* __restrict__ lengths2, // (N,)
|
||||
const int64_t* __restrict__ idxs, // (N, P1, K)
|
||||
const float* __restrict__ grad_dists, // (N, P1, K)
|
||||
float* __restrict__ grad_p1, // (N, P1, D)
|
||||
float* __restrict__ grad_p2, // (N, P2, D)
|
||||
const size_t N,
|
||||
const size_t P1,
|
||||
const size_t P2,
|
||||
const size_t K,
|
||||
const size_t D) {
|
||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t stride = gridDim.x * blockDim.x;
|
||||
|
||||
for (size_t i = tid; i < N * P1 * K * D; i += stride) {
|
||||
const size_t n = i / (P1 * K * D); // batch index
|
||||
size_t rem = i % (P1 * K * D);
|
||||
const size_t p1_idx = rem / (K * D); // index of point in p1
|
||||
rem = rem % (K * D);
|
||||
const size_t k = rem / D; // k-th nearest neighbor
|
||||
const size_t d = rem % D; // d-th dimension in the feature vector
|
||||
|
||||
const size_t num1 = lengths1[n]; // number of valid points in p1 in batch
|
||||
const size_t num2 = lengths2[n]; // number of valid points in p2 in batch
|
||||
if ((p1_idx < num1) && (k < num2)) {
|
||||
const float grad_dist = grad_dists[n * P1 * K + p1_idx * K + k];
|
||||
// index of point in p2 corresponding to the k-th nearest neighbor
|
||||
const size_t p2_idx = idxs[n * P1 * K + p1_idx * K + k];
|
||||
const float diff = 2.0 * grad_dist *
|
||||
(p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
|
||||
atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff);
|
||||
atomicAdd(grad_p2 + n * P2 * D + p2_idx * D + d, -1.0f * diff);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
|
||||
const at::Tensor& p1,
|
||||
const at::Tensor& p2,
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
const at::Tensor& idxs,
|
||||
const at::Tensor& grad_dists) {
|
||||
const auto N = p1.size(0);
|
||||
const auto P1 = p1.size(1);
|
||||
const auto P2 = p2.size(1);
|
||||
const auto D = p2.size(2);
|
||||
const auto K = idxs.size(2);
|
||||
|
||||
AT_ASSERTM(p2.size(2) == D, "Point sets must have the same last dimension");
|
||||
AT_ASSERTM(idxs.size(0) == N, "KNN idxs must have the same batch dimension");
|
||||
AT_ASSERTM(
|
||||
idxs.size(1) == P1, "KNN idxs must have the same point dimension as p1");
|
||||
AT_ASSERTM(grad_dists.size(0) == N);
|
||||
AT_ASSERTM(grad_dists.size(1) == P1);
|
||||
AT_ASSERTM(grad_dists.size(2) == K);
|
||||
|
||||
auto grad_p1 = at::zeros({N, P1, D}, p1.options());
|
||||
auto grad_p2 = at::zeros({N, P2, D}, p2.options());
|
||||
|
||||
const int blocks = 64;
|
||||
const int threads = 512;
|
||||
|
||||
KNearestNeighborBackwardKernel<<<blocks, threads>>>(
|
||||
p1.data_ptr<float>(),
|
||||
p2.data_ptr<float>(),
|
||||
lengths1.data_ptr<int64_t>(),
|
||||
lengths2.data_ptr<int64_t>(),
|
||||
idxs.data_ptr<int64_t>(),
|
||||
grad_dists.data_ptr<float>(),
|
||||
grad_p1.data_ptr<float>(),
|
||||
grad_p2.data_ptr<float>(),
|
||||
N,
|
||||
P1,
|
||||
P2,
|
||||
K,
|
||||
D);
|
||||
|
||||
return std::make_tuple(grad_p1, grad_p2);
|
||||
}
|
||||
|
||||
@@ -16,8 +16,6 @@
|
||||
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
|
||||
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
|
||||
// K: int giving the number of nearest points to return.
|
||||
// sorted: bool telling whether to sort the K returned points by their
|
||||
// distance.
|
||||
// version: Integer telling which implementation to use.
|
||||
//
|
||||
// Returns:
|
||||
@@ -67,3 +65,66 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
|
||||
}
|
||||
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, K);
|
||||
}
|
||||
|
||||
// Compute gradients with respect to p1 and p2
|
||||
//
|
||||
// Args:
|
||||
// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
|
||||
// containing P1 points of dimension D.
|
||||
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
|
||||
// containing P2 points of dimension D.
|
||||
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
|
||||
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
|
||||
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
|
||||
// p1_neighbor_idx[n, i, k] = j means that the kth nearest
|
||||
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
|
||||
// It is padded with zeros so that it can be used easily in a later
|
||||
// gather() operation. This is computed from the forward pass.
|
||||
// grad_dists: FLoatTensor of shape (N, P1, K) which contains the input
|
||||
// gradients.
|
||||
//
|
||||
// Returns:
|
||||
// grad_p1: FloatTensor of shape (N, P1, D) containing the output gradients
|
||||
// wrt p1.
|
||||
// grad_p2: FloatTensor of shape (N, P2, D) containing the output gradients
|
||||
// wrt p2.
|
||||
|
||||
// CPU implementation.
|
||||
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
|
||||
const at::Tensor& p1,
|
||||
const at::Tensor& p2,
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
const at::Tensor& idxs,
|
||||
const at::Tensor& grad_dists);
|
||||
|
||||
// CUDA implementation
|
||||
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
|
||||
const at::Tensor& p1,
|
||||
const at::Tensor& p2,
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
const at::Tensor& idxs,
|
||||
const at::Tensor& grad_dists);
|
||||
|
||||
// Implementation which is exposed.
|
||||
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackward(
|
||||
const at::Tensor& p1,
|
||||
const at::Tensor& p2,
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
const at::Tensor& idxs,
|
||||
const at::Tensor& grad_dists) {
|
||||
if (p1.is_cuda() || p2.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(p1);
|
||||
CHECK_CONTIGUOUS_CUDA(p2);
|
||||
return KNearestNeighborBackwardCuda(
|
||||
p1, p2, lengths1, lengths2, idxs, grad_dists);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return KNearestNeighborBackwardCpu(
|
||||
p1, p2, lengths1, lengths2, idxs, grad_dists);
|
||||
}
|
||||
|
||||
@@ -57,3 +57,51 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
|
||||
}
|
||||
return std::make_tuple(idxs, dists);
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------- //
|
||||
// Backward Operators //
|
||||
// ------------------------------------------------------------- //
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
|
||||
const at::Tensor& p1,
|
||||
const at::Tensor& p2,
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
const at::Tensor& idxs,
|
||||
const at::Tensor& grad_dists) {
|
||||
const int N = p1.size(0);
|
||||
const int P1 = p1.size(1);
|
||||
const int D = p1.size(2);
|
||||
const int P2 = p2.size(1);
|
||||
const int K = idxs.size(2);
|
||||
|
||||
torch::Tensor grad_p1 = torch::full({N, P1, D}, 0, p1.options());
|
||||
torch::Tensor grad_p2 = torch::full({N, P2, D}, 0, p2.options());
|
||||
|
||||
auto p1_a = p1.accessor<float, 3>();
|
||||
auto p2_a = p2.accessor<float, 3>();
|
||||
auto lengths1_a = lengths1.accessor<int64_t, 1>();
|
||||
auto lengths2_a = lengths2.accessor<int64_t, 1>();
|
||||
auto idxs_a = idxs.accessor<int64_t, 3>();
|
||||
auto grad_dists_a = grad_dists.accessor<float, 3>();
|
||||
auto grad_p1_a = grad_p1.accessor<float, 3>();
|
||||
auto grad_p2_a = grad_p2.accessor<float, 3>();
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
const int64_t length1 = lengths1_a[n];
|
||||
int64_t length2 = lengths2_a[n];
|
||||
length2 = (length2 < K) ? length2 : K;
|
||||
for (int64_t i1 = 0; i1 < length1; ++i1) {
|
||||
for (int64_t k = 0; k < length2; ++k) {
|
||||
const int64_t i2 = idxs_a[n][i1][k];
|
||||
for (int64_t d = 0; d < D; ++d) {
|
||||
const float diff =
|
||||
2.0f * grad_dists_a[n][i1][k] * (p1_a[n][i1][d] - p2_a[n][i2][d]);
|
||||
grad_p1_a[n][i1][d] += diff;
|
||||
grad_p2_a[n][i2][d] += -1.0f * diff;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_tuple(grad_p1, grad_p2);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user