Farthest point sampling C++

Summary: C++ implementation of iterative farthest point sampling.

Reviewed By: jcjohnson

Differential Revision: D30349887

fbshipit-source-id: d25990f857752633859fe00283e182858a870269
This commit is contained in:
Nikhila Ravi 2021-09-15 13:47:55 -07:00 committed by Facebook GitHub Bot
parent 3b7d78c7a7
commit d9f7611c4b
6 changed files with 346 additions and 19 deletions

View File

@ -27,6 +27,7 @@
#include "rasterize_meshes/rasterize_meshes.h" #include "rasterize_meshes/rasterize_meshes.h"
#include "rasterize_points/rasterize_points.h" #include "rasterize_points/rasterize_points.h"
#include "sample_pdf/sample_pdf.h" #include "sample_pdf/sample_pdf.h"
#include "sample_farthest_points/sample_farthest_points.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("face_areas_normals_forward", &FaceAreasNormalsForward); m.def("face_areas_normals_forward", &FaceAreasNormalsForward);
@ -40,9 +41,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#endif #endif
m.def("knn_points_idx", &KNearestNeighborIdx); m.def("knn_points_idx", &KNearestNeighborIdx);
m.def("knn_points_backward", &KNearestNeighborBackward); m.def("knn_points_backward", &KNearestNeighborBackward);
// Ball Query
m.def("ball_query", &BallQuery); m.def("ball_query", &BallQuery);
m.def("sample_farthest_points", &FarthestPointSampling);
m.def( m.def(
"mesh_normal_consistency_find_verts", &MeshNormalConsistencyFindVertices); "mesh_normal_consistency_find_verts", &MeshNormalConsistencyFindVertices);
m.def("gather_scatter", &GatherScatter); m.def("gather_scatter", &GatherScatter);

View File

@ -0,0 +1,107 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <torch/extension.h>
#include <iterator>
#include <random>
#include <vector>
at::Tensor FarthestPointSamplingCpu(
const at::Tensor& points,
const at::Tensor& lengths,
const at::Tensor& K,
const bool random_start_point) {
// Get constants
const int64_t N = points.size(0);
const int64_t P = points.size(1);
const int64_t D = points.size(2);
const int64_t max_K = torch::max(K).item<int64_t>();
// Initialize an output array for the sampled indices
// of shape (N, max_K)
auto long_opts = lengths.options();
torch::Tensor sampled_indices = torch::full({N, max_K}, -1, long_opts);
// Create accessors for all tensors
auto points_a = points.accessor<float, 3>();
auto lengths_a = lengths.accessor<int64_t, 1>();
auto k_a = K.accessor<int64_t, 1>();
auto sampled_indices_a = sampled_indices.accessor<int64_t, 2>();
// Initialize a mask to prevent duplicates
// If true, the point has already been selected.
std::vector<unsigned char> selected_points_mask(P, false);
// Initialize to infinity a vector of
// distances from each point to any of the previously selected points
std::vector<float> dists(P, std::numeric_limits<float>::max());
// Initialize random number generation for random starting points
std::random_device rd;
std::default_random_engine eng(rd());
for (int64_t n = 0; n < N; ++n) {
// Resize and reset points mask and distances for each batch
selected_points_mask.resize(lengths_a[n]);
dists.resize(lengths_a[n]);
std::fill(selected_points_mask.begin(), selected_points_mask.end(), false);
std::fill(dists.begin(), dists.end(), std::numeric_limits<float>::max());
// Select a starting point index and save it
std::uniform_int_distribution<int> distr(0, lengths_a[n] - 1);
int64_t last_idx = random_start_point ? distr(eng) : 0;
sampled_indices_a[n][0] = last_idx;
// Set the value of the mask at this point to false
selected_points_mask[last_idx] = true;
// For heterogeneous pointclouds, use the minimum of the
// length for that cloud compared to K as the number of
// points to sample
const int64_t batch_k = std::min(lengths_a[n], k_a[n]);
// Iteratively select batch_k points per batch
for (int64_t k = 1; k < batch_k; ++k) {
// Iterate through all the points
for (int64_t p = 0; p < lengths_a[n]; ++p) {
if (selected_points_mask[p]) {
// For already selected points set the distance to 0.0
dists[p] = 0.0;
continue;
}
// Calculate the distance to the last selected point
float dist2 = 0.0;
for (int64_t d = 0; d < D; ++d) {
float diff = points_a[n][last_idx][d] - points_a[n][p][d];
dist2 += diff * diff;
}
// If the distance of this point to the last selected point is closer
// than the distance to any of the previously selected points, then
// update this distance
if (dist2 < dists[p]) {
dists[p] = dist2;
}
}
// The aim is to pick the point that has the largest
// nearest neighbour distance to any of the already selected points
auto itr = std::max_element(dists.begin(), dists.end());
last_idx = std::distance(dists.begin(), itr);
// Save selected point
sampled_indices_a[n][k] = last_idx;
// Set the mask value to true to prevent duplicates.
selected_points_mask[last_idx] = true;
}
}
return sampled_indices;
}

