Ball Query

Summary:
Implementation of ball query from PointNet++.  This function is similar to KNN (find the neighbors in p2 for all points in p1). These are the key differences:
-  It will return the **first** K neighbors within a specified radius as opposed to the **closest** K neighbors.
- As all the points in p2 do not need to be considered to find the closest K, the algorithm is much faster than KNN when p2 has a large number of points.
- The neighbors are not sorted
- Due to the radius threshold it is not guaranteed that there will be K neighbors even if there are more than K points in p2.
- The padding value for `idx` is -1 instead of 0.

# Note:
- Some of the code is very similar to KNN so it could be possible to modify the KNN forward kernels to support ball query.
- Some users might want to use kNN with ball query - for this we could provide a wrapper function around the current `knn_points` which enables applying the radius threshold afterwards as an alternative. This could be called `ball_query_knn`.

Reviewed By: jcjohnson

Differential Revision: D30261362

fbshipit-source-id: 66b6a7e0114beff7164daf7eba21546ff41ec450
This commit is contained in:
Nikhila Ravi 2021-08-12 14:05:23 -07:00 committed by Facebook GitHub Bot
parent e5c58a8a8b
commit 103da63393
10 changed files with 709 additions and 1 deletions

View File

@ -0,0 +1,130 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "utils/pytorch3d_cutils.h"
// A chunk of work is blocksize-many points of P1.
// The number of potential chunks to do is N*(1+(P1-1)/blocksize)
// call (1+(P1-1)/blocksize) chunks_per_cloud
// These chunks are divided among the gridSize-many blocks.
// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc .
// In chunk i, we work on cloud i/chunks_per_cloud on points starting from
// blocksize*(i%chunks_per_cloud).
template <typename scalar_t>
__global__ void BallQueryKernel(
const at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> p1,
const at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> p2,
const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits>
lengths1,
const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits>
lengths2,
at::PackedTensorAccessor64<int64_t, 3, at::RestrictPtrTraits> idxs,
at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> dists,
const int64_t K,
const float radius2) {
const int64_t N = p1.size(0);
const int64_t chunks_per_cloud = (1 + (p1.size(1) - 1) / blockDim.x);
const int64_t chunks_to_do = N * chunks_per_cloud;
const int D = p1.size(2);
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
const int64_t n = chunk / chunks_per_cloud; // batch_index
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
int64_t i = start_point + threadIdx.x;
// Check if point is valid in heterogeneous tensor
if (i >= lengths1[n]) {
continue;
}
// Iterate over points in p2 until desired count is reached or
// all points have been considered
for (int64_t j = 0, count = 0; j < lengths2[n] && count < K; ++j) {
// Calculate the distance between the points
scalar_t dist2 = 0.0;
for (int d = 0; d < D; ++d) {
scalar_t diff = p1[n][i][d] - p2[n][j][d];
dist2 += (diff * diff);
}
if (dist2 < radius2) {
// If the point is within the radius
// Set the value of the index to the point index
idxs[n][i][count] = j;
dists[n][i][count] = dist2;
// increment the number of selected samples for the point i
++count;
}
}
}
}
std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
const at::Tensor& p1, // (N, P1, 3)
const at::Tensor& p2, // (N, P2, 3)
const at::Tensor& lengths1, // (N,)
const at::Tensor& lengths2, // (N,)
int K,
float radius) {
// Check inputs are on the same device
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4};
at::CheckedFrom c = "BallQueryCuda";
at::checkAllSameGPU(c, {p1_t, p2_t, lengths1_t, lengths2_t});
at::checkAllSameType(c, {p1_t, p2_t});
// Set the device for the kernel launch based on the device of p1
at::cuda::CUDAGuard device_guard(p1.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_CHECK(
p2.size(2) == p1.size(2), "Point sets must have the same last dimension");
const int N = p1.size(0);
const int P1 = p1.size(1);
const int64_t K_64 = K;
const float radius2 = radius * radius;
// Output tensor with indices of neighbors for each point in p1
auto long_dtype = lengths1.options().dtype(at::kLong);
auto idxs = at::full({N, P1, K}, -1, long_dtype);
auto dists = at::zeros({N, P1, K}, p1.options());
if (idxs.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(idxs, dists);
}
const size_t blocks = 256;
const size_t threads = 256;
AT_DISPATCH_FLOATING_TYPES(
p1.scalar_type(), "ball_query_kernel_cuda", ([&] {
BallQueryKernel<<<blocks, threads, 0, stream>>>(
p1.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
p2.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
lengths1.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>(),
lengths2.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>(),
idxs.packed_accessor64<int64_t, 3, at::RestrictPtrTraits>(),
dists.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
K_64,
radius2);
}));
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(idxs, dists);
}

