mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-22 23:30:35 +08:00
Ball Query
Summary: Implementation of ball query from PointNet++. This function is similar to KNN (find the neighbors in p2 for all points in p1). These are the key differences: - It will return the **first** K neighbors within a specified radius as opposed to the **closest** K neighbors. - As all the points in p2 do not need to be considered to find the closest K, the algorithm is much faster than KNN when p2 has a large number of points. - The neighbors are not sorted - Due to the radius threshold it is not guaranteed that there will be K neighbors even if there are more than K points in p2. - The padding value for `idx` is -1 instead of 0. # Note: - Some of the code is very similar to KNN so it could be possible to modify the KNN forward kernels to support ball query. - Some users might want to use kNN with ball query - for this we could provide a wrapper function around the current `knn_points` which enables applying the radius threshold afterwards as an alternative. This could be called `ball_query_knn`. Reviewed By: jcjohnson Differential Revision: D30261362 fbshipit-source-id: 66b6a7e0114beff7164daf7eba21546ff41ec450
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e5c58a8a8b
commit
103da63393
@@ -4,6 +4,7 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .ball_query import ball_query
|
||||
from .cameras_alignment import corresponding_cameras_alignment
|
||||
from .cubify import cubify
|
||||
from .graph_conv import GraphConv
|
||||
@@ -34,5 +35,4 @@ from .utils import (
|
||||
)
|
||||
from .vert_align import vert_align
|
||||
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
|
||||
150
pytorch3d/ops/ball_query.py
Normal file
150
pytorch3d/ops/ball_query.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from pytorch3d import _C
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
|
||||
from .knn import _KNN
|
||||
|
||||
|
||||
class _ball_query(Function):
|
||||
"""
|
||||
Torch autograd Function wrapper for Ball Query C++/CUDA implementations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, p1, p2, lengths1, lengths2, K, radius):
|
||||
"""
|
||||
Arguments defintions the same as in the ball_query function
|
||||
"""
|
||||
idx, dists = _C.ball_query(p1, p2, lengths1, lengths2, K, radius)
|
||||
ctx.save_for_backward(p1, p2, lengths1, lengths2, idx)
|
||||
ctx.mark_non_differentiable(idx)
|
||||
return dists, idx
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, grad_dists, grad_idx):
|
||||
p1, p2, lengths1, lengths2, idx = ctx.saved_tensors
|
||||
# TODO(gkioxari) Change cast to floats once we add support for doubles.
|
||||
if not (grad_dists.dtype == torch.float32):
|
||||
grad_dists = grad_dists.float()
|
||||
if not (p1.dtype == torch.float32):
|
||||
p1 = p1.float()
|
||||
if not (p2.dtype == torch.float32):
|
||||
p2 = p2.float()
|
||||
|
||||
# Reuse the KNN backward function
|
||||
grad_p1, grad_p2 = _C.knn_points_backward(
|
||||
p1, p2, lengths1, lengths2, idx, grad_dists
|
||||
)
|
||||
return grad_p1, grad_p2, None, None, None, None
|
||||
|
||||
|
||||
def ball_query(
|
||||
p1: torch.Tensor,
|
||||
p2: torch.Tensor,
|
||||
lengths1: Union[torch.Tensor, None] = None,
|
||||
lengths2: Union[torch.Tensor, None] = None,
|
||||
K: int = 500,
|
||||
radius: float = 0.2,
|
||||
return_nn: bool = True,
|
||||
):
|
||||
"""
|
||||
Ball Query is an alternative to KNN. It can be
|
||||
used to find all points in p2 that are within a specified radius
|
||||
to the query point in p1 (with an upper limit of K neighbors).
|
||||
|
||||
The neighbors returned are not necssarily the *nearest* to the
|
||||
point in p1, just the first K values in p2 which are within the
|
||||
specified radius.
|
||||
|
||||
This method is faster than kNN when there are large numbers of points
|
||||
in p2 and the ordering of neighbors is not important compared to the
|
||||
distance being within the radius threshold.
|
||||
|
||||
"Ball query’s local neighborhood guarantees a fixed region scale thus
|
||||
making local region features more generalizable across space, which is
|
||||
preferred for tasks requiring local pattern recognition
|
||||
(e.g. semantic point labeling)" [1].
|
||||
|
||||
[1] Charles R. Qi et al, "PointNet++: Deep Hierarchical Feature Learning
|
||||
on Point Sets in a Metric Space", NeurIPS 2017.
|
||||
|
||||
Args:
|
||||
p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each
|
||||
containing up to P1 points of dimension D. These represent the centers of
|
||||
the ball queries.
|
||||
p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each
|
||||
containing up to P2 points of dimension D.
|
||||
lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
|
||||
length of each pointcloud in p1. Or None to indicate that every cloud has
|
||||
length P1.
|
||||
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
|
||||
length of each pointcloud in p2. Or None to indicate that every cloud has
|
||||
length P2.
|
||||
K: Integer giving the upper bound on the number of samples to take
|
||||
within the radius
|
||||
radius: the radius around each point within which the neighbors need to be located
|
||||
return_nn: If set to True returns the K neighbor points in p2 for each point in p1.
|
||||
|
||||
Returns:
|
||||
dists: Tensor of shape (N, P1, K) giving the squared distances to
|
||||
the neighbors. This is padded with zeros both where a cloud in p2
|
||||
has fewer than S points and where a cloud in p1 has fewer than P1 points
|
||||
and also if there are fewer than K points which satisfy the radius threshold.
|
||||
|
||||
idx: LongTensor of shape (N, P1, K) giving the indices of the
|
||||
S neighbors in p2 for points in p1.
|
||||
Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th
|
||||
neighbor to `p1[n, i]` in `p2[n]`. This is padded with -1 both where a cloud
|
||||
in p2 has fewer than S points and where a cloud in p1 has fewer than P1
|
||||
points and also if there are fewer than K points which satisfy the radius threshold.
|
||||
|
||||
nn: Tensor of shape (N, P1, K, D) giving the K neighbors in p2 for
|
||||
each point in p1. Concretely, `p2_nn[n, i, k]` gives the k-th neighbor
|
||||
for `p1[n, i]`. Returned if `return_nn` is True. The output is a tensor
|
||||
of shape (N, P1, K, U).
|
||||
|
||||
"""
|
||||
if p1.shape[0] != p2.shape[0]:
|
||||
raise ValueError("pts1 and pts2 must have the same batch dimension.")
|
||||
if p1.shape[2] != p2.shape[2]:
|
||||
raise ValueError("pts1 and pts2 must have the same point dimension.")
|
||||
|
||||
p1 = p1.contiguous()
|
||||
p2 = p2.contiguous()
|
||||
P1 = p1.shape[1]
|
||||
P2 = p2.shape[1]
|
||||
D = p2.shape[2]
|
||||
N = p1.shape[0]
|
||||
|
||||
if lengths1 is None:
|
||||
lengths1 = torch.full((N,), P1, dtype=torch.int64, device=p1.device)
|
||||
if lengths2 is None:
|
||||
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device)
|
||||
|
||||
# pyre-fixme[16]: `_ball_query` has no attribute `apply`.
|
||||
dists, idx = _ball_query.apply(p1, p2, lengths1, lengths2, K, radius)
|
||||
|
||||
# Gather the neighbors if needed
|
||||
points_nn = None
|
||||
if return_nn:
|
||||
idx_expanded = idx[:, :, :, None].expand(-1, -1, -1, D)
|
||||
idx_mask = idx_expanded.eq(-1)
|
||||
idx_new = idx_expanded.clone()
|
||||
# Replace -1 values with 0 for gather
|
||||
idx_new[idx_mask] = 0
|
||||
# Gather points from p2
|
||||
points_nn = p2[:, :, None].expand(-1, -1, K, -1).gather(1, idx_new)
|
||||
# Replace padded values
|
||||
points_nn[idx_mask] = 0.0
|
||||
|
||||
return _KNN(dists=dists, idx=idx, knn=points_nn)
|
||||
Reference in New Issue
Block a user