mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-03 02:35:58 +08:00
Ball Query
Summary: Implementation of ball query from PointNet++. This function is similar to KNN (find the neighbors in p2 for all points in p1). These are the key differences: - It will return the **first** K neighbors within a specified radius as opposed to the **closest** K neighbors. - As all the points in p2 do not need to be considered to find the closest K, the algorithm is much faster than KNN when p2 has a large number of points. - The neighbors are not sorted - Due to the radius threshold it is not guaranteed that there will be K neighbors even if there are more than K points in p2. - The padding value for `idx` is -1 instead of 0. # Note: - Some of the code is very similar to KNN so it could be possible to modify the KNN forward kernels to support ball query. - Some users might want to use kNN with ball query - for this we could provide a wrapper function around the current `knn_points` which enables applying the radius threshold afterwards as an alternative. This could be called `ball_query_knn`. Reviewed By: jcjohnson Differential Revision: D30261362 fbshipit-source-id: 66b6a7e0114beff7164daf7eba21546ff41ec450
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e5c58a8a8b
commit
103da63393
91
pytorch3d/csrc/ball_query/ball_query.h
Normal file
91
pytorch3d/csrc/ball_query/ball_query.h
Normal file
@@ -0,0 +1,91 @@
|
||||
/*
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
#include <tuple>
|
||||
#include "utils/pytorch3d_cutils.h"
|
||||
|
||||
// Compute indices of K neighbors in pointcloud p2 to points
|
||||
// in pointcloud p1 which fall within a specified radius
|
||||
//
|
||||
// Args:
|
||||
// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
|
||||
// containing P1 points of dimension D.
|
||||
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
|
||||
// containing P2 points of dimension D.
|
||||
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
|
||||
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
|
||||
// K: Integer giving the upper bound on the number of samples to take
|
||||
// within the radius
|
||||
// radius: the radius around each point within which the neighbors need to be
|
||||
// located
|
||||
//
|
||||
// Returns:
|
||||
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
|
||||
// p1_neighbor_idx[n, i, k] = j means that the kth
|
||||
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
|
||||
// This is padded with -1s both where a cloud in p2 has fewer than
|
||||
// S points and where a cloud in p1 has fewer than P1 points and
|
||||
// also if there are fewer than K points which satisfy the radius
|
||||
// threshold.
|
||||
//
|
||||
// p1_neighbor_dists: FloatTensor of shape (N, P1, K) containing the squared
|
||||
// distance from each point p1[n, p, :] to its K neighbors
|
||||
// p2[n, p1_neighbor_idx[n, p, k], :].
|
||||
|
||||
// CPU implementation
|
||||
std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
|
||||
const at::Tensor& p1,
|
||||
const at::Tensor& p2,
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
const int K,
|
||||
const float radius);
|
||||
|
||||
// CUDA implementation
|
||||
std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
|
||||
const at::Tensor& p1,
|
||||
const at::Tensor& p2,
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
const int K,
|
||||
const float radius);
|
||||
|
||||
// Implementation which is exposed
|
||||
// Note: the backward pass reuses the KNearestNeighborBackward kernel
|
||||
inline std::tuple<at::Tensor, at::Tensor> BallQuery(
|
||||
const at::Tensor& p1,
|
||||
const at::Tensor& p2,
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
int K,
|
||||
float radius) {
|
||||
if (p1.is_cuda() || p2.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(p1);
|
||||
CHECK_CUDA(p2);
|
||||
return BallQueryCuda(
|
||||
p1.contiguous(),
|
||||
p2.contiguous(),
|
||||
lengths1.contiguous(),
|
||||
lengths2.contiguous(),
|
||||
K,
|
||||
radius);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return BallQueryCpu(
|
||||
p1.contiguous(),
|
||||
p2.contiguous(),
|
||||
lengths1.contiguous(),
|
||||
lengths2.contiguous(),
|
||||
K,
|
||||
radius);
|
||||
}
|
||||
Reference in New Issue
Block a user