View File

@ -0,0 +1,91 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <torch/extension.h>
#include <tuple>
#include "utils/pytorch3d_cutils.h"
// Compute indices of K neighbors in pointcloud p2 to points
// in pointcloud p1 which fall within a specified radius
//
// Args:
// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
// containing P1 points of dimension D.
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
// containing P2 points of dimension D.
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
// K: Integer giving the upper bound on the number of samples to take
// within the radius
// radius: the radius around each point within which the neighbors need to be
// located
//
// Returns:
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
// p1_neighbor_idx[n, i, k] = j means that the kth
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
// This is padded with -1s both where a cloud in p2 has fewer than
// S points and where a cloud in p1 has fewer than P1 points and
// also if there are fewer than K points which satisfy the radius
// threshold.
//
// p1_neighbor_dists: FloatTensor of shape (N, P1, K) containing the squared
// distance from each point p1[n, p, :] to its K neighbors
// p2[n, p1_neighbor_idx[n, p, k], :].
// CPU implementation
std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const int K,
const float radius);
// CUDA implementation
std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const int K,
const float radius);
// Implementation which is exposed
// Note: the backward pass reuses the KNearestNeighborBackward kernel
inline std::tuple<at::Tensor, at::Tensor> BallQuery(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
int K,
float radius) {
if (p1.is_cuda() || p2.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CUDA(p1);
CHECK_CUDA(p2);
return BallQueryCuda(
p1.contiguous(),
p2.contiguous(),
lengths1.contiguous(),
lengths2.contiguous(),
K,
radius);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
return BallQueryCpu(
p1.contiguous(),
p2.contiguous(),
lengths1.contiguous(),
lengths2.contiguous(),
K,
radius);
}

View File

@ -0,0 +1,55 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <torch/extension.h>
#include <queue>
#include <tuple>
std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
int K,
float radius) {
const int N = p1.size(0);
const int P1 = p1.size(1);
const int D = p1.size(2);
auto long_opts = lengths1.options().dtype(torch::kInt64);
torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts);
torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());
const float radius2 = radius * radius;
auto p1_a = p1.accessor<float, 3>();
auto p2_a = p2.accessor<float, 3>();
auto lengths1_a = lengths1.accessor<int64_t, 1>();
auto lengths2_a = lengths2.accessor<int64_t, 1>();
auto idxs_a = idxs.accessor<int64_t, 3>();
auto dists_a = dists.accessor<float, 3>();
for (int n = 0; n < N; ++n) {
const int64_t length1 = lengths1_a[n];
const int64_t length2 = lengths2_a[n];
for (int64_t i = 0; i < length1; ++i) {
for (int64_t j = 0, count = 0; j < length2 && count < K; ++j) {
float dist2 = 0;
for (int d = 0; d < D; ++d) {
float diff = p1_a[n][i][d] - p2_a[n][j][d];
dist2 += diff * diff;
}
if (dist2 < radius2) {
dists_a[n][i][count] = dist2;
idxs_a[n][i][count] = j;
++count;
}
}
}
}
return std::make_tuple(idxs, dists);
}

View File