View File

@ -0,0 +1,56 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <torch/extension.h>
#include <tuple>
#include "utils/pytorch3d_cutils.h"
// Iterative farthest point sampling algorithm [1] to subsample a set of
// K points from a given pointcloud. At each iteration, a point is selected
// which has the largest nearest neighbor distance to any of the
// already selected points.
// Farthest point sampling provides more uniform coverage of the input
// point cloud compared to uniform random sampling.
// [1] Charles R. Qi et al, "PointNet++: Deep Hierarchical Feature Learning
// on Point Sets in a Metric Space", NeurIPS 2017.
// Args:
// points: (N, P, D) float32 Tensor containing the batch of pointclouds.
// lengths: (N,) long Tensor giving the number of points in each pointcloud
// (to support heterogeneous batches of pointclouds).
// K: a tensor of length (N,) giving the number of
// samples to select for each element in the batch.
// The number of samples is typically << P.
// random_start_point: bool, if True, a random point is selected as the
// starting point for iterative sampling.
// Returns:
// selected_indices: (N, K) array of selected indices. If the values in
// K are not all the same, then the shape will be (N, max(K), D), and
// padded with -1 for batch elements where k_i < max(K). The selected
// points are gathered in the pytorch autograd wrapper.
at::Tensor FarthestPointSamplingCpu(
const at::Tensor& points,
const at::Tensor& lengths,
const at::Tensor& K,
const bool random_start_point);
// Exposed implementation.
at::Tensor FarthestPointSampling(
const at::Tensor& points,
const at::Tensor& lengths,
const at::Tensor& K,
const bool random_start_point) {
if (points.is_cuda() || lengths.is_cuda() || K.is_cuda()) {
AT_ERROR("CUDA implementation not yet supported");
}
return FarthestPointSamplingCpu(points, lengths, K, random_start_point);
}

View File

