mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
Farthest point sampling python naive
Summary: This is a naive python implementation of the iterative farthest point sampling algorithm along with associated simple tests. The C++/CUDA implementations will follow in subsequent diffs. The algorithm is used to subsample a pointcloud with better coverage of the space of the pointcloud. The function has not been added to `__init__.py`. I will add this after the full C++/CUDA implementations. Reviewed By: jcjohnson Differential Revision: D30285716 fbshipit-source-id: 33f4181041fc652776406bcfd67800a6f0c3dd58
This commit is contained in:
parent
a0d76a7080
commit
3b7d78c7a7
@ -12,6 +12,7 @@ from torch.autograd import Function
|
|||||||
from torch.autograd.function import once_differentiable
|
from torch.autograd.function import once_differentiable
|
||||||
|
|
||||||
from .knn import _KNN
|
from .knn import _KNN
|
||||||
|
from .utils import masked_gather
|
||||||
|
|
||||||
|
|
||||||
class _ball_query(Function):
|
class _ball_query(Function):
|
||||||
@ -123,7 +124,6 @@ def ball_query(
|
|||||||
p2 = p2.contiguous()
|
p2 = p2.contiguous()
|
||||||
P1 = p1.shape[1]
|
P1 = p1.shape[1]
|
||||||
P2 = p2.shape[1]
|
P2 = p2.shape[1]
|
||||||
D = p2.shape[2]
|
|
||||||
N = p1.shape[0]
|
N = p1.shape[0]
|
||||||
|
|
||||||
if lengths1 is None:
|
if lengths1 is None:
|
||||||
@ -135,16 +135,6 @@ def ball_query(
|
|||||||
dists, idx = _ball_query.apply(p1, p2, lengths1, lengths2, K, radius)
|
dists, idx = _ball_query.apply(p1, p2, lengths1, lengths2, K, radius)
|
||||||
|
|
||||||
# Gather the neighbors if needed
|
# Gather the neighbors if needed
|
||||||
points_nn = None
|
points_nn = masked_gather(p2, idx) if return_nn else None
|
||||||
if return_nn:
|
|
||||||
idx_expanded = idx[:, :, :, None].expand(-1, -1, -1, D)
|
|
||||||
idx_mask = idx_expanded.eq(-1)
|
|
||||||
idx_new = idx_expanded.clone()
|
|
||||||
# Replace -1 values with 0 for gather
|
|
||||||
idx_new[idx_mask] = 0
|
|
||||||
# Gather points from p2
|
|
||||||
points_nn = p2[:, :, None].expand(-1, -1, K, -1).gather(1, idx_new)
|
|
||||||
# Replace padded values
|
|
||||||
points_nn[idx_mask] = 0.0
|
|
||||||
|
|
||||||
return _KNN(dists=dists, idx=idx, knn=points_nn)
|
return _KNN(dists=dists, idx=idx, knn=points_nn)
|
||||||
|
124
pytorch3d/ops/sample_farthest_points.py
Normal file
124
pytorch3d/ops/sample_farthest_points.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from random import randint
|
||||||
|
from typing import Optional, Tuple, Union, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .utils import masked_gather
|
||||||
|
|
||||||
|
|
||||||
|
def sample_farthest_points_naive(
|
||||||
|
points: torch.Tensor,
|
||||||
|
lengths: Optional[torch.Tensor] = None,
|
||||||
|
K: Union[int, List, torch.Tensor] = 50,
|
||||||
|
random_start_point: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Iterative farthest point sampling algorithm [1] to subsample a set of
|
||||||
|
K points from a given pointcloud. At each iteration, a point is selected
|
||||||
|
which has the largest nearest neighbor distance to any of the
|
||||||
|
already selected points.
|
||||||
|
|
||||||
|
Farthest point sampling provides more uniform coverage of the input
|
||||||
|
point cloud compared to uniform random sampling.
|
||||||
|
|
||||||
|
[1] Charles R. Qi et al, "PointNet++: Deep Hierarchical Feature Learning
|
||||||
|
on Point Sets in a Metric Space", NeurIPS 2017.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
points: (N, P, D) array containing the batch of pointclouds
|
||||||
|
lengths: (N,) number of points in each pointcloud (to support heterogeneous
|
||||||
|
batches of pointclouds)
|
||||||
|
K: samples you want in each sampled point cloud (this is typically << P). If
|
||||||
|
K is an int then the same number of samples are selected for each
|
||||||
|
pointcloud in the batch. If K is a tensor is should be length (N,)
|
||||||
|
giving the number of samples to select for each element in the batch
|
||||||
|
random_start_point: bool, if True, a random point is selected as the starting
|
||||||
|
point for iterative sampling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
selected_points: (N, K, D), array of selected values from points. If the input
|
||||||
|
K is a tensor, then the shape will be (N, max(K), D), and padded with
|
||||||
|
0.0 for batch elements where k_i < max(K).
|
||||||
|
selected_indices: (N, K) array of selected indices. If the input
|
||||||
|
K is a tensor, then the shape will be (N, max(K), D), and padded with
|
||||||
|
-1 for batch elements where k_i < max(K).
|
||||||
|
"""
|
||||||
|
N, P, D = points.shape
|
||||||
|
device = points.device
|
||||||
|
|
||||||
|
# Validate inputs
|
||||||
|
if lengths is None:
|
||||||
|
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
|
if lengths.shape[0] != N:
|
||||||
|
raise ValueError("points and lengths must have same batch dimension.")
|
||||||
|
|
||||||
|
# TODO: support providing K as a ratio of the total number of points instead of as an int
|
||||||
|
if isinstance(K, int):
|
||||||
|
K = torch.full((N,), K, dtype=torch.int64, device=device)
|
||||||
|
elif isinstance(K, list):
|
||||||
|
K = torch.tensor(K, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
|
if K.shape[0] != N:
|
||||||
|
raise ValueError("K and points must have the same batch dimension")
|
||||||
|
|
||||||
|
# Find max value of K
|
||||||
|
max_K = torch.max(K)
|
||||||
|
|
||||||
|
# List of selected indices from each batch element
|
||||||
|
all_sampled_indices = []
|
||||||
|
|
||||||
|
for n in range(N):
|
||||||
|
# Initialize an array for the sampled indices, shape: (max_K,)
|
||||||
|
sample_idx_batch = torch.full(
|
||||||
|
(max_K,), fill_value=-1, dtype=torch.int64, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize closest distances to inf, shape: (P,)
|
||||||
|
# This will be updated at each iteration to track the closest distance of the
|
||||||
|
# remaining points to any of the selected points
|
||||||
|
# pyre-fixme[16]: `torch.Tensor` has no attribute new_full.
|
||||||
|
closest_dists = points.new_full(
|
||||||
|
(lengths[n],), float("inf"), dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
# Select a random point index and save it as the starting point
|
||||||
|
selected_idx = randint(0, lengths[n] - 1) if random_start_point else 0
|
||||||
|
sample_idx_batch[0] = selected_idx
|
||||||
|
|
||||||
|
# If the pointcloud has fewer than K points then only iterate over the min
|
||||||
|
k_n = min(lengths[n], K[n])
|
||||||
|
|
||||||
|
# Iteratively select points for a maximum of k_n
|
||||||
|
for i in range(1, k_n):
|
||||||
|
# Find the distance between the last selected point
|
||||||
|
# and all the other points. If a point has already been selected
|
||||||
|
# it's distance will be 0.0 so it will not be selected again as the max.
|
||||||
|
dist = points[n, selected_idx, :] - points[n, : lengths[n], :]
|
||||||
|
dist_to_last_selected = (dist ** 2).sum(-1) # (P - i)
|
||||||
|
|
||||||
|
# If closer than currently saved distance to one of the selected
|
||||||
|
# points, then updated closest_dists
|
||||||
|
closest_dists = torch.min(dist_to_last_selected, closest_dists) # (P - i)
|
||||||
|
|
||||||
|
# The aim is to pick the point that has the largest
|
||||||
|
# nearest neighbour distance to any of the already selected points
|
||||||
|
selected_idx = torch.argmax(closest_dists)
|
||||||
|
sample_idx_batch[i] = selected_idx
|
||||||
|
|
||||||
|
# Add the list of points for this batch to the final list
|
||||||
|
all_sampled_indices.append(sample_idx_batch)
|
||||||
|
|
||||||
|
all_sampled_indices = torch.stack(all_sampled_indices, dim=0)
|
||||||
|
|
||||||
|
# Gather the points
|
||||||
|
all_sampled_points = masked_gather(points, all_sampled_indices)
|
||||||
|
|
||||||
|
# Return (N, max_K, D) subsampled points and indices
|
||||||
|
return all_sampled_points, all_sampled_indices
|
@ -15,6 +15,54 @@ if TYPE_CHECKING:
|
|||||||
from pytorch3d.structures import Pointclouds
|
from pytorch3d.structures import Pointclouds
|
||||||
|
|
||||||
|
|
||||||
|
def masked_gather(points: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Helper function for torch.gather to collect the points at
|
||||||
|
the given indices in idx where some of the indices might be -1 to
|
||||||
|
indicate padding. These indices are first replaced with 0.
|
||||||
|
Then the points are gathered after which the padded values
|
||||||
|
are set to 0.0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
points: (N, P, D) float32 tensor of points
|
||||||
|
idx: (N, K) or (N, P, K) long tensor of indices into points, where
|
||||||
|
some indices are -1 to indicate padding
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
selected_points: (N, K, D) float32 tensor of points
|
||||||
|
at the given indices
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(idx) != len(points):
|
||||||
|
raise ValueError("points and idx must have the same batch dimension")
|
||||||
|
|
||||||
|
N, P, D = points.shape
|
||||||
|
|
||||||
|
if idx.ndim == 3:
|
||||||
|
# Case: KNN, Ball Query where idx is of shape (N, P', K)
|
||||||
|
# where P' is not necessarily the same as P as the
|
||||||
|
# points may be gathered from a different pointcloud.
|
||||||
|
K = idx.shape[2]
|
||||||
|
# Match dimensions for points and indices
|
||||||
|
idx_expanded = idx[..., None].expand(-1, -1, -1, D)
|
||||||
|
points = points[:, :, None, :].expand(-1, -1, K, -1)
|
||||||
|
elif idx.ndim == 2:
|
||||||
|
# Farthest point sampling where idx is of shape (N, K)
|
||||||
|
idx_expanded = idx[..., None].expand(-1, -1, D)
|
||||||
|
else:
|
||||||
|
raise ValueError("idx format is not supported %s" % repr(idx.shape))
|
||||||
|
|
||||||
|
idx_expanded_mask = idx_expanded.eq(-1)
|
||||||
|
idx_expanded = idx_expanded.clone()
|
||||||
|
# Replace -1 values with 0 for gather
|
||||||
|
idx_expanded[idx_expanded_mask] = 0
|
||||||
|
# Gather points
|
||||||
|
selected_points = points.gather(dim=1, index=idx_expanded)
|
||||||
|
# Replace padded values
|
||||||
|
selected_points[idx_expanded_mask] = 0.0
|
||||||
|
return selected_points
|
||||||
|
|
||||||
|
|
||||||
def wmean(
|
def wmean(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
weight: Optional[torch.Tensor] = None,
|
weight: Optional[torch.Tensor] = None,
|
||||||
|
@ -76,3 +76,13 @@ class TestOpsUtils(TestCaseMixin, unittest.TestCase):
|
|||||||
mean = oputil.wmean(x, dim=(0, 1), weight=weight, keepdim=False)
|
mean = oputil.wmean(x, dim=(0, 1), weight=weight, keepdim=False)
|
||||||
mean_gt = np.average(x_np, axis=(0, 1), weights=weight_np)
|
mean_gt = np.average(x_np, axis=(0, 1), weights=weight_np)
|
||||||
self.assertClose(mean.cpu().data.numpy(), mean_gt)
|
self.assertClose(mean.cpu().data.numpy(), mean_gt)
|
||||||
|
|
||||||
|
def test_masked_gather_errors(self):
|
||||||
|
idx = torch.randint(0, 10, size=(5, 10, 4, 2))
|
||||||
|
points = torch.randn(size=(5, 10, 3))
|
||||||
|
with self.assertRaisesRegex(ValueError, "format is not supported"):
|
||||||
|
oputil.masked_gather(points, idx)
|
||||||
|
|
||||||
|
points = torch.randn(size=(2, 10, 3))
|
||||||
|
with self.assertRaisesRegex(ValueError, "same batch dimension"):
|
||||||
|
oputil.masked_gather(points, idx)
|
||||||
|
111
tests/test_sample_farthest_points.py
Normal file
111
tests/test_sample_farthest_points.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||||
|
from pytorch3d.ops.sample_farthest_points import sample_farthest_points_naive
|
||||||
|
from pytorch3d.ops.utils import masked_gather
|
||||||
|
|
||||||
|
|
||||||
|
class TestFPS(TestCaseMixin, unittest.TestCase):
|
||||||
|
def test_simple(self):
|
||||||
|
device = get_random_cuda_device()
|
||||||
|
# fmt: off
|
||||||
|
points = torch.tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[-1.0, -1.0], # noqa: E241, E201
|
||||||
|
[-1.3, 1.1], # noqa: E241, E201
|
||||||
|
[ 0.2, -1.1], # noqa: E241, E201
|
||||||
|
[ 0.0, 0.0], # noqa: E241, E201
|
||||||
|
[ 1.3, 1.3], # noqa: E241, E201
|
||||||
|
[ 1.0, 0.5], # noqa: E241, E201
|
||||||
|
[-1.3, 0.2], # noqa: E241, E201
|
||||||
|
[ 1.5, -0.5], # noqa: E241, E201
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[-2.2, -2.4], # noqa: E241, E201
|
||||||
|
[-2.1, 2.0], # noqa: E241, E201
|
||||||
|
[ 2.2, 2.1], # noqa: E241, E201
|
||||||
|
[ 2.1, -2.4], # noqa: E241, E201
|
||||||
|
[ 0.4, -1.0], # noqa: E241, E201
|
||||||
|
[ 0.3, 0.3], # noqa: E241, E201
|
||||||
|
[ 1.2, 0.5], # noqa: E241, E201
|
||||||
|
[ 4.5, 4.5], # noqa: E241, E201
|
||||||
|
],
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
expected_inds = torch.tensor([[0, 4], [0, 7]], dtype=torch.int64, device=device)
|
||||||
|
out_points, out_inds = sample_farthest_points_naive(points, K=2)
|
||||||
|
self.assertClose(out_inds, expected_inds)
|
||||||
|
|
||||||
|
# Gather the points
|
||||||
|
expected_inds = expected_inds[..., None].expand(-1, -1, points.shape[-1])
|
||||||
|
self.assertClose(out_points, points.gather(dim=1, index=expected_inds))
|
||||||
|
|
||||||
|
# Different number of points sampled for each pointcloud in the batch
|
||||||
|
expected_inds = torch.tensor(
|
||||||
|
[[0, 4, 1], [0, 7, -1]], dtype=torch.int64, device=device
|
||||||
|
)
|
||||||
|
out_points, out_inds = sample_farthest_points_naive(points, K=[3, 2])
|
||||||
|
self.assertClose(out_inds, expected_inds)
|
||||||
|
|
||||||
|
# Gather the points
|
||||||
|
expected_points = masked_gather(points, expected_inds)
|
||||||
|
self.assertClose(out_points, expected_points)
|
||||||
|
|
||||||
|
def test_random_heterogeneous(self):
|
||||||
|
device = get_random_cuda_device()
|
||||||
|
N, P, D, K = 5, 40, 5, 8
|
||||||
|
points = torch.randn((N, P, D), device=device)
|
||||||
|
out_points, out_idxs = sample_farthest_points_naive(points, K=K)
|
||||||
|
self.assertTrue(out_idxs.min() >= 0)
|
||||||
|
for n in range(N):
|
||||||
|
self.assertEqual(out_idxs[n].ne(-1).sum(), K)
|
||||||
|
|
||||||
|
lengths = torch.randint(low=1, high=P, size=(N,), device=device)
|
||||||
|
out_points, out_idxs = sample_farthest_points_naive(points, lengths, K=50)
|
||||||
|
|
||||||
|
for n in range(N):
|
||||||
|
# Check that for heterogeneous batches, the max number of
|
||||||
|
# selected points is less than the length
|
||||||
|
self.assertTrue(out_idxs[n].ne(-1).sum() <= lengths[n])
|
||||||
|
self.assertTrue(out_idxs[n].max() <= lengths[n])
|
||||||
|
|
||||||
|
# Check there are no duplicate indices
|
||||||
|
val_mask = out_idxs[n].ne(-1)
|
||||||
|
vals, counts = torch.unique(out_idxs[n][val_mask], return_counts=True)
|
||||||
|
self.assertTrue(counts.le(1).all())
|
||||||
|
|
||||||
|
def test_errors(self):
|
||||||
|
device = get_random_cuda_device()
|
||||||
|
N, P, D, K = 5, 40, 5, 8
|
||||||
|
points = torch.randn((N, P, D), device=device)
|
||||||
|
wrong_batch_dim = torch.randint(low=1, high=K, size=(K,), device=device)
|
||||||
|
|
||||||
|
# K has diferent batch dimension to points
|
||||||
|
with self.assertRaisesRegex(ValueError, "K and points must have"):
|
||||||
|
sample_farthest_points_naive(points, K=wrong_batch_dim)
|
||||||
|
|
||||||
|
# lengths has diferent batch dimension to points
|
||||||
|
with self.assertRaisesRegex(ValueError, "points and lengths must have"):
|
||||||
|
sample_farthest_points_naive(points, lengths=wrong_batch_dim, K=K)
|
||||||
|
|
||||||
|
def test_random_start(self):
|
||||||
|
device = get_random_cuda_device()
|
||||||
|
N, P, D, K = 5, 40, 5, 8
|
||||||
|
points = torch.randn((N, P, D), device=device)
|
||||||
|
out_points, out_idxs = sample_farthest_points_naive(
|
||||||
|
points, K=K, random_start_point=True
|
||||||
|
)
|
||||||
|
# Check the first index is not 0 for all batch elements
|
||||||
|
# when random_start_point = True
|
||||||
|
self.assertTrue(out_idxs[:, 0].sum() > 0)
|
Loading…
x
Reference in New Issue
Block a user