@ -12,6 +12,7 @@
// clang-format on
#include "./pulsar/pytorch/renderer.h"
#include "./pulsar/pytorch/tensor_util.h"
#include "ball_query/ball_query.h"
#include "blending/sigmoid_alpha_blend.h"
#include "compositing/alpha_composite.h"
#include "compositing/norm_weighted_sum.h"
@ -38,6 +39,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#endif
m.def("knn_points_idx", &KNearestNeighborIdx);
m.def("knn_points_backward", &KNearestNeighborBackward);
// Ball Query
m.def("ball_query", &BallQuery);
m.def(
"mesh_normal_consistency_find_verts", &MeshNormalConsistencyFindVertices);
m.def("gather_scatter", &GatherScatter);

View File

@ -477,6 +477,10 @@ __global__ void KNearestNeighborBackwardKernel(
const float grad_dist = grad_dists[n * P1 * K + p1_idx * K + k];
// index of point in p2 corresponding to the k-th nearest neighbor
const size_t p2_idx = idxs[n * P1 * K + p1_idx * K + k];
// If the index is the pad value of -1 then ignore it
if (p2_idx == -1) {
continue;
}
const float diff = 2.0 * grad_dist *
(p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff);

View File

@ -99,6 +99,10 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
for (int64_t i1 = 0; i1 < length1; ++i1) {
for (int64_t k = 0; k < length2; ++k) {
const int64_t i2 = idxs_a[n][i1][k];
// If the index is the pad value of -1 then ignore it
if (i2 == -1) {
continue;
}
for (int64_t d = 0; d < D; ++d) {
const float diff =
2.0f * grad_dists_a[n][i1][k] * (p1_a[n][i1][d] - p2_a[n][i2][d]);

View File

@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .ball_query import ball_query
from .cameras_alignment import corresponding_cameras_alignment
from .cubify import cubify
from .graph_conv import GraphConv
@ -34,5 +35,4 @@ from .utils import (
)
from .vert_align import vert_align
__all__ = [k for k in globals().keys() if not k.startswith("_")]

150
pytorch3d/ops/ball_query.py Normal file
View File

@ -0,0 +1,150 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Union
import torch
from pytorch3d import _C
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from .knn import _KNN
class _ball_query(Function):
"""
Torch autograd Function wrapper for Ball Query C++/CUDA implementations.
"""
@staticmethod
def forward(ctx, p1, p2, lengths1, lengths2, K, radius):
"""
Arguments defintions the same as in the ball_query function
"""
idx, dists = _C.ball_query(p1, p2, lengths1, lengths2, K, radius)
ctx.save_for_backward(p1, p2, lengths1, lengths2, idx)
ctx.mark_non_differentiable(idx)
return dists, idx
@staticmethod
@once_differentiable
def backward(ctx, grad_dists, grad_idx):
p1, p2, lengths1, lengths2, idx = ctx.saved_tensors
# TODO(gkioxari) Change cast to floats once we add support for doubles.
if not (grad_dists.dtype == torch.float32):
grad_dists = grad_dists.float()
if not (p1.dtype == torch.float32):
p1 = p1.float()
if not (p2.dtype == torch.float32):
p2 = p2.float()
# Reuse the KNN backward function
grad_p1, grad_p2 = _C.knn_points_backward(
p1, p2, lengths1, lengths2, idx, grad_dists
)
return grad_p1, grad_p2, None, None, None, None
def ball_query(
p1: torch.Tensor,
p2: torch.Tensor,
lengths1: Union[torch.Tensor, None] = None,
lengths2: Union[torch.Tensor, None] = None,
K: int = 500,
radius: float = 0.2,
return_nn: bool = True,
):
"""
Ball Query is an alternative to KNN. It can be
used to find all points in p2 that are within a specified radius
to the query point in p1 (with an upper limit of K neighbors).
The neighbors returned are not necssarily the *nearest* to the
point in p1, just the first K values in p2 which are within the
specified radius.
This method is faster than kNN when there are large numbers of points
in p2 and the ordering of neighbors is not important compared to the
distance being within the radius threshold.
"Ball querys local neighborhood guarantees a fixed region scale thus
making local region features more generalizable across space, which is
preferred for tasks requiring local pattern recognition
(e.g. semantic point labeling)" [1].
[1] Charles R. Qi et al, "PointNet++: Deep Hierarchical Feature Learning
on Point Sets in a Metric Space", NeurIPS 2017.
Args:
p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each
containing up to P1 points of dimension D. These represent the centers of
the ball queries.
p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each
containing up to P2 points of dimension D.
lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
length of each pointcloud in p1. Or None to indicate that every cloud has
length P1.
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
length of each pointcloud in p2. Or None to indicate that every cloud has
length P2.
K: Integer giving the upper bound on the number of samples to take
within the radius
radius: the radius around each point within which the neighbors need to be located
return_nn: If set to True returns the K neighbor points in p2 for each point in p1.
Returns:
dists: Tensor of shape (N, P1, K) giving the squared distances to
the neighbors. This is padded with zeros both where a cloud in p2
has fewer than S points and where a cloud in p1 has fewer than P1 points
and also if there are fewer than K points which satisfy the radius threshold.
idx: LongTensor of shape (N, P1, K) giving the indices of the
S neighbors in p2 for points in p1.
Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th
neighbor to `p1[n, i]` in `p2[n]`. This is padded with -1 both where a cloud
in p2 has fewer than S points and where a cloud in p1 has fewer than P1
points and also if there are fewer than K points which satisfy the radius threshold.
nn: Tensor of shape (N, P1, K, D) giving the K neighbors in p2 for
each point in p1. Concretely, `p2_nn[n, i, k]` gives the k-th neighbor
for `p1[n, i]`. Returned if `return_nn` is True. The output is a tensor
of shape (N, P1, K, U).
"""
if p1.shape[0] != p2.shape[0]:
raise ValueError("pts1 and pts2 must have the same batch dimension.")
if p1.shape[2] != p2.shape[2]:
raise ValueError("pts1 and pts2 must have the same point dimension.")
p1 = p1.contiguous()
p2 = p2.contiguous()
P1 = p1.shape[1]
P2 = p2.shape[1]
D = p2.shape[2]
N = p1.shape[0]
if lengths1 is None:
lengths1 = torch.full((N,), P1, dtype=torch.int64, device=p1.device)
if lengths2 is None:
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device)
# pyre-fixme[16]: `_ball_query` has no attribute `apply`.
dists, idx = _ball_query.apply(p1, p2, lengths1, lengths2, K, radius)
# Gather the neighbors if needed
points_nn = None
if return_nn:
idx_expanded = idx[:, :, :, None].expand(-1, -1, -1, D)
idx_mask = idx_expanded.eq(-1)
idx_new = idx_expanded.clone()
# Replace -1 values with 0 for gather
idx_new[idx_mask] = 0
# Gather points from p2
points_nn = p2[:, :, None].expand(-1, -1, K, -1).gather(1, idx_new)
# Replace padded values
points_nn[idx_mask] = 0.0
return _KNN(dists=dists, idx=idx, knn=points_nn)

40
tests/bm_ball_query.py Normal file
View File

@ -0,0 +1,40 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from itertools import product
from fvcore.common.benchmark import benchmark
from test_ball_query import TestBallQuery
def bm_ball_query() -> None:
backends = ["cpu", "cuda:0"]
kwargs_list = []
Ns = [32]
P1s = [256]
P2s = [128, 512]
Ds = [3, 10]
Ks = [3, 24, 100]
Rs = [0.1, 0.2, 5]
test_cases = product(Ns, P1s, P2s, Ds, Ks, Rs, backends)
for case in test_cases:
N, P1, P2, D, K, R, b = case
kwargs_list.append(
{"N": N, "P1": P1, "P2": P2, "D": D, "K": K, "radius": R, "device": b}
)
benchmark(
TestBallQuery.ball_query_square, "BALLQUERY_SQUARE", kwargs_list, warmup_iters=1
)
benchmark(
TestBallQuery.ball_query_ragged, "BALLQUERY_RAGGED", kwargs_list, warmup_iters=1
)
if __name__ == "__main__":
bm_ball_query()

230
tests/test_ball_query.py Normal file
View File

@ -0,0 +1,230 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from itertools import product
import torch
from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.ops.ball_query import ball_query
from pytorch3d.ops.knn import _KNN
from pytorch3d.utils import ico_sphere
class TestBallQuery(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(1)
@staticmethod
def _ball_query_naive(
p1, p2, lengths1, lengths2, K: int, radius: float
) -> torch.Tensor:
"""
Naive PyTorch implementation of ball query.
"""
N, P1, D = p1.shape
_N, P2, _D = p2.shape
assert N == _N and D == _D
if lengths1 is None:
lengths1 = torch.full((N,), P1, dtype=torch.int64, device=p1.device)
if lengths2 is None:
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device)
radius2 = radius * radius
dists = torch.zeros((N, P1, K), dtype=torch.float32, device=p1.device)
idx = torch.full((N, P1, K), fill_value=-1, dtype=torch.int64, device=p1.device)
# Iterate through the batches
for n in range(N):
num1 = lengths1[n].item()
num2 = lengths2[n].item()
# Iterate through the points in the p1
for i in range(num1):
# Iterate through the points in the p2
count = 0
for j in range(num2):
dist = p2[n, j] - p1[n, i]
dist2 = (dist * dist).sum()
if dist2 < radius2 and count < K:
dists[n, i, count] = dist2
idx[n, i, count] = j
count += 1
return _KNN(dists=dists, idx=idx, knn=None)
def _ball_query_vs_python_square_helper(self, device):
Ns = [1, 4]
Ds = [3, 5, 8]
P1s = [8, 24]
P2s = [8, 16, 32]
Ks = [1, 5]
Rs = [3, 5]
factors = [Ns, Ds, P1s, P2s, Ks, Rs]
for N, D, P1, P2, K, R in product(*factors):
x = torch.randn(N, P1, D, device=device, requires_grad=True)
x_cuda = x.clone().detach()
x_cuda.requires_grad_(True)
y = torch.randn(N, P2, D, device=device, requires_grad=True)
y_cuda = y.clone().detach()
y_cuda.requires_grad_(True)
# forward
out1 = self._ball_query_naive(
x, y, lengths1=None, lengths2=None, K=K, radius=R
)
out2 = ball_query(x_cuda, y_cuda, K=K, radius=R)
# Check dists
self.assertClose(out1.dists, out2.dists)
# Check idx
self.assertTrue(torch.all(out1.idx == out2.idx))
# backward
grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device)
loss1 = (out1.dists * grad_dist).sum()
loss1.backward()
loss2 = (out2.dists * grad_dist).sum()
loss2.backward()
self.assertClose(x_cuda.grad, x.grad, atol=5e-6)
self.assertClose(y_cuda.grad, y.grad, atol=5e-6)
def test_ball_query_vs_python_square_cpu(self):
device = torch.device("cpu")
self._ball_query_vs_python_square_helper(device)
def test_ball_query_vs_python_square_cuda(self):
device = get_random_cuda_device()
self._ball_query_vs_python_square_helper(device)
def _ball_query_vs_python_ragged_helper(self, device):
Ns = [1, 4]
Ds = [3, 5, 8]
P1s = [8, 24]
P2s = [8, 16, 32]
Ks = [2, 3, 10]
Rs = [1.4, 5] # radius
factors = [Ns, Ds, P1s, P2s, Ks, Rs]
for N, D, P1, P2, K, R in product(*factors):
x = torch.rand((N, P1, D), device=device, requires_grad=True)
y = torch.rand((N, P2, D), device=device, requires_grad=True)
lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device)
x_csrc = x.clone().detach()
x_csrc.requires_grad_(True)
y_csrc = y.clone().detach()
y_csrc.requires_grad_(True)
# forward
out1 = self._ball_query_naive(
x, y, lengths1=lengths1, lengths2=lengths2, K=K, radius=R
)
out2 = ball_query(
x_csrc,
y_csrc,
lengths1=lengths1,
lengths2=lengths2,
K=K,
radius=R,
)
self.assertClose(out1.idx, out2.idx)
self.assertClose(out1.dists, out2.dists)
# backward
grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device)
loss1 = (out1.dists * grad_dist).sum()
loss1.backward()
loss2 = (out2.dists * grad_dist).sum()
loss2.backward()
self.assertClose(x_csrc.grad, x.grad, atol=5e-6)
self.assertClose(y_csrc.grad, y.grad, atol=5e-6)
def test_ball_query_vs_python_ragged_cpu(self):
device = torch.device("cpu")
self._ball_query_vs_python_ragged_helper(device)
def test_ball_query_vs_python_ragged_cuda(self):
device = get_random_cuda_device()
self._ball_query_vs_python_ragged_helper(device)
def test_ball_query_output_simple(self):
device = get_random_cuda_device()
N, P1, P2, K = 5, 8, 16, 4
sphere = ico_sphere(level=2, device=device).extend(N)
points_1 = sample_points_from_meshes(sphere, P1)
points_2 = sample_points_from_meshes(sphere, P2) * 5.0
radius = 6.0
naive_out = self._ball_query_naive(
points_1, points_2, lengths1=None, lengths2=None, K=K, radius=radius
)
cuda_out = ball_query(points_1, points_2, K=K, radius=radius)
# All points should have N sample neighbors as radius is large
# Zero is a valid index but can only be present once (i.e. no zero padding)
naive_out_zeros = (naive_out.idx == 0).sum(dim=-1).max()
cuda_out_zeros = (cuda_out.idx == 0).sum(dim=-1).max()
self.assertTrue(naive_out_zeros == 0 or naive_out_zeros == 1)
self.assertTrue(cuda_out_zeros == 0 or cuda_out_zeros == 1)
# All points should now have zero sample neighbors as radius is small
radius = 0.5
naive_out = self._ball_query_naive(
points_1, points_2, lengths1=None, lengths2=None, K=K, radius=radius
)
cuda_out = ball_query(points_1, points_2, K=K, radius=radius)
naive_out_allzeros = (naive_out.idx == -1).all()
cuda_out_allzeros = (cuda_out.idx == -1).sum()
self.assertTrue(naive_out_allzeros)
self.assertTrue(cuda_out_allzeros)
@staticmethod
def ball_query_square(
N: int, P1: int, P2: int, D: int, K: int, radius: float, device: str
):
device = torch.device(device)
pts1 = torch.randn(N, P1, D, device=device, requires_grad=True)
pts2 = torch.randn(N, P2, D, device=device, requires_grad=True)
grad_dists = torch.randn(N, P1, K, device=device)
torch.cuda.synchronize()
def output():
out = ball_query(pts1, pts2, K=K, radius=radius)
loss = (out.dists * grad_dists).sum()
loss.backward()
torch.cuda.synchronize()
return output
@staticmethod
def ball_query_ragged(
N: int, P1: int, P2: int, D: int, K: int, radius: float, device: str
):
device = torch.device(device)
pts1 = torch.rand((N, P1, D), device=device, requires_grad=True)
pts2 = torch.rand((N, P2, D), device=device, requires_grad=True)
lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device)
grad_dists = torch.randn(N, P1, K, device=device)
torch.cuda.synchronize()
def output():
out = ball_query(
pts1, pts2, lengths1=lengths1, lengths2=lengths2, K=K, radius=radius
)
loss = (out.dists * grad_dists).sum()
loss.backward()
torch.cuda.synchronize()
return output