mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Ball Query
Summary: Implementation of ball query from PointNet++. This function is similar to KNN (find the neighbors in p2 for all points in p1). These are the key differences: - It will return the **first** K neighbors within a specified radius as opposed to the **closest** K neighbors. - As all the points in p2 do not need to be considered to find the closest K, the algorithm is much faster than KNN when p2 has a large number of points. - The neighbors are not sorted - Due to the radius threshold it is not guaranteed that there will be K neighbors even if there are more than K points in p2. - The padding value for `idx` is -1 instead of 0. # Note: - Some of the code is very similar to KNN so it could be possible to modify the KNN forward kernels to support ball query. - Some users might want to use kNN with ball query - for this we could provide a wrapper function around the current `knn_points` which enables applying the radius threshold afterwards as an alternative. This could be called `ball_query_knn`. Reviewed By: jcjohnson Differential Revision: D30261362 fbshipit-source-id: 66b6a7e0114beff7164daf7eba21546ff41ec450
This commit is contained in:
parent
e5c58a8a8b
commit
103da63393
130
pytorch3d/csrc/ball_query/ball_query.cu
Normal file
130
pytorch3d/csrc/ball_query/ball_query.cu
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
* All rights reserved.
|
||||||
|
*
|
||||||
|
* This source code is licensed under the BSD-style license found in the
|
||||||
|
* LICENSE file in the root directory of this source tree.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include "utils/pytorch3d_cutils.h"
|
||||||
|
|
||||||
|
// A chunk of work is blocksize-many points of P1.
|
||||||
|
// The number of potential chunks to do is N*(1+(P1-1)/blocksize)
|
||||||
|
// call (1+(P1-1)/blocksize) chunks_per_cloud
|
||||||
|
// These chunks are divided among the gridSize-many blocks.
|
||||||
|
// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc .
|
||||||
|
// In chunk i, we work on cloud i/chunks_per_cloud on points starting from
|
||||||
|
// blocksize*(i%chunks_per_cloud).
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__global__ void BallQueryKernel(
|
||||||
|
const at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> p1,
|
||||||
|
const at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> p2,
|
||||||
|
const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits>
|
||||||
|
lengths1,
|
||||||
|
const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits>
|
||||||
|
lengths2,
|
||||||
|
at::PackedTensorAccessor64<int64_t, 3, at::RestrictPtrTraits> idxs,
|
||||||
|
at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> dists,
|
||||||
|
const int64_t K,
|
||||||
|
const float radius2) {
|
||||||
|
const int64_t N = p1.size(0);
|
||||||
|
const int64_t chunks_per_cloud = (1 + (p1.size(1) - 1) / blockDim.x);
|
||||||
|
const int64_t chunks_to_do = N * chunks_per_cloud;
|
||||||
|
const int D = p1.size(2);
|
||||||
|
|
||||||
|
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
|
||||||
|
const int64_t n = chunk / chunks_per_cloud; // batch_index
|
||||||
|
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
|
||||||
|
int64_t i = start_point + threadIdx.x;
|
||||||
|
|
||||||
|
// Check if point is valid in heterogeneous tensor
|
||||||
|
if (i >= lengths1[n]) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterate over points in p2 until desired count is reached or
|
||||||
|
// all points have been considered
|
||||||
|
for (int64_t j = 0, count = 0; j < lengths2[n] && count < K; ++j) {
|
||||||
|
// Calculate the distance between the points
|
||||||
|
scalar_t dist2 = 0.0;
|
||||||
|
for (int d = 0; d < D; ++d) {
|
||||||
|
scalar_t diff = p1[n][i][d] - p2[n][j][d];
|
||||||
|
dist2 += (diff * diff);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dist2 < radius2) {
|
||||||
|
// If the point is within the radius
|
||||||
|
// Set the value of the index to the point index
|
||||||
|
idxs[n][i][count] = j;
|
||||||
|
dists[n][i][count] = dist2;
|
||||||
|
|
||||||
|
// increment the number of selected samples for the point i
|
||||||
|
++count;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
|
||||||
|
const at::Tensor& p1, // (N, P1, 3)
|
||||||
|
const at::Tensor& p2, // (N, P2, 3)
|
||||||
|
const at::Tensor& lengths1, // (N,)
|
||||||
|
const at::Tensor& lengths2, // (N,)
|
||||||
|
int K,
|
||||||
|
float radius) {
|
||||||
|
// Check inputs are on the same device
|
||||||
|
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
|
||||||
|
lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4};
|
||||||
|
at::CheckedFrom c = "BallQueryCuda";
|
||||||
|
at::checkAllSameGPU(c, {p1_t, p2_t, lengths1_t, lengths2_t});
|
||||||
|
at::checkAllSameType(c, {p1_t, p2_t});
|
||||||
|
|
||||||
|
// Set the device for the kernel launch based on the device of p1
|
||||||
|
at::cuda::CUDAGuard device_guard(p1.device());
|
||||||
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
TORCH_CHECK(
|
||||||
|
p2.size(2) == p1.size(2), "Point sets must have the same last dimension");
|
||||||
|
|
||||||
|
const int N = p1.size(0);
|
||||||
|
const int P1 = p1.size(1);
|
||||||
|
const int64_t K_64 = K;
|
||||||
|
const float radius2 = radius * radius;
|
||||||
|
|
||||||
|
// Output tensor with indices of neighbors for each point in p1
|
||||||
|
auto long_dtype = lengths1.options().dtype(at::kLong);
|
||||||
|
auto idxs = at::full({N, P1, K}, -1, long_dtype);
|
||||||
|
auto dists = at::zeros({N, P1, K}, p1.options());
|
||||||
|
|
||||||
|
if (idxs.numel() == 0) {
|
||||||
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
|
return std::make_tuple(idxs, dists);
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t blocks = 256;
|
||||||
|
const size_t threads = 256;
|
||||||
|
|
||||||
|
AT_DISPATCH_FLOATING_TYPES(
|
||||||
|
p1.scalar_type(), "ball_query_kernel_cuda", ([&] {
|
||||||
|
BallQueryKernel<<<blocks, threads, 0, stream>>>(
|
||||||
|
p1.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
|
||||||
|
p2.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
|
||||||
|
lengths1.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>(),
|
||||||
|
lengths2.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>(),
|
||||||
|
idxs.packed_accessor64<int64_t, 3, at::RestrictPtrTraits>(),
|
||||||
|
dists.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
|
||||||
|
K_64,
|
||||||
|
radius2);
|
||||||
|
}));
|
||||||
|
|
||||||
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
|
return std::make_tuple(idxs, dists);
|
||||||
|
}
|
91
pytorch3d/csrc/ball_query/ball_query.h
Normal file
91
pytorch3d/csrc/ball_query/ball_query.h
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
* All rights reserved.
|
||||||
|
*
|
||||||
|
* This source code is licensed under the BSD-style license found in the
|
||||||
|
* LICENSE file in the root directory of this source tree.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include <tuple>
|
||||||
|
#include "utils/pytorch3d_cutils.h"
|
||||||
|
|
||||||
|
// Compute indices of K neighbors in pointcloud p2 to points
|
||||||
|
// in pointcloud p1 which fall within a specified radius
|
||||||
|
//
|
||||||
|
// Args:
|
||||||
|
// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
|
||||||
|
// containing P1 points of dimension D.
|
||||||
|
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
|
||||||
|
// containing P2 points of dimension D.
|
||||||
|
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
|
||||||
|
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
|
||||||
|
// K: Integer giving the upper bound on the number of samples to take
|
||||||
|
// within the radius
|
||||||
|
// radius: the radius around each point within which the neighbors need to be
|
||||||
|
// located
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
|
||||||
|
// p1_neighbor_idx[n, i, k] = j means that the kth
|
||||||
|
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
|
||||||
|
// This is padded with -1s both where a cloud in p2 has fewer than
|
||||||
|
// S points and where a cloud in p1 has fewer than P1 points and
|
||||||
|
// also if there are fewer than K points which satisfy the radius
|
||||||
|
// threshold.
|
||||||
|
//
|
||||||
|
// p1_neighbor_dists: FloatTensor of shape (N, P1, K) containing the squared
|
||||||
|
// distance from each point p1[n, p, :] to its K neighbors
|
||||||
|
// p2[n, p1_neighbor_idx[n, p, k], :].
|
||||||
|
|
||||||
|
// CPU implementation
|
||||||
|
std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
|
||||||
|
const at::Tensor& p1,
|
||||||
|
const at::Tensor& p2,
|
||||||
|
const at::Tensor& lengths1,
|
||||||
|
const at::Tensor& lengths2,
|
||||||
|
const int K,
|
||||||
|
const float radius);
|
||||||
|
|
||||||
|
// CUDA implementation
|
||||||
|
std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
|
||||||
|
const at::Tensor& p1,
|
||||||
|
const at::Tensor& p2,
|
||||||
|
const at::Tensor& lengths1,
|
||||||
|
const at::Tensor& lengths2,
|
||||||
|
const int K,
|
||||||
|
const float radius);
|
||||||
|
|
||||||
|
// Implementation which is exposed
|
||||||
|
// Note: the backward pass reuses the KNearestNeighborBackward kernel
|
||||||
|
inline std::tuple<at::Tensor, at::Tensor> BallQuery(
|
||||||
|
const at::Tensor& p1,
|
||||||
|
const at::Tensor& p2,
|
||||||
|
const at::Tensor& lengths1,
|
||||||
|
const at::Tensor& lengths2,
|
||||||
|
int K,
|
||||||
|
float radius) {
|
||||||
|
if (p1.is_cuda() || p2.is_cuda()) {
|
||||||
|
#ifdef WITH_CUDA
|
||||||
|
CHECK_CUDA(p1);
|
||||||
|
CHECK_CUDA(p2);
|
||||||
|
return BallQueryCuda(
|
||||||
|
p1.contiguous(),
|
||||||
|
p2.contiguous(),
|
||||||
|
lengths1.contiguous(),
|
||||||
|
lengths2.contiguous(),
|
||||||
|
K,
|
||||||
|
radius);
|
||||||
|
#else
|
||||||
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
return BallQueryCpu(
|
||||||
|
p1.contiguous(),
|
||||||
|
p2.contiguous(),
|
||||||
|
lengths1.contiguous(),
|
||||||
|
lengths2.contiguous(),
|
||||||
|
K,
|
||||||
|
radius);
|
||||||
|
}
|
55
pytorch3d/csrc/ball_query/ball_query_cpu.cpp
Normal file
55
pytorch3d/csrc/ball_query/ball_query_cpu.cpp
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
* All rights reserved.
|
||||||
|
*
|
||||||
|
* This source code is licensed under the BSD-style license found in the
|
||||||
|
* LICENSE file in the root directory of this source tree.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include <queue>
|
||||||
|
#include <tuple>
|
||||||
|
|
||||||
|
std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
|
||||||
|
const at::Tensor& p1,
|
||||||
|
const at::Tensor& p2,
|
||||||
|
const at::Tensor& lengths1,
|
||||||
|
const at::Tensor& lengths2,
|
||||||
|
int K,
|
||||||
|
float radius) {
|
||||||
|
const int N = p1.size(0);
|
||||||
|
const int P1 = p1.size(1);
|
||||||
|
const int D = p1.size(2);
|
||||||
|
|
||||||
|
auto long_opts = lengths1.options().dtype(torch::kInt64);
|
||||||
|
torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts);
|
||||||
|
torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());
|
||||||
|
const float radius2 = radius * radius;
|
||||||
|
|
||||||
|
auto p1_a = p1.accessor<float, 3>();
|
||||||
|
auto p2_a = p2.accessor<float, 3>();
|
||||||
|
auto lengths1_a = lengths1.accessor<int64_t, 1>();
|
||||||
|
auto lengths2_a = lengths2.accessor<int64_t, 1>();
|
||||||
|
auto idxs_a = idxs.accessor<int64_t, 3>();
|
||||||
|
auto dists_a = dists.accessor<float, 3>();
|
||||||
|
|
||||||
|
for (int n = 0; n < N; ++n) {
|
||||||
|
const int64_t length1 = lengths1_a[n];
|
||||||
|
const int64_t length2 = lengths2_a[n];
|
||||||
|
for (int64_t i = 0; i < length1; ++i) {
|
||||||
|
for (int64_t j = 0, count = 0; j < length2 && count < K; ++j) {
|
||||||
|
float dist2 = 0;
|
||||||
|
for (int d = 0; d < D; ++d) {
|
||||||
|
float diff = p1_a[n][i][d] - p2_a[n][j][d];
|
||||||
|
dist2 += diff * diff;
|
||||||
|
}
|
||||||
|
if (dist2 < radius2) {
|
||||||
|
dists_a[n][i][count] = dist2;
|
||||||
|
idxs_a[n][i][count] = j;
|
||||||
|
++count;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::make_tuple(idxs, dists);
|
||||||
|
}
|
@ -12,6 +12,7 @@
|
|||||||
// clang-format on
|
// clang-format on
|
||||||
#include "./pulsar/pytorch/renderer.h"
|
#include "./pulsar/pytorch/renderer.h"
|
||||||
#include "./pulsar/pytorch/tensor_util.h"
|
#include "./pulsar/pytorch/tensor_util.h"
|
||||||
|
#include "ball_query/ball_query.h"
|
||||||
#include "blending/sigmoid_alpha_blend.h"
|
#include "blending/sigmoid_alpha_blend.h"
|
||||||
#include "compositing/alpha_composite.h"
|
#include "compositing/alpha_composite.h"
|
||||||
#include "compositing/norm_weighted_sum.h"
|
#include "compositing/norm_weighted_sum.h"
|
||||||
@ -38,6 +39,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
#endif
|
#endif
|
||||||
m.def("knn_points_idx", &KNearestNeighborIdx);
|
m.def("knn_points_idx", &KNearestNeighborIdx);
|
||||||
m.def("knn_points_backward", &KNearestNeighborBackward);
|
m.def("knn_points_backward", &KNearestNeighborBackward);
|
||||||
|
|
||||||
|
// Ball Query
|
||||||
|
m.def("ball_query", &BallQuery);
|
||||||
m.def(
|
m.def(
|
||||||
"mesh_normal_consistency_find_verts", &MeshNormalConsistencyFindVertices);
|
"mesh_normal_consistency_find_verts", &MeshNormalConsistencyFindVertices);
|
||||||
m.def("gather_scatter", &GatherScatter);
|
m.def("gather_scatter", &GatherScatter);
|
||||||
|
@ -477,6 +477,10 @@ __global__ void KNearestNeighborBackwardKernel(
|
|||||||
const float grad_dist = grad_dists[n * P1 * K + p1_idx * K + k];
|
const float grad_dist = grad_dists[n * P1 * K + p1_idx * K + k];
|
||||||
// index of point in p2 corresponding to the k-th nearest neighbor
|
// index of point in p2 corresponding to the k-th nearest neighbor
|
||||||
const size_t p2_idx = idxs[n * P1 * K + p1_idx * K + k];
|
const size_t p2_idx = idxs[n * P1 * K + p1_idx * K + k];
|
||||||
|
// If the index is the pad value of -1 then ignore it
|
||||||
|
if (p2_idx == -1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
const float diff = 2.0 * grad_dist *
|
const float diff = 2.0 * grad_dist *
|
||||||
(p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
|
(p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
|
||||||
atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff);
|
atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff);
|
||||||
|
@ -99,6 +99,10 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
|
|||||||
for (int64_t i1 = 0; i1 < length1; ++i1) {
|
for (int64_t i1 = 0; i1 < length1; ++i1) {
|
||||||
for (int64_t k = 0; k < length2; ++k) {
|
for (int64_t k = 0; k < length2; ++k) {
|
||||||
const int64_t i2 = idxs_a[n][i1][k];
|
const int64_t i2 = idxs_a[n][i1][k];
|
||||||
|
// If the index is the pad value of -1 then ignore it
|
||||||
|
if (i2 == -1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
for (int64_t d = 0; d < D; ++d) {
|
for (int64_t d = 0; d < D; ++d) {
|
||||||
const float diff =
|
const float diff =
|
||||||
2.0f * grad_dists_a[n][i1][k] * (p1_a[n][i1][d] - p2_a[n][i2][d]);
|
2.0f * grad_dists_a[n][i1][k] * (p1_a[n][i1][d] - p2_a[n][i2][d]);
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from .ball_query import ball_query
|
||||||
from .cameras_alignment import corresponding_cameras_alignment
|
from .cameras_alignment import corresponding_cameras_alignment
|
||||||
from .cubify import cubify
|
from .cubify import cubify
|
||||||
from .graph_conv import GraphConv
|
from .graph_conv import GraphConv
|
||||||
@ -34,5 +35,4 @@ from .utils import (
|
|||||||
)
|
)
|
||||||
from .vert_align import vert_align
|
from .vert_align import vert_align
|
||||||
|
|
||||||
|
|
||||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||||
|
150
pytorch3d/ops/ball_query.py
Normal file
150
pytorch3d/ops/ball_query.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pytorch3d import _C
|
||||||
|
from torch.autograd import Function
|
||||||
|
from torch.autograd.function import once_differentiable
|
||||||
|
|
||||||
|
from .knn import _KNN
|
||||||
|
|
||||||
|
|
||||||
|
class _ball_query(Function):
|
||||||
|
"""
|
||||||
|
Torch autograd Function wrapper for Ball Query C++/CUDA implementations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, p1, p2, lengths1, lengths2, K, radius):
|
||||||
|
"""
|
||||||
|
Arguments defintions the same as in the ball_query function
|
||||||
|
"""
|
||||||
|
idx, dists = _C.ball_query(p1, p2, lengths1, lengths2, K, radius)
|
||||||
|
ctx.save_for_backward(p1, p2, lengths1, lengths2, idx)
|
||||||
|
ctx.mark_non_differentiable(idx)
|
||||||
|
return dists, idx
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@once_differentiable
|
||||||
|
def backward(ctx, grad_dists, grad_idx):
|
||||||
|
p1, p2, lengths1, lengths2, idx = ctx.saved_tensors
|
||||||
|
# TODO(gkioxari) Change cast to floats once we add support for doubles.
|
||||||
|
if not (grad_dists.dtype == torch.float32):
|
||||||
|
grad_dists = grad_dists.float()
|
||||||
|
if not (p1.dtype == torch.float32):
|
||||||
|
p1 = p1.float()
|
||||||
|
if not (p2.dtype == torch.float32):
|
||||||
|
p2 = p2.float()
|
||||||
|
|
||||||
|
# Reuse the KNN backward function
|
||||||
|
grad_p1, grad_p2 = _C.knn_points_backward(
|
||||||
|
p1, p2, lengths1, lengths2, idx, grad_dists
|
||||||
|
)
|
||||||
|
return grad_p1, grad_p2, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
def ball_query(
|
||||||
|
p1: torch.Tensor,
|
||||||
|
p2: torch.Tensor,
|
||||||
|
lengths1: Union[torch.Tensor, None] = None,
|
||||||
|
lengths2: Union[torch.Tensor, None] = None,
|
||||||
|
K: int = 500,
|
||||||
|
radius: float = 0.2,
|
||||||
|
return_nn: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Ball Query is an alternative to KNN. It can be
|
||||||
|
used to find all points in p2 that are within a specified radius
|
||||||
|
to the query point in p1 (with an upper limit of K neighbors).
|
||||||
|
|
||||||
|
The neighbors returned are not necssarily the *nearest* to the
|
||||||
|
point in p1, just the first K values in p2 which are within the
|
||||||
|
specified radius.
|
||||||
|
|
||||||
|
This method is faster than kNN when there are large numbers of points
|
||||||
|
in p2 and the ordering of neighbors is not important compared to the
|
||||||
|
distance being within the radius threshold.
|
||||||
|
|
||||||
|
"Ball query’s local neighborhood guarantees a fixed region scale thus
|
||||||
|
making local region features more generalizable across space, which is
|
||||||
|
preferred for tasks requiring local pattern recognition
|
||||||
|
(e.g. semantic point labeling)" [1].
|
||||||
|
|
||||||
|
[1] Charles R. Qi et al, "PointNet++: Deep Hierarchical Feature Learning
|
||||||
|
on Point Sets in a Metric Space", NeurIPS 2017.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each
|
||||||
|
containing up to P1 points of dimension D. These represent the centers of
|
||||||
|
the ball queries.
|
||||||
|
p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each
|
||||||
|
containing up to P2 points of dimension D.
|
||||||
|
lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
|
||||||
|
length of each pointcloud in p1. Or None to indicate that every cloud has
|
||||||
|
length P1.
|
||||||
|
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
|
||||||
|
length of each pointcloud in p2. Or None to indicate that every cloud has
|
||||||
|
length P2.
|
||||||
|
K: Integer giving the upper bound on the number of samples to take
|
||||||
|
within the radius
|
||||||
|
radius: the radius around each point within which the neighbors need to be located
|
||||||
|
return_nn: If set to True returns the K neighbor points in p2 for each point in p1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dists: Tensor of shape (N, P1, K) giving the squared distances to
|
||||||
|
the neighbors. This is padded with zeros both where a cloud in p2
|
||||||
|
has fewer than S points and where a cloud in p1 has fewer than P1 points
|
||||||
|
and also if there are fewer than K points which satisfy the radius threshold.
|
||||||
|
|
||||||
|
idx: LongTensor of shape (N, P1, K) giving the indices of the
|
||||||
|
S neighbors in p2 for points in p1.
|
||||||
|
Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th
|
||||||
|
neighbor to `p1[n, i]` in `p2[n]`. This is padded with -1 both where a cloud
|
||||||
|
in p2 has fewer than S points and where a cloud in p1 has fewer than P1
|
||||||
|
points and also if there are fewer than K points which satisfy the radius threshold.
|
||||||
|
|
||||||
|
nn: Tensor of shape (N, P1, K, D) giving the K neighbors in p2 for
|
||||||
|
each point in p1. Concretely, `p2_nn[n, i, k]` gives the k-th neighbor
|
||||||
|
for `p1[n, i]`. Returned if `return_nn` is True. The output is a tensor
|
||||||
|
of shape (N, P1, K, U).
|
||||||
|
|
||||||
|
"""
|
||||||
|
if p1.shape[0] != p2.shape[0]:
|
||||||
|
raise ValueError("pts1 and pts2 must have the same batch dimension.")
|
||||||
|
if p1.shape[2] != p2.shape[2]:
|
||||||
|
raise ValueError("pts1 and pts2 must have the same point dimension.")
|
||||||
|
|
||||||
|
p1 = p1.contiguous()
|
||||||
|
p2 = p2.contiguous()
|
||||||
|
P1 = p1.shape[1]
|
||||||
|
P2 = p2.shape[1]
|
||||||
|
D = p2.shape[2]
|
||||||
|
N = p1.shape[0]
|
||||||
|
|
||||||
|
if lengths1 is None:
|
||||||
|
lengths1 = torch.full((N,), P1, dtype=torch.int64, device=p1.device)
|
||||||
|
if lengths2 is None:
|
||||||
|
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device)
|
||||||
|
|
||||||
|
# pyre-fixme[16]: `_ball_query` has no attribute `apply`.
|
||||||
|
dists, idx = _ball_query.apply(p1, p2, lengths1, lengths2, K, radius)
|
||||||
|
|
||||||
|
# Gather the neighbors if needed
|
||||||
|
points_nn = None
|
||||||
|
if return_nn:
|
||||||
|
idx_expanded = idx[:, :, :, None].expand(-1, -1, -1, D)
|
||||||
|
idx_mask = idx_expanded.eq(-1)
|
||||||
|
idx_new = idx_expanded.clone()
|
||||||
|
# Replace -1 values with 0 for gather
|
||||||
|
idx_new[idx_mask] = 0
|
||||||
|
# Gather points from p2
|
||||||
|
points_nn = p2[:, :, None].expand(-1, -1, K, -1).gather(1, idx_new)
|
||||||
|
# Replace padded values
|
||||||
|
points_nn[idx_mask] = 0.0
|
||||||
|
|
||||||
|
return _KNN(dists=dists, idx=idx, knn=points_nn)
|
40
tests/bm_ball_query.py
Normal file
40
tests/bm_ball_query.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from itertools import product
|
||||||
|
|
||||||
|
from fvcore.common.benchmark import benchmark
|
||||||
|
from test_ball_query import TestBallQuery
|
||||||
|
|
||||||
|
|
||||||
|
def bm_ball_query() -> None:
|
||||||
|
|
||||||
|
backends = ["cpu", "cuda:0"]
|
||||||
|
|
||||||
|
kwargs_list = []
|
||||||
|
Ns = [32]
|
||||||
|
P1s = [256]
|
||||||
|
P2s = [128, 512]
|
||||||
|
Ds = [3, 10]
|
||||||
|
Ks = [3, 24, 100]
|
||||||
|
Rs = [0.1, 0.2, 5]
|
||||||
|
test_cases = product(Ns, P1s, P2s, Ds, Ks, Rs, backends)
|
||||||
|
for case in test_cases:
|
||||||
|
N, P1, P2, D, K, R, b = case
|
||||||
|
kwargs_list.append(
|
||||||
|
{"N": N, "P1": P1, "P2": P2, "D": D, "K": K, "radius": R, "device": b}
|
||||||
|
)
|
||||||
|
|
||||||
|
benchmark(
|
||||||
|
TestBallQuery.ball_query_square, "BALLQUERY_SQUARE", kwargs_list, warmup_iters=1
|
||||||
|
)
|
||||||
|
benchmark(
|
||||||
|
TestBallQuery.ball_query_ragged, "BALLQUERY_RAGGED", kwargs_list, warmup_iters=1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
bm_ball_query()
|
230
tests/test_ball_query.py
Normal file
230
tests/test_ball_query.py
Normal file
@ -0,0 +1,230 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from itertools import product
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||||
|
from pytorch3d.ops import sample_points_from_meshes
|
||||||
|
from pytorch3d.ops.ball_query import ball_query
|
||||||
|
from pytorch3d.ops.knn import _KNN
|
||||||
|
from pytorch3d.utils import ico_sphere
|
||||||
|
|
||||||
|
|
||||||
|
class TestBallQuery(TestCaseMixin, unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
super().setUp()
|
||||||
|
torch.manual_seed(1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ball_query_naive(
|
||||||
|
p1, p2, lengths1, lengths2, K: int, radius: float
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Naive PyTorch implementation of ball query.
|
||||||
|
"""
|
||||||
|
N, P1, D = p1.shape
|
||||||
|
_N, P2, _D = p2.shape
|
||||||
|
|
||||||
|
assert N == _N and D == _D
|
||||||
|
|
||||||
|
if lengths1 is None:
|
||||||
|
lengths1 = torch.full((N,), P1, dtype=torch.int64, device=p1.device)
|
||||||
|
if lengths2 is None:
|
||||||
|
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device)
|
||||||
|
|
||||||
|
radius2 = radius * radius
|
||||||
|
dists = torch.zeros((N, P1, K), dtype=torch.float32, device=p1.device)
|
||||||
|
idx = torch.full((N, P1, K), fill_value=-1, dtype=torch.int64, device=p1.device)
|
||||||
|
|
||||||
|
# Iterate through the batches
|
||||||
|
for n in range(N):
|
||||||
|
num1 = lengths1[n].item()
|
||||||
|
num2 = lengths2[n].item()
|
||||||
|
|
||||||
|
# Iterate through the points in the p1
|
||||||
|
for i in range(num1):
|
||||||
|
# Iterate through the points in the p2
|
||||||
|
count = 0
|
||||||
|
for j in range(num2):
|
||||||
|
dist = p2[n, j] - p1[n, i]
|
||||||
|
dist2 = (dist * dist).sum()
|
||||||
|
if dist2 < radius2 and count < K:
|
||||||
|
dists[n, i, count] = dist2
|
||||||
|
idx[n, i, count] = j
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
return _KNN(dists=dists, idx=idx, knn=None)
|
||||||
|
|
||||||
|
def _ball_query_vs_python_square_helper(self, device):
|
||||||
|
Ns = [1, 4]
|
||||||
|
Ds = [3, 5, 8]
|
||||||
|
P1s = [8, 24]
|
||||||
|
P2s = [8, 16, 32]
|
||||||
|
Ks = [1, 5]
|
||||||
|
Rs = [3, 5]
|
||||||
|
factors = [Ns, Ds, P1s, P2s, Ks, Rs]
|
||||||
|
for N, D, P1, P2, K, R in product(*factors):
|
||||||
|
x = torch.randn(N, P1, D, device=device, requires_grad=True)
|
||||||
|
x_cuda = x.clone().detach()
|
||||||
|
x_cuda.requires_grad_(True)
|
||||||
|
y = torch.randn(N, P2, D, device=device, requires_grad=True)
|
||||||
|
y_cuda = y.clone().detach()
|
||||||
|
y_cuda.requires_grad_(True)
|
||||||
|
|
||||||
|
# forward
|
||||||
|
out1 = self._ball_query_naive(
|
||||||
|
x, y, lengths1=None, lengths2=None, K=K, radius=R
|
||||||
|
)
|
||||||
|
out2 = ball_query(x_cuda, y_cuda, K=K, radius=R)
|
||||||
|
|
||||||
|
# Check dists
|
||||||
|
self.assertClose(out1.dists, out2.dists)
|
||||||
|
# Check idx
|
||||||
|
self.assertTrue(torch.all(out1.idx == out2.idx))
|
||||||
|
|
||||||
|
# backward
|
||||||
|
grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device)
|
||||||
|
loss1 = (out1.dists * grad_dist).sum()
|
||||||
|
loss1.backward()
|
||||||
|
loss2 = (out2.dists * grad_dist).sum()
|
||||||
|
loss2.backward()
|
||||||
|
|
||||||
|
self.assertClose(x_cuda.grad, x.grad, atol=5e-6)
|
||||||
|
self.assertClose(y_cuda.grad, y.grad, atol=5e-6)
|
||||||
|
|
||||||
|
def test_ball_query_vs_python_square_cpu(self):
|
||||||
|
device = torch.device("cpu")
|
||||||
|
self._ball_query_vs_python_square_helper(device)
|
||||||
|
|
||||||
|
def test_ball_query_vs_python_square_cuda(self):
|
||||||
|
device = get_random_cuda_device()
|
||||||
|
self._ball_query_vs_python_square_helper(device)
|
||||||
|
|
||||||
|
def _ball_query_vs_python_ragged_helper(self, device):
|
||||||
|
Ns = [1, 4]
|
||||||
|
Ds = [3, 5, 8]
|
||||||
|
P1s = [8, 24]
|
||||||
|
P2s = [8, 16, 32]
|
||||||
|
Ks = [2, 3, 10]
|
||||||
|
Rs = [1.4, 5] # radius
|
||||||
|
factors = [Ns, Ds, P1s, P2s, Ks, Rs]
|
||||||
|
for N, D, P1, P2, K, R in product(*factors):
|
||||||
|
x = torch.rand((N, P1, D), device=device, requires_grad=True)
|
||||||
|
y = torch.rand((N, P2, D), device=device, requires_grad=True)
|
||||||
|
lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
|
||||||
|
lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device)
|
||||||
|
|
||||||
|
x_csrc = x.clone().detach()
|
||||||
|
x_csrc.requires_grad_(True)
|
||||||
|
y_csrc = y.clone().detach()
|
||||||
|
y_csrc.requires_grad_(True)
|
||||||
|
|
||||||
|
# forward
|
||||||
|
out1 = self._ball_query_naive(
|
||||||
|
x, y, lengths1=lengths1, lengths2=lengths2, K=K, radius=R
|
||||||
|
)
|
||||||
|
out2 = ball_query(
|
||||||
|
x_csrc,
|
||||||
|
y_csrc,
|
||||||
|
lengths1=lengths1,
|
||||||
|
lengths2=lengths2,
|
||||||
|
K=K,
|
||||||
|
radius=R,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertClose(out1.idx, out2.idx)
|
||||||
|
self.assertClose(out1.dists, out2.dists)
|
||||||
|
|
||||||
|
# backward
|
||||||
|
grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device)
|
||||||
|
loss1 = (out1.dists * grad_dist).sum()
|
||||||
|
loss1.backward()
|
||||||
|
loss2 = (out2.dists * grad_dist).sum()
|
||||||
|
loss2.backward()
|
||||||
|
|
||||||
|
self.assertClose(x_csrc.grad, x.grad, atol=5e-6)
|
||||||
|
self.assertClose(y_csrc.grad, y.grad, atol=5e-6)
|
||||||
|
|
||||||
|
def test_ball_query_vs_python_ragged_cpu(self):
|
||||||
|
device = torch.device("cpu")
|
||||||
|
self._ball_query_vs_python_ragged_helper(device)
|
||||||
|
|
||||||
|
def test_ball_query_vs_python_ragged_cuda(self):
|
||||||
|
device = get_random_cuda_device()
|
||||||
|
self._ball_query_vs_python_ragged_helper(device)
|
||||||
|
|
||||||
|
def test_ball_query_output_simple(self):
|
||||||
|
device = get_random_cuda_device()
|
||||||
|
N, P1, P2, K = 5, 8, 16, 4
|
||||||
|
sphere = ico_sphere(level=2, device=device).extend(N)
|
||||||
|
points_1 = sample_points_from_meshes(sphere, P1)
|
||||||
|
points_2 = sample_points_from_meshes(sphere, P2) * 5.0
|
||||||
|
radius = 6.0
|
||||||
|
|
||||||
|
naive_out = self._ball_query_naive(
|
||||||
|
points_1, points_2, lengths1=None, lengths2=None, K=K, radius=radius
|
||||||
|
)
|
||||||
|
cuda_out = ball_query(points_1, points_2, K=K, radius=radius)
|
||||||
|
|
||||||
|
# All points should have N sample neighbors as radius is large
|
||||||
|
# Zero is a valid index but can only be present once (i.e. no zero padding)
|
||||||
|
naive_out_zeros = (naive_out.idx == 0).sum(dim=-1).max()
|
||||||
|
cuda_out_zeros = (cuda_out.idx == 0).sum(dim=-1).max()
|
||||||
|
self.assertTrue(naive_out_zeros == 0 or naive_out_zeros == 1)
|
||||||
|
self.assertTrue(cuda_out_zeros == 0 or cuda_out_zeros == 1)
|
||||||
|
|
||||||
|
# All points should now have zero sample neighbors as radius is small
|
||||||
|
radius = 0.5
|
||||||
|
naive_out = self._ball_query_naive(
|
||||||
|
points_1, points_2, lengths1=None, lengths2=None, K=K, radius=radius
|
||||||
|
)
|
||||||
|
cuda_out = ball_query(points_1, points_2, K=K, radius=radius)
|
||||||
|
naive_out_allzeros = (naive_out.idx == -1).all()
|
||||||
|
cuda_out_allzeros = (cuda_out.idx == -1).sum()
|
||||||
|
self.assertTrue(naive_out_allzeros)
|
||||||
|
self.assertTrue(cuda_out_allzeros)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def ball_query_square(
|
||||||
|
N: int, P1: int, P2: int, D: int, K: int, radius: float, device: str
|
||||||
|
):
|
||||||
|
device = torch.device(device)
|
||||||
|
pts1 = torch.randn(N, P1, D, device=device, requires_grad=True)
|
||||||
|
pts2 = torch.randn(N, P2, D, device=device, requires_grad=True)
|
||||||
|
grad_dists = torch.randn(N, P1, K, device=device)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
def output():
|
||||||
|
out = ball_query(pts1, pts2, K=K, radius=radius)
|
||||||
|
loss = (out.dists * grad_dists).sum()
|
||||||
|
loss.backward()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def ball_query_ragged(
|
||||||
|
N: int, P1: int, P2: int, D: int, K: int, radius: float, device: str
|
||||||
|
):
|
||||||
|
device = torch.device(device)
|
||||||
|
pts1 = torch.rand((N, P1, D), device=device, requires_grad=True)
|
||||||
|
pts2 = torch.rand((N, P2, D), device=device, requires_grad=True)
|
||||||
|
lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
|
||||||
|
lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device)
|
||||||
|
grad_dists = torch.randn(N, P1, K, device=device)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
def output():
|
||||||
|
out = ball_query(
|
||||||
|
pts1, pts2, lengths1=lengths1, lengths2=lengths2, K=K, radius=radius
|
||||||
|
)
|
||||||
|
loss = (out.dists * grad_dists).sum()
|
||||||
|
loss.backward()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
return output
|
Loading…
x
Reference in New Issue
Block a user