diff --git a/pytorch3d/csrc/ext.cpp b/pytorch3d/csrc/ext.cpp index 4fa78def..3b9a9edc 100644 --- a/pytorch3d/csrc/ext.cpp +++ b/pytorch3d/csrc/ext.cpp @@ -27,6 +27,7 @@ #include "rasterize_meshes/rasterize_meshes.h" #include "rasterize_points/rasterize_points.h" #include "sample_pdf/sample_pdf.h" +#include "sample_farthest_points/sample_farthest_points.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("face_areas_normals_forward", &FaceAreasNormalsForward); @@ -40,9 +41,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { #endif m.def("knn_points_idx", &KNearestNeighborIdx); m.def("knn_points_backward", &KNearestNeighborBackward); - - // Ball Query m.def("ball_query", &BallQuery); + m.def("sample_farthest_points", &FarthestPointSampling); m.def( "mesh_normal_consistency_find_verts", &MeshNormalConsistencyFindVertices); m.def("gather_scatter", &GatherScatter); diff --git a/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.cpp b/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.cpp new file mode 100644 index 00000000..f4167ac5 --- /dev/null +++ b/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.cpp @@ -0,0 +1,107 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +at::Tensor FarthestPointSamplingCpu( + const at::Tensor& points, + const at::Tensor& lengths, + const at::Tensor& K, + const bool random_start_point) { + // Get constants + const int64_t N = points.size(0); + const int64_t P = points.size(1); + const int64_t D = points.size(2); + const int64_t max_K = torch::max(K).item(); + + // Initialize an output array for the sampled indices + // of shape (N, max_K) + auto long_opts = lengths.options(); + torch::Tensor sampled_indices = torch::full({N, max_K}, -1, long_opts); + + // Create accessors for all tensors + auto points_a = points.accessor(); + auto lengths_a = lengths.accessor(); + auto k_a = K.accessor(); + auto sampled_indices_a = sampled_indices.accessor(); + + // Initialize a mask to prevent duplicates + // If true, the point has already been selected. + std::vector selected_points_mask(P, false); + + // Initialize to infinity a vector of + // distances from each point to any of the previously selected points + std::vector dists(P, std::numeric_limits::max()); + + // Initialize random number generation for random starting points + std::random_device rd; + std::default_random_engine eng(rd()); + + for (int64_t n = 0; n < N; ++n) { + // Resize and reset points mask and distances for each batch + selected_points_mask.resize(lengths_a[n]); + dists.resize(lengths_a[n]); + std::fill(selected_points_mask.begin(), selected_points_mask.end(), false); + std::fill(dists.begin(), dists.end(), std::numeric_limits::max()); + + // Select a starting point index and save it + std::uniform_int_distribution distr(0, lengths_a[n] - 1); + int64_t last_idx = random_start_point ? distr(eng) : 0; + sampled_indices_a[n][0] = last_idx; + + // Set the value of the mask at this point to false + selected_points_mask[last_idx] = true; + + // For heterogeneous pointclouds, use the minimum of the + // length for that cloud compared to K as the number of + // points to sample + const int64_t batch_k = std::min(lengths_a[n], k_a[n]); + + // Iteratively select batch_k points per batch + for (int64_t k = 1; k < batch_k; ++k) { + // Iterate through all the points + for (int64_t p = 0; p < lengths_a[n]; ++p) { + if (selected_points_mask[p]) { + // For already selected points set the distance to 0.0 + dists[p] = 0.0; + continue; + } + + // Calculate the distance to the last selected point + float dist2 = 0.0; + for (int64_t d = 0; d < D; ++d) { + float diff = points_a[n][last_idx][d] - points_a[n][p][d]; + dist2 += diff * diff; + } + + // If the distance of this point to the last selected point is closer + // than the distance to any of the previously selected points, then + // update this distance + if (dist2 < dists[p]) { + dists[p] = dist2; + } + } + + // The aim is to pick the point that has the largest + // nearest neighbour distance to any of the already selected points + auto itr = std::max_element(dists.begin(), dists.end()); + last_idx = std::distance(dists.begin(), itr); + + // Save selected point + sampled_indices_a[n][k] = last_idx; + + // Set the mask value to true to prevent duplicates. + selected_points_mask[last_idx] = true; + } + } + + return sampled_indices; +} diff --git a/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.h b/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.h new file mode 100644 index 00000000..bb4456d3 --- /dev/null +++ b/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include +#include "utils/pytorch3d_cutils.h" + +// Iterative farthest point sampling algorithm [1] to subsample a set of +// K points from a given pointcloud. At each iteration, a point is selected +// which has the largest nearest neighbor distance to any of the +// already selected points. + +// Farthest point sampling provides more uniform coverage of the input +// point cloud compared to uniform random sampling. + +// [1] Charles R. Qi et al, "PointNet++: Deep Hierarchical Feature Learning +// on Point Sets in a Metric Space", NeurIPS 2017. + +// Args: +// points: (N, P, D) float32 Tensor containing the batch of pointclouds. +// lengths: (N,) long Tensor giving the number of points in each pointcloud +// (to support heterogeneous batches of pointclouds). +// K: a tensor of length (N,) giving the number of +// samples to select for each element in the batch. +// The number of samples is typically << P. +// random_start_point: bool, if True, a random point is selected as the +// starting point for iterative sampling. +// Returns: +// selected_indices: (N, K) array of selected indices. If the values in +// K are not all the same, then the shape will be (N, max(K), D), and +// padded with -1 for batch elements where k_i < max(K). The selected +// points are gathered in the pytorch autograd wrapper. + +at::Tensor FarthestPointSamplingCpu( + const at::Tensor& points, + const at::Tensor& lengths, + const at::Tensor& K, + const bool random_start_point); + +// Exposed implementation. +at::Tensor FarthestPointSampling( + const at::Tensor& points, + const at::Tensor& lengths, + const at::Tensor& K, + const bool random_start_point) { + if (points.is_cuda() || lengths.is_cuda() || K.is_cuda()) { + AT_ERROR("CUDA implementation not yet supported"); + } + return FarthestPointSamplingCpu(points, lengths, K, random_start_point); +} diff --git a/pytorch3d/ops/sample_farthest_points.py b/pytorch3d/ops/sample_farthest_points.py index 071f3ffe..c75253f6 100644 --- a/pytorch3d/ops/sample_farthest_points.py +++ b/pytorch3d/ops/sample_farthest_points.py @@ -8,11 +8,12 @@ from random import randint from typing import Optional, Tuple, Union, List import torch +from pytorch3d import _C from .utils import masked_gather -def sample_farthest_points_naive( +def sample_farthest_points( points: torch.Tensor, lengths: Optional[torch.Tensor] = None, K: Union[int, List, torch.Tensor] = 50, @@ -34,7 +35,7 @@ def sample_farthest_points_naive( points: (N, P, D) array containing the batch of pointclouds lengths: (N,) number of points in each pointcloud (to support heterogeneous batches of pointclouds) - K: samples you want in each sampled point cloud (this is typically << P). If + K: samples required in each sampled point cloud (this is typically << P). If K is an int then the same number of samples are selected for each pointcloud in the batch. If K is a tensor is should be length (N,) giving the number of samples to select for each element in the batch @@ -52,6 +53,50 @@ def sample_farthest_points_naive( N, P, D = points.shape device = points.device + # Validate inputs + if lengths is None: + lengths = torch.full((N,), P, dtype=torch.int64, device=device) + + if lengths.shape[0] != N: + raise ValueError("points and lengths must have same batch dimension.") + + # TODO: support providing K as a ratio of the total number of points instead of as an int + if isinstance(K, int): + K = torch.full((N,), K, dtype=torch.int64, device=device) + elif isinstance(K, list): + K = torch.tensor(K, dtype=torch.int64, device=device) + + if K.shape[0] != N: + raise ValueError("K and points must have the same batch dimension") + + # Check dtypes are correct and convert if necessary + if not (points.dtype == torch.float32): + points = points.to(torch.float32) + if not (lengths.dtype == torch.int64): + lengths = lengths.to(torch.int64) + if not (K.dtype == torch.int64): + K = K.to(torch.int64) + + with torch.no_grad(): + # pyre-fixme[16]: `pytorch3d_._C` has no attribute `sample_farthest_points`. + idx = _C.sample_farthest_points(points, lengths, K, random_start_point) + sampled_points = masked_gather(points, idx) + + return sampled_points, idx + + +def sample_farthest_points_naive( + points: torch.Tensor, + lengths: Optional[torch.Tensor] = None, + K: Union[int, List, torch.Tensor] = 50, + random_start_point: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Same Args/Returns as sample_farthest_points + """ + N, P, D = points.shape + device = points.device + # Validate inputs if lengths is None: lengths = torch.full((N,), P, dtype=torch.int64, device=device) diff --git a/tests/bm_sample_farthest_points.py b/tests/bm_sample_farthest_points.py new file mode 100644 index 00000000..5108d1d1 --- /dev/null +++ b/tests/bm_sample_farthest_points.py @@ -0,0 +1,37 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import product + +from fvcore.common.benchmark import benchmark +from test_sample_farthest_points import TestFPS + + +def bm_fps() -> None: + kwargs_list = [] + backends = ["cpu", "cuda:0"] + Ns = [8, 32] + Ps = [64, 256] + Ds = [3] + Ks = [24] + test_cases = product(Ns, Ps, Ds, Ks, backends) + for case in test_cases: + N, P, D, K, d = case + kwargs_list.append({"N": N, "P": P, "D": D, "K": K, "device": d}) + + benchmark( + TestFPS.sample_farthest_points_naive, + "FPS_NAIVE_PYTHON", + kwargs_list, + warmup_iters=1, + ) + + kwargs_list = [k for k in kwargs_list if k["device"] == "cpu"] + benchmark(TestFPS.sample_farthest_points, "FPS_CPU", kwargs_list, warmup_iters=1) + + +if __name__ == "__main__": + bm_fps() diff --git a/tests/test_sample_farthest_points.py b/tests/test_sample_farthest_points.py index 54b55f1c..7b071b18 100644 --- a/tests/test_sample_farthest_points.py +++ b/tests/test_sample_farthest_points.py @@ -8,13 +8,15 @@ import unittest import torch from common_testing import TestCaseMixin, get_random_cuda_device -from pytorch3d.ops.sample_farthest_points import sample_farthest_points_naive +from pytorch3d.ops.sample_farthest_points import ( + sample_farthest_points_naive, + sample_farthest_points, +) from pytorch3d.ops.utils import masked_gather class TestFPS(TestCaseMixin, unittest.TestCase): - def test_simple(self): - device = get_random_cuda_device() + def _test_simple(self, fps_func, device="cpu"): # fmt: off points = torch.tensor( [ @@ -44,7 +46,7 @@ class TestFPS(TestCaseMixin, unittest.TestCase): ) # fmt: on expected_inds = torch.tensor([[0, 4], [0, 7]], dtype=torch.int64, device=device) - out_points, out_inds = sample_farthest_points_naive(points, K=2) + out_points, out_inds = fps_func(points, K=2) self.assertClose(out_inds, expected_inds) # Gather the points @@ -55,24 +57,37 @@ class TestFPS(TestCaseMixin, unittest.TestCase): expected_inds = torch.tensor( [[0, 4, 1], [0, 7, -1]], dtype=torch.int64, device=device ) - out_points, out_inds = sample_farthest_points_naive(points, K=[3, 2]) + out_points, out_inds = fps_func(points, K=[3, 2]) self.assertClose(out_inds, expected_inds) # Gather the points expected_points = masked_gather(points, expected_inds) self.assertClose(out_points, expected_points) - def test_random_heterogeneous(self): - device = get_random_cuda_device() - N, P, D, K = 5, 40, 5, 8 - points = torch.randn((N, P, D), device=device) - out_points, out_idxs = sample_farthest_points_naive(points, K=K) + def _test_compare_random_heterogeneous(self, device="cpu"): + N, P, D, K = 5, 20, 5, 8 + points = torch.randn((N, P, D), device=device, dtype=torch.float32) + out_points_naive, out_idxs_naive = sample_farthest_points_naive(points, K=K) + out_points, out_idxs = sample_farthest_points(points, K=K) self.assertTrue(out_idxs.min() >= 0) + self.assertClose(out_idxs, out_idxs_naive) + self.assertClose(out_points, out_points_naive) for n in range(N): self.assertEqual(out_idxs[n].ne(-1).sum(), K) + # Test case where K > P + K = 30 + points1 = torch.randn((N, P, D), dtype=torch.float32, device=device) + points2 = points1.clone() + points1.requires_grad = True + points2.requires_grad = True lengths = torch.randint(low=1, high=P, size=(N,), device=device) - out_points, out_idxs = sample_farthest_points_naive(points, lengths, K=50) + out_points_naive, out_idxs_naive = sample_farthest_points_naive( + points1, lengths, K=K + ) + out_points, out_idxs = sample_farthest_points(points2, lengths, K=K) + self.assertClose(out_idxs, out_idxs_naive) + self.assertClose(out_points, out_points_naive) for n in range(N): # Check that for heterogeneous batches, the max number of @@ -85,8 +100,15 @@ class TestFPS(TestCaseMixin, unittest.TestCase): vals, counts = torch.unique(out_idxs[n][val_mask], return_counts=True) self.assertTrue(counts.le(1).all()) - def test_errors(self): - device = get_random_cuda_device() + # Check gradients + grad_sampled_points = torch.ones((N, K, D), dtype=torch.float32, device=device) + loss1 = (out_points_naive * grad_sampled_points).sum() + loss1.backward() + loss2 = (out_points * grad_sampled_points).sum() + loss2.backward() + self.assertClose(points1.grad, points2.grad, atol=5e-6) + + def _test_errors(self, fps_func, device="cpu"): N, P, D, K = 5, 40, 5, 8 points = torch.randn((N, P, D), device=device) wrong_batch_dim = torch.randint(low=1, high=K, size=(K,), device=device) @@ -99,8 +121,7 @@ class TestFPS(TestCaseMixin, unittest.TestCase): with self.assertRaisesRegex(ValueError, "points and lengths must have"): sample_farthest_points_naive(points, lengths=wrong_batch_dim, K=K) - def test_random_start(self): - device = get_random_cuda_device() + def _test_random_start(self, fps_func, device="cpu"): N, P, D, K = 5, 40, 5, 8 points = torch.randn((N, P, D), device=device) out_points, out_idxs = sample_farthest_points_naive( @@ -109,3 +130,64 @@ class TestFPS(TestCaseMixin, unittest.TestCase): # Check the first index is not 0 for all batch elements # when random_start_point = True self.assertTrue(out_idxs[:, 0].sum() > 0) + + def _test_gradcheck(self, fps_func, device="cpu"): + N, P, D, K = 2, 5, 3, 2 + points = torch.randn( + (N, P, D), dtype=torch.float32, device=device, requires_grad=True + ) + torch.autograd.gradcheck( + fps_func, + (points, None, K), + check_undefined_grad=False, + eps=2e-3, + atol=0.001, + ) + + def test_sample_farthest_points_naive(self): + device = get_random_cuda_device() + self._test_simple(sample_farthest_points_naive, device) + self._test_errors(sample_farthest_points_naive, device) + self._test_random_start(sample_farthest_points_naive, device) + self._test_gradcheck(sample_farthest_points_naive, device) + + def test_sample_farthest_points_cpu(self): + self._test_simple(sample_farthest_points, "cpu") + self._test_errors(sample_farthest_points, "cpu") + self._test_compare_random_heterogeneous("cpu") + self._test_random_start(sample_farthest_points, "cpu") + self._test_gradcheck(sample_farthest_points, "cpu") + + @staticmethod + def sample_farthest_points_naive(N: int, P: int, D: int, K: int, device: str): + device = torch.device(device) + pts = torch.randn( + N, P, D, dtype=torch.float32, device=device, requires_grad=True + ) + grad_pts = torch.randn(N, K, D, dtype=torch.float32, device=device) + torch.cuda.synchronize() + + def output(): + out_points, _ = sample_farthest_points_naive(pts, K=K) + loss = (out_points * grad_pts).sum() + loss.backward() + torch.cuda.synchronize() + + return output + + @staticmethod + def sample_farthest_points(N: int, P: int, D: int, K: int, device: str): + device = torch.device(device) + pts = torch.randn( + N, P, D, dtype=torch.float32, device=device, requires_grad=True + ) + grad_pts = torch.randn(N, K, D, dtype=torch.float32, device=device) + torch.cuda.synchronize() + + def output(): + out_points, _ = sample_farthest_points(pts, K=K) + loss = (out_points * grad_pts).sum() + loss.backward() + torch.cuda.synchronize() + + return output