Farthest point sampling C++

Summary: C++ implementation of iterative farthest point sampling.

Reviewed By: jcjohnson

Differential Revision: D30349887

fbshipit-source-id: d25990f857752633859fe00283e182858a870269
This commit is contained in:
Nikhila Ravi
2021-09-15 13:47:55 -07:00
committed by Facebook GitHub Bot
parent 3b7d78c7a7
commit d9f7611c4b
6 changed files with 346 additions and 19 deletions

View File

@@ -27,6 +27,7 @@
#include "rasterize_meshes/rasterize_meshes.h"
#include "rasterize_points/rasterize_points.h"
#include "sample_pdf/sample_pdf.h"
#include "sample_farthest_points/sample_farthest_points.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("face_areas_normals_forward", &FaceAreasNormalsForward);
@@ -40,9 +41,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#endif
m.def("knn_points_idx", &KNearestNeighborIdx);
m.def("knn_points_backward", &KNearestNeighborBackward);
// Ball Query
m.def("ball_query", &BallQuery);
m.def("sample_farthest_points", &FarthestPointSampling);
m.def(
"mesh_normal_consistency_find_verts", &MeshNormalConsistencyFindVertices);
m.def("gather_scatter", &GatherScatter);

View File

@@ -0,0 +1,107 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <torch/extension.h>
#include <iterator>
#include <random>
#include <vector>
at::Tensor FarthestPointSamplingCpu(
const at::Tensor& points,
const at::Tensor& lengths,
const at::Tensor& K,
const bool random_start_point) {
// Get constants
const int64_t N = points.size(0);
const int64_t P = points.size(1);
const int64_t D = points.size(2);
const int64_t max_K = torch::max(K).item<int64_t>();
// Initialize an output array for the sampled indices
// of shape (N, max_K)
auto long_opts = lengths.options();
torch::Tensor sampled_indices = torch::full({N, max_K}, -1, long_opts);
// Create accessors for all tensors
auto points_a = points.accessor<float, 3>();
auto lengths_a = lengths.accessor<int64_t, 1>();
auto k_a = K.accessor<int64_t, 1>();
auto sampled_indices_a = sampled_indices.accessor<int64_t, 2>();
// Initialize a mask to prevent duplicates
// If true, the point has already been selected.
std::vector<unsigned char> selected_points_mask(P, false);
// Initialize to infinity a vector of
// distances from each point to any of the previously selected points
std::vector<float> dists(P, std::numeric_limits<float>::max());
// Initialize random number generation for random starting points
std::random_device rd;
std::default_random_engine eng(rd());
for (int64_t n = 0; n < N; ++n) {
// Resize and reset points mask and distances for each batch
selected_points_mask.resize(lengths_a[n]);
dists.resize(lengths_a[n]);
std::fill(selected_points_mask.begin(), selected_points_mask.end(), false);
std::fill(dists.begin(), dists.end(), std::numeric_limits<float>::max());
// Select a starting point index and save it
std::uniform_int_distribution<int> distr(0, lengths_a[n] - 1);
int64_t last_idx = random_start_point ? distr(eng) : 0;
sampled_indices_a[n][0] = last_idx;
// Set the value of the mask at this point to false
selected_points_mask[last_idx] = true;
// For heterogeneous pointclouds, use the minimum of the
// length for that cloud compared to K as the number of
// points to sample
const int64_t batch_k = std::min(lengths_a[n], k_a[n]);
// Iteratively select batch_k points per batch
for (int64_t k = 1; k < batch_k; ++k) {
// Iterate through all the points
for (int64_t p = 0; p < lengths_a[n]; ++p) {
if (selected_points_mask[p]) {
// For already selected points set the distance to 0.0
dists[p] = 0.0;
continue;
}
// Calculate the distance to the last selected point
float dist2 = 0.0;
for (int64_t d = 0; d < D; ++d) {
float diff = points_a[n][last_idx][d] - points_a[n][p][d];
dist2 += diff * diff;
}
// If the distance of this point to the last selected point is closer
// than the distance to any of the previously selected points, then
// update this distance
if (dist2 < dists[p]) {
dists[p] = dist2;
}
}
// The aim is to pick the point that has the largest
// nearest neighbour distance to any of the already selected points
auto itr = std::max_element(dists.begin(), dists.end());
last_idx = std::distance(dists.begin(), itr);
// Save selected point
sampled_indices_a[n][k] = last_idx;
// Set the mask value to true to prevent duplicates.
selected_points_mask[last_idx] = true;
}
}
return sampled_indices;
}

View File

@@ -0,0 +1,56 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <torch/extension.h>
#include <tuple>
#include "utils/pytorch3d_cutils.h"
// Iterative farthest point sampling algorithm [1] to subsample a set of
// K points from a given pointcloud. At each iteration, a point is selected
// which has the largest nearest neighbor distance to any of the
// already selected points.
// Farthest point sampling provides more uniform coverage of the input
// point cloud compared to uniform random sampling.
// [1] Charles R. Qi et al, "PointNet++: Deep Hierarchical Feature Learning
// on Point Sets in a Metric Space", NeurIPS 2017.
// Args:
// points: (N, P, D) float32 Tensor containing the batch of pointclouds.
// lengths: (N,) long Tensor giving the number of points in each pointcloud
// (to support heterogeneous batches of pointclouds).
// K: a tensor of length (N,) giving the number of
// samples to select for each element in the batch.
// The number of samples is typically << P.
// random_start_point: bool, if True, a random point is selected as the
// starting point for iterative sampling.
// Returns:
// selected_indices: (N, K) array of selected indices. If the values in
// K are not all the same, then the shape will be (N, max(K), D), and
// padded with -1 for batch elements where k_i < max(K). The selected
// points are gathered in the pytorch autograd wrapper.
at::Tensor FarthestPointSamplingCpu(
const at::Tensor& points,
const at::Tensor& lengths,
const at::Tensor& K,
const bool random_start_point);
// Exposed implementation.
at::Tensor FarthestPointSampling(
const at::Tensor& points,
const at::Tensor& lengths,
const at::Tensor& K,
const bool random_start_point) {
if (points.is_cuda() || lengths.is_cuda() || K.is_cuda()) {
AT_ERROR("CUDA implementation not yet supported");
}
return FarthestPointSamplingCpu(points, lengths, K, random_start_point);
}