@ -8,11 +8,12 @@ from random import randint
from typing import Optional, Tuple, Union, List from typing import Optional, Tuple, Union, List
import torch import torch
from pytorch3d import _C
from .utils import masked_gather from .utils import masked_gather
def sample_farthest_points_naive( def sample_farthest_points(
points: torch.Tensor, points: torch.Tensor,
lengths: Optional[torch.Tensor] = None, lengths: Optional[torch.Tensor] = None,
K: Union[int, List, torch.Tensor] = 50, K: Union[int, List, torch.Tensor] = 50,
@ -34,7 +35,7 @@ def sample_farthest_points_naive(
points: (N, P, D) array containing the batch of pointclouds points: (N, P, D) array containing the batch of pointclouds
lengths: (N,) number of points in each pointcloud (to support heterogeneous lengths: (N,) number of points in each pointcloud (to support heterogeneous
batches of pointclouds) batches of pointclouds)
K: samples you want in each sampled point cloud (this is typically << P). If K: samples required in each sampled point cloud (this is typically << P). If
K is an int then the same number of samples are selected for each K is an int then the same number of samples are selected for each
pointcloud in the batch. If K is a tensor is should be length (N,) pointcloud in the batch. If K is a tensor is should be length (N,)
giving the number of samples to select for each element in the batch giving the number of samples to select for each element in the batch
@ -52,6 +53,50 @@ def sample_farthest_points_naive(
N, P, D = points.shape N, P, D = points.shape
device = points.device device = points.device
# Validate inputs
if lengths is None:
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
if lengths.shape[0] != N:
raise ValueError("points and lengths must have same batch dimension.")
# TODO: support providing K as a ratio of the total number of points instead of as an int
if isinstance(K, int):
K = torch.full((N,), K, dtype=torch.int64, device=device)
elif isinstance(K, list):
K = torch.tensor(K, dtype=torch.int64, device=device)
if K.shape[0] != N:
raise ValueError("K and points must have the same batch dimension")
# Check dtypes are correct and convert if necessary
if not (points.dtype == torch.float32):
points = points.to(torch.float32)
if not (lengths.dtype == torch.int64):
lengths = lengths.to(torch.int64)
if not (K.dtype == torch.int64):
K = K.to(torch.int64)
with torch.no_grad():
# pyre-fixme[16]: `pytorch3d_._C` has no attribute `sample_farthest_points`.
idx = _C.sample_farthest_points(points, lengths, K, random_start_point)
sampled_points = masked_gather(points, idx)
return sampled_points, idx
def sample_farthest_points_naive(
points: torch.Tensor,
lengths: Optional[torch.Tensor] = None,
K: Union[int, List, torch.Tensor] = 50,
random_start_point: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Same Args/Returns as sample_farthest_points
"""
N, P, D = points.shape
device = points.device
# Validate inputs # Validate inputs
if lengths is None: if lengths is None:
lengths = torch.full((N,), P, dtype=torch.int64, device=device) lengths = torch.full((N,), P, dtype=torch.int64, device=device)

View File

@ -0,0 +1,37 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from itertools import product
from fvcore.common.benchmark import benchmark
from test_sample_farthest_points import TestFPS
def bm_fps() -> None:
kwargs_list = []
backends = ["cpu", "cuda:0"]
Ns = [8, 32]
Ps = [64, 256]
Ds = [3]
Ks = [24]
test_cases = product(Ns, Ps, Ds, Ks, backends)
for case in test_cases:
N, P, D, K, d = case
kwargs_list.append({"N": N, "P": P, "D": D, "K": K, "device": d})
benchmark(
TestFPS.sample_farthest_points_naive,
"FPS_NAIVE_PYTHON",
kwargs_list,
warmup_iters=1,
)
kwargs_list = [k for k in kwargs_list if k["device"] == "cpu"]
benchmark(TestFPS.sample_farthest_points, "FPS_CPU", kwargs_list, warmup_iters=1)
if __name__ == "__main__":
bm_fps()

View File

@ -8,13 +8,15 @@ import unittest
import torch import torch
from common_testing import TestCaseMixin, get_random_cuda_device from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d.ops.sample_farthest_points import sample_farthest_points_naive from pytorch3d.ops.sample_farthest_points import (
sample_farthest_points_naive,
sample_farthest_points,
)
from pytorch3d.ops.utils import masked_gather from pytorch3d.ops.utils import masked_gather
class TestFPS(TestCaseMixin, unittest.TestCase): class TestFPS(TestCaseMixin, unittest.TestCase):
def test_simple(self): def _test_simple(self, fps_func, device="cpu"):
device = get_random_cuda_device()
# fmt: off # fmt: off
points = torch.tensor( points = torch.tensor(
[ [
@ -44,7 +46,7 @@ class TestFPS(TestCaseMixin, unittest.TestCase):
) )
# fmt: on # fmt: on
expected_inds = torch.tensor([[0, 4], [0, 7]], dtype=torch.int64, device=device) expected_inds = torch.tensor([[0, 4], [0, 7]], dtype=torch.int64, device=device)
out_points, out_inds = sample_farthest_points_naive(points, K=2) out_points, out_inds = fps_func(points, K=2)
self.assertClose(out_inds, expected_inds) self.assertClose(out_inds, expected_inds)
# Gather the points # Gather the points
@ -55,24 +57,37 @@ class TestFPS(TestCaseMixin, unittest.TestCase):
expected_inds = torch.tensor( expected_inds = torch.tensor(
[[0, 4, 1], [0, 7, -1]], dtype=torch.int64, device=device [[0, 4, 1], [0, 7, -1]], dtype=torch.int64, device=device
) )
out_points, out_inds = sample_farthest_points_naive(points, K=[3, 2]) out_points, out_inds = fps_func(points, K=[3, 2])
self.assertClose(out_inds, expected_inds) self.assertClose(out_inds, expected_inds)
# Gather the points # Gather the points
expected_points = masked_gather(points, expected_inds) expected_points = masked_gather(points, expected_inds)
self.assertClose(out_points, expected_points) self.assertClose(out_points, expected_points)
def test_random_heterogeneous(self): def _test_compare_random_heterogeneous(self, device="cpu"):
device = get_random_cuda_device() N, P, D, K = 5, 20, 5, 8
N, P, D, K = 5, 40, 5, 8 points = torch.randn((N, P, D), device=device, dtype=torch.float32)
points = torch.randn((N, P, D), device=device) out_points_naive, out_idxs_naive = sample_farthest_points_naive(points, K=K)
out_points, out_idxs = sample_farthest_points_naive(points, K=K) out_points, out_idxs = sample_farthest_points(points, K=K)
self.assertTrue(out_idxs.min() >= 0) self.assertTrue(out_idxs.min() >= 0)
self.assertClose(out_idxs, out_idxs_naive)
self.assertClose(out_points, out_points_naive)
for n in range(N): for n in range(N):
self.assertEqual(out_idxs[n].ne(-1).sum(), K) self.assertEqual(out_idxs[n].ne(-1).sum(), K)
# Test case where K > P
K = 30
points1 = torch.randn((N, P, D), dtype=torch.float32, device=device)
points2 = points1.clone()
points1.requires_grad = True
points2.requires_grad = True
lengths = torch.randint(low=1, high=P, size=(N,), device=device) lengths = torch.randint(low=1, high=P, size=(N,), device=device)
out_points, out_idxs = sample_farthest_points_naive(points, lengths, K=50) out_points_naive, out_idxs_naive = sample_farthest_points_naive(
points1, lengths, K=K
)
out_points, out_idxs = sample_farthest_points(points2, lengths, K=K)
self.assertClose(out_idxs, out_idxs_naive)
self.assertClose(out_points, out_points_naive)
for n in range(N): for n in range(N):
# Check that for heterogeneous batches, the max number of # Check that for heterogeneous batches, the max number of
@ -85,8 +100,15 @@ class TestFPS(TestCaseMixin, unittest.TestCase):
vals, counts = torch.unique(out_idxs[n][val_mask], return_counts=True) vals, counts = torch.unique(out_idxs[n][val_mask], return_counts=True)
self.assertTrue(counts.le(1).all()) self.assertTrue(counts.le(1).all())
def test_errors(self): # Check gradients
device = get_random_cuda_device() grad_sampled_points = torch.ones((N, K, D), dtype=torch.float32, device=device)
loss1 = (out_points_naive * grad_sampled_points).sum()
loss1.backward()
loss2 = (out_points * grad_sampled_points).sum()
loss2.backward()
self.assertClose(points1.grad, points2.grad, atol=5e-6)
def _test_errors(self, fps_func, device="cpu"):
N, P, D, K = 5, 40, 5, 8 N, P, D, K = 5, 40, 5, 8
points = torch.randn((N, P, D), device=device) points = torch.randn((N, P, D), device=device)
wrong_batch_dim = torch.randint(low=1, high=K, size=(K,), device=device) wrong_batch_dim = torch.randint(low=1, high=K, size=(K,), device=device)
@ -99,8 +121,7 @@ class TestFPS(TestCaseMixin, unittest.TestCase):
with self.assertRaisesRegex(ValueError, "points and lengths must have"): with self.assertRaisesRegex(ValueError, "points and lengths must have"):
sample_farthest_points_naive(points, lengths=wrong_batch_dim, K=K) sample_farthest_points_naive(points, lengths=wrong_batch_dim, K=K)
def test_random_start(self): def _test_random_start(self, fps_func, device="cpu"):
device = get_random_cuda_device()
N, P, D, K = 5, 40, 5, 8 N, P, D, K = 5, 40, 5, 8
points = torch.randn((N, P, D), device=device) points = torch.randn((N, P, D), device=device)
out_points, out_idxs = sample_farthest_points_naive( out_points, out_idxs = sample_farthest_points_naive(
@ -109,3 +130,64 @@ class TestFPS(TestCaseMixin, unittest.TestCase):
# Check the first index is not 0 for all batch elements # Check the first index is not 0 for all batch elements
# when random_start_point = True # when random_start_point = True
self.assertTrue(out_idxs[:, 0].sum() > 0) self.assertTrue(out_idxs[:, 0].sum() > 0)
def _test_gradcheck(self, fps_func, device="cpu"):
N, P, D, K = 2, 5, 3, 2
points = torch.randn(
(N, P, D), dtype=torch.float32, device=device, requires_grad=True
)
torch.autograd.gradcheck(
fps_func,
(points, None, K),
check_undefined_grad=False,
eps=2e-3,
atol=0.001,
)
def test_sample_farthest_points_naive(self):
device = get_random_cuda_device()
self._test_simple(sample_farthest_points_naive, device)
self._test_errors(sample_farthest_points_naive, device)
self._test_random_start(sample_farthest_points_naive, device)
self._test_gradcheck(sample_farthest_points_naive, device)
def test_sample_farthest_points_cpu(self):
self._test_simple(sample_farthest_points, "cpu")
self._test_errors(sample_farthest_points, "cpu")
self._test_compare_random_heterogeneous("cpu")
self._test_random_start(sample_farthest_points, "cpu")
self._test_gradcheck(sample_farthest_points, "cpu")
@staticmethod
def sample_farthest_points_naive(N: int, P: int, D: int, K: int, device: str):
device = torch.device(device)
pts = torch.randn(
N, P, D, dtype=torch.float32, device=device, requires_grad=True
)
grad_pts = torch.randn(N, K, D, dtype=torch.float32, device=device)
torch.cuda.synchronize()
def output():
out_points, _ = sample_farthest_points_naive(pts, K=K)
loss = (out_points * grad_pts).sum()
loss.backward()
torch.cuda.synchronize()
return output
@staticmethod
def sample_farthest_points(N: int, P: int, D: int, K: int, device: str):
device = torch.device(device)
pts = torch.randn(
N, P, D, dtype=torch.float32, device=device, requires_grad=True
)
grad_pts = torch.randn(N, K, D, dtype=torch.float32, device=device)
torch.cuda.synchronize()
def output():
out_points, _ = sample_farthest_points(pts, K=K)
loss = (out_points * grad_pts).sum()
loss.backward()
torch.cuda.synchronize()
return output