diff --git a/pytorch3d/ops/ball_query.py b/pytorch3d/ops/ball_query.py index 5105400a..e352d878 100644 --- a/pytorch3d/ops/ball_query.py +++ b/pytorch3d/ops/ball_query.py @@ -12,6 +12,7 @@ from torch.autograd import Function from torch.autograd.function import once_differentiable from .knn import _KNN +from .utils import masked_gather class _ball_query(Function): @@ -123,7 +124,6 @@ def ball_query( p2 = p2.contiguous() P1 = p1.shape[1] P2 = p2.shape[1] - D = p2.shape[2] N = p1.shape[0] if lengths1 is None: @@ -135,16 +135,6 @@ def ball_query( dists, idx = _ball_query.apply(p1, p2, lengths1, lengths2, K, radius) # Gather the neighbors if needed - points_nn = None - if return_nn: - idx_expanded = idx[:, :, :, None].expand(-1, -1, -1, D) - idx_mask = idx_expanded.eq(-1) - idx_new = idx_expanded.clone() - # Replace -1 values with 0 for gather - idx_new[idx_mask] = 0 - # Gather points from p2 - points_nn = p2[:, :, None].expand(-1, -1, K, -1).gather(1, idx_new) - # Replace padded values - points_nn[idx_mask] = 0.0 + points_nn = masked_gather(p2, idx) if return_nn else None return _KNN(dists=dists, idx=idx, knn=points_nn) diff --git a/pytorch3d/ops/sample_farthest_points.py b/pytorch3d/ops/sample_farthest_points.py new file mode 100644 index 00000000..071f3ffe --- /dev/null +++ b/pytorch3d/ops/sample_farthest_points.py @@ -0,0 +1,124 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from random import randint +from typing import Optional, Tuple, Union, List + +import torch + +from .utils import masked_gather + + +def sample_farthest_points_naive( + points: torch.Tensor, + lengths: Optional[torch.Tensor] = None, + K: Union[int, List, torch.Tensor] = 50, + random_start_point: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Iterative farthest point sampling algorithm [1] to subsample a set of + K points from a given pointcloud. At each iteration, a point is selected + which has the largest nearest neighbor distance to any of the + already selected points. + + Farthest point sampling provides more uniform coverage of the input + point cloud compared to uniform random sampling. + + [1] Charles R. Qi et al, "PointNet++: Deep Hierarchical Feature Learning + on Point Sets in a Metric Space", NeurIPS 2017. + + Args: + points: (N, P, D) array containing the batch of pointclouds + lengths: (N,) number of points in each pointcloud (to support heterogeneous + batches of pointclouds) + K: samples you want in each sampled point cloud (this is typically << P). If + K is an int then the same number of samples are selected for each + pointcloud in the batch. If K is a tensor is should be length (N,) + giving the number of samples to select for each element in the batch + random_start_point: bool, if True, a random point is selected as the starting + point for iterative sampling. + + Returns: + selected_points: (N, K, D), array of selected values from points. If the input + K is a tensor, then the shape will be (N, max(K), D), and padded with + 0.0 for batch elements where k_i < max(K). + selected_indices: (N, K) array of selected indices. If the input + K is a tensor, then the shape will be (N, max(K), D), and padded with + -1 for batch elements where k_i < max(K). + """ + N, P, D = points.shape + device = points.device + + # Validate inputs + if lengths is None: + lengths = torch.full((N,), P, dtype=torch.int64, device=device) + + if lengths.shape[0] != N: + raise ValueError("points and lengths must have same batch dimension.") + + # TODO: support providing K as a ratio of the total number of points instead of as an int + if isinstance(K, int): + K = torch.full((N,), K, dtype=torch.int64, device=device) + elif isinstance(K, list): + K = torch.tensor(K, dtype=torch.int64, device=device) + + if K.shape[0] != N: + raise ValueError("K and points must have the same batch dimension") + + # Find max value of K + max_K = torch.max(K) + + # List of selected indices from each batch element + all_sampled_indices = [] + + for n in range(N): + # Initialize an array for the sampled indices, shape: (max_K,) + sample_idx_batch = torch.full( + (max_K,), fill_value=-1, dtype=torch.int64, device=device + ) + + # Initialize closest distances to inf, shape: (P,) + # This will be updated at each iteration to track the closest distance of the + # remaining points to any of the selected points + # pyre-fixme[16]: `torch.Tensor` has no attribute new_full. + closest_dists = points.new_full( + (lengths[n],), float("inf"), dtype=torch.float32 + ) + + # Select a random point index and save it as the starting point + selected_idx = randint(0, lengths[n] - 1) if random_start_point else 0 + sample_idx_batch[0] = selected_idx + + # If the pointcloud has fewer than K points then only iterate over the min + k_n = min(lengths[n], K[n]) + + # Iteratively select points for a maximum of k_n + for i in range(1, k_n): + # Find the distance between the last selected point + # and all the other points. If a point has already been selected + # it's distance will be 0.0 so it will not be selected again as the max. + dist = points[n, selected_idx, :] - points[n, : lengths[n], :] + dist_to_last_selected = (dist ** 2).sum(-1) # (P - i) + + # If closer than currently saved distance to one of the selected + # points, then updated closest_dists + closest_dists = torch.min(dist_to_last_selected, closest_dists) # (P - i) + + # The aim is to pick the point that has the largest + # nearest neighbour distance to any of the already selected points + selected_idx = torch.argmax(closest_dists) + sample_idx_batch[i] = selected_idx + + # Add the list of points for this batch to the final list + all_sampled_indices.append(sample_idx_batch) + + all_sampled_indices = torch.stack(all_sampled_indices, dim=0) + + # Gather the points + all_sampled_points = masked_gather(points, all_sampled_indices) + + # Return (N, max_K, D) subsampled points and indices + return all_sampled_points, all_sampled_indices diff --git a/pytorch3d/ops/utils.py b/pytorch3d/ops/utils.py index c6365449..d61ea5bf 100644 --- a/pytorch3d/ops/utils.py +++ b/pytorch3d/ops/utils.py @@ -15,6 +15,54 @@ if TYPE_CHECKING: from pytorch3d.structures import Pointclouds +def masked_gather(points: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: + """ + Helper function for torch.gather to collect the points at + the given indices in idx where some of the indices might be -1 to + indicate padding. These indices are first replaced with 0. + Then the points are gathered after which the padded values + are set to 0.0. + + Args: + points: (N, P, D) float32 tensor of points + idx: (N, K) or (N, P, K) long tensor of indices into points, where + some indices are -1 to indicate padding + + Returns: + selected_points: (N, K, D) float32 tensor of points + at the given indices + """ + + if len(idx) != len(points): + raise ValueError("points and idx must have the same batch dimension") + + N, P, D = points.shape + + if idx.ndim == 3: + # Case: KNN, Ball Query where idx is of shape (N, P', K) + # where P' is not necessarily the same as P as the + # points may be gathered from a different pointcloud. + K = idx.shape[2] + # Match dimensions for points and indices + idx_expanded = idx[..., None].expand(-1, -1, -1, D) + points = points[:, :, None, :].expand(-1, -1, K, -1) + elif idx.ndim == 2: + # Farthest point sampling where idx is of shape (N, K) + idx_expanded = idx[..., None].expand(-1, -1, D) + else: + raise ValueError("idx format is not supported %s" % repr(idx.shape)) + + idx_expanded_mask = idx_expanded.eq(-1) + idx_expanded = idx_expanded.clone() + # Replace -1 values with 0 for gather + idx_expanded[idx_expanded_mask] = 0 + # Gather points + selected_points = points.gather(dim=1, index=idx_expanded) + # Replace padded values + selected_points[idx_expanded_mask] = 0.0 + return selected_points + + def wmean( x: torch.Tensor, weight: Optional[torch.Tensor] = None, diff --git a/tests/test_ops_utils.py b/tests/test_ops_utils.py index d12a93b3..a4653373 100644 --- a/tests/test_ops_utils.py +++ b/tests/test_ops_utils.py @@ -76,3 +76,13 @@ class TestOpsUtils(TestCaseMixin, unittest.TestCase): mean = oputil.wmean(x, dim=(0, 1), weight=weight, keepdim=False) mean_gt = np.average(x_np, axis=(0, 1), weights=weight_np) self.assertClose(mean.cpu().data.numpy(), mean_gt) + + def test_masked_gather_errors(self): + idx = torch.randint(0, 10, size=(5, 10, 4, 2)) + points = torch.randn(size=(5, 10, 3)) + with self.assertRaisesRegex(ValueError, "format is not supported"): + oputil.masked_gather(points, idx) + + points = torch.randn(size=(2, 10, 3)) + with self.assertRaisesRegex(ValueError, "same batch dimension"): + oputil.masked_gather(points, idx) diff --git a/tests/test_sample_farthest_points.py b/tests/test_sample_farthest_points.py new file mode 100644 index 00000000..54b55f1c --- /dev/null +++ b/tests/test_sample_farthest_points.py @@ -0,0 +1,111 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from common_testing import TestCaseMixin, get_random_cuda_device +from pytorch3d.ops.sample_farthest_points import sample_farthest_points_naive +from pytorch3d.ops.utils import masked_gather + + +class TestFPS(TestCaseMixin, unittest.TestCase): + def test_simple(self): + device = get_random_cuda_device() + # fmt: off + points = torch.tensor( + [ + [ + [-1.0, -1.0], # noqa: E241, E201 + [-1.3, 1.1], # noqa: E241, E201 + [ 0.2, -1.1], # noqa: E241, E201 + [ 0.0, 0.0], # noqa: E241, E201 + [ 1.3, 1.3], # noqa: E241, E201 + [ 1.0, 0.5], # noqa: E241, E201 + [-1.3, 0.2], # noqa: E241, E201 + [ 1.5, -0.5], # noqa: E241, E201 + ], + [ + [-2.2, -2.4], # noqa: E241, E201 + [-2.1, 2.0], # noqa: E241, E201 + [ 2.2, 2.1], # noqa: E241, E201 + [ 2.1, -2.4], # noqa: E241, E201 + [ 0.4, -1.0], # noqa: E241, E201 + [ 0.3, 0.3], # noqa: E241, E201 + [ 1.2, 0.5], # noqa: E241, E201 + [ 4.5, 4.5], # noqa: E241, E201 + ], + ], + dtype=torch.float32, + device=device, + ) + # fmt: on + expected_inds = torch.tensor([[0, 4], [0, 7]], dtype=torch.int64, device=device) + out_points, out_inds = sample_farthest_points_naive(points, K=2) + self.assertClose(out_inds, expected_inds) + + # Gather the points + expected_inds = expected_inds[..., None].expand(-1, -1, points.shape[-1]) + self.assertClose(out_points, points.gather(dim=1, index=expected_inds)) + + # Different number of points sampled for each pointcloud in the batch + expected_inds = torch.tensor( + [[0, 4, 1], [0, 7, -1]], dtype=torch.int64, device=device + ) + out_points, out_inds = sample_farthest_points_naive(points, K=[3, 2]) + self.assertClose(out_inds, expected_inds) + + # Gather the points + expected_points = masked_gather(points, expected_inds) + self.assertClose(out_points, expected_points) + + def test_random_heterogeneous(self): + device = get_random_cuda_device() + N, P, D, K = 5, 40, 5, 8 + points = torch.randn((N, P, D), device=device) + out_points, out_idxs = sample_farthest_points_naive(points, K=K) + self.assertTrue(out_idxs.min() >= 0) + for n in range(N): + self.assertEqual(out_idxs[n].ne(-1).sum(), K) + + lengths = torch.randint(low=1, high=P, size=(N,), device=device) + out_points, out_idxs = sample_farthest_points_naive(points, lengths, K=50) + + for n in range(N): + # Check that for heterogeneous batches, the max number of + # selected points is less than the length + self.assertTrue(out_idxs[n].ne(-1).sum() <= lengths[n]) + self.assertTrue(out_idxs[n].max() <= lengths[n]) + + # Check there are no duplicate indices + val_mask = out_idxs[n].ne(-1) + vals, counts = torch.unique(out_idxs[n][val_mask], return_counts=True) + self.assertTrue(counts.le(1).all()) + + def test_errors(self): + device = get_random_cuda_device() + N, P, D, K = 5, 40, 5, 8 + points = torch.randn((N, P, D), device=device) + wrong_batch_dim = torch.randint(low=1, high=K, size=(K,), device=device) + + # K has diferent batch dimension to points + with self.assertRaisesRegex(ValueError, "K and points must have"): + sample_farthest_points_naive(points, K=wrong_batch_dim) + + # lengths has diferent batch dimension to points + with self.assertRaisesRegex(ValueError, "points and lengths must have"): + sample_farthest_points_naive(points, lengths=wrong_batch_dim, K=K) + + def test_random_start(self): + device = get_random_cuda_device() + N, P, D, K = 5, 40, 5, 8 + points = torch.randn((N, P, D), device=device) + out_points, out_idxs = sample_farthest_points_naive( + points, K=K, random_start_point=True + ) + # Check the first index is not 0 for all batch elements + # when random_start_point = True + self.assertTrue(out_idxs[:, 0].sum() > 0)