mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Farthest point sampling C++
Summary: C++ implementation of iterative farthest point sampling. Reviewed By: jcjohnson Differential Revision: D30349887 fbshipit-source-id: d25990f857752633859fe00283e182858a870269
This commit is contained in:
parent
3b7d78c7a7
commit
d9f7611c4b
@ -27,6 +27,7 @@
|
||||
#include "rasterize_meshes/rasterize_meshes.h"
|
||||
#include "rasterize_points/rasterize_points.h"
|
||||
#include "sample_pdf/sample_pdf.h"
|
||||
#include "sample_farthest_points/sample_farthest_points.h"
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("face_areas_normals_forward", &FaceAreasNormalsForward);
|
||||
@ -40,9 +41,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
#endif
|
||||
m.def("knn_points_idx", &KNearestNeighborIdx);
|
||||
m.def("knn_points_backward", &KNearestNeighborBackward);
|
||||
|
||||
// Ball Query
|
||||
m.def("ball_query", &BallQuery);
|
||||
m.def("sample_farthest_points", &FarthestPointSampling);
|
||||
m.def(
|
||||
"mesh_normal_consistency_find_verts", &MeshNormalConsistencyFindVertices);
|
||||
m.def("gather_scatter", &GatherScatter);
|
||||
|
107
pytorch3d/csrc/sample_farthest_points/sample_farthest_points.cpp
Normal file
107
pytorch3d/csrc/sample_farthest_points/sample_farthest_points.cpp
Normal file
@ -0,0 +1,107 @@
|
||||
/*
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <iterator>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
at::Tensor FarthestPointSamplingCpu(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& lengths,
|
||||
const at::Tensor& K,
|
||||
const bool random_start_point) {
|
||||
// Get constants
|
||||
const int64_t N = points.size(0);
|
||||
const int64_t P = points.size(1);
|
||||
const int64_t D = points.size(2);
|
||||
const int64_t max_K = torch::max(K).item<int64_t>();
|
||||
|
||||
// Initialize an output array for the sampled indices
|
||||
// of shape (N, max_K)
|
||||
auto long_opts = lengths.options();
|
||||
torch::Tensor sampled_indices = torch::full({N, max_K}, -1, long_opts);
|
||||
|
||||
// Create accessors for all tensors
|
||||
auto points_a = points.accessor<float, 3>();
|
||||
auto lengths_a = lengths.accessor<int64_t, 1>();
|
||||
auto k_a = K.accessor<int64_t, 1>();
|
||||
auto sampled_indices_a = sampled_indices.accessor<int64_t, 2>();
|
||||
|
||||
// Initialize a mask to prevent duplicates
|
||||
// If true, the point has already been selected.
|
||||
std::vector<unsigned char> selected_points_mask(P, false);
|
||||
|
||||
// Initialize to infinity a vector of
|
||||
// distances from each point to any of the previously selected points
|
||||
std::vector<float> dists(P, std::numeric_limits<float>::max());
|
||||
|
||||
// Initialize random number generation for random starting points
|
||||
std::random_device rd;
|
||||
std::default_random_engine eng(rd());
|
||||
|
||||
for (int64_t n = 0; n < N; ++n) {
|
||||
// Resize and reset points mask and distances for each batch
|
||||
selected_points_mask.resize(lengths_a[n]);
|
||||
dists.resize(lengths_a[n]);
|
||||
std::fill(selected_points_mask.begin(), selected_points_mask.end(), false);
|
||||
std::fill(dists.begin(), dists.end(), std::numeric_limits<float>::max());
|
||||
|
||||
// Select a starting point index and save it
|
||||
std::uniform_int_distribution<int> distr(0, lengths_a[n] - 1);
|
||||
int64_t last_idx = random_start_point ? distr(eng) : 0;
|
||||
sampled_indices_a[n][0] = last_idx;
|
||||
|
||||
// Set the value of the mask at this point to false
|
||||
selected_points_mask[last_idx] = true;
|
||||
|
||||
// For heterogeneous pointclouds, use the minimum of the
|
||||
// length for that cloud compared to K as the number of
|
||||
// points to sample
|
||||
const int64_t batch_k = std::min(lengths_a[n], k_a[n]);
|
||||
|
||||
// Iteratively select batch_k points per batch
|
||||
for (int64_t k = 1; k < batch_k; ++k) {
|
||||
// Iterate through all the points
|
||||
for (int64_t p = 0; p < lengths_a[n]; ++p) {
|
||||
if (selected_points_mask[p]) {
|
||||
// For already selected points set the distance to 0.0
|
||||
dists[p] = 0.0;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Calculate the distance to the last selected point
|
||||
float dist2 = 0.0;
|
||||
for (int64_t d = 0; d < D; ++d) {
|
||||
float diff = points_a[n][last_idx][d] - points_a[n][p][d];
|
||||
dist2 += diff * diff;
|
||||
}
|
||||
|
||||
// If the distance of this point to the last selected point is closer
|
||||
// than the distance to any of the previously selected points, then
|
||||
// update this distance
|
||||
if (dist2 < dists[p]) {
|
||||
dists[p] = dist2;
|
||||
}
|
||||
}
|
||||
|
||||
// The aim is to pick the point that has the largest
|
||||
// nearest neighbour distance to any of the already selected points
|
||||
auto itr = std::max_element(dists.begin(), dists.end());
|
||||
last_idx = std::distance(dists.begin(), itr);
|
||||
|
||||
// Save selected point
|
||||
sampled_indices_a[n][k] = last_idx;
|
||||
|
||||
// Set the mask value to true to prevent duplicates.
|
||||
selected_points_mask[last_idx] = true;
|
||||
}
|
||||
}
|
||||
|
||||
return sampled_indices;
|
||||
}
|
@ -0,0 +1,56 @@
|
||||
/*
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
#include <tuple>
|
||||
#include "utils/pytorch3d_cutils.h"
|
||||
|
||||
// Iterative farthest point sampling algorithm [1] to subsample a set of
|
||||
// K points from a given pointcloud. At each iteration, a point is selected
|
||||
// which has the largest nearest neighbor distance to any of the
|
||||
// already selected points.
|
||||
|
||||
// Farthest point sampling provides more uniform coverage of the input
|
||||
// point cloud compared to uniform random sampling.
|
||||
|
||||
// [1] Charles R. Qi et al, "PointNet++: Deep Hierarchical Feature Learning
|
||||
// on Point Sets in a Metric Space", NeurIPS 2017.
|
||||
|
||||
// Args:
|
||||
// points: (N, P, D) float32 Tensor containing the batch of pointclouds.
|
||||
// lengths: (N,) long Tensor giving the number of points in each pointcloud
|
||||
// (to support heterogeneous batches of pointclouds).
|
||||
// K: a tensor of length (N,) giving the number of
|
||||
// samples to select for each element in the batch.
|
||||
// The number of samples is typically << P.
|
||||
// random_start_point: bool, if True, a random point is selected as the
|
||||
// starting point for iterative sampling.
|
||||
// Returns:
|
||||
// selected_indices: (N, K) array of selected indices. If the values in
|
||||
// K are not all the same, then the shape will be (N, max(K), D), and
|
||||
// padded with -1 for batch elements where k_i < max(K). The selected
|
||||
// points are gathered in the pytorch autograd wrapper.
|
||||
|
||||
at::Tensor FarthestPointSamplingCpu(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& lengths,
|
||||
const at::Tensor& K,
|
||||
const bool random_start_point);
|
||||
|
||||
// Exposed implementation.
|
||||
at::Tensor FarthestPointSampling(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& lengths,
|
||||
const at::Tensor& K,
|
||||
const bool random_start_point) {
|
||||
if (points.is_cuda() || lengths.is_cuda() || K.is_cuda()) {
|
||||
AT_ERROR("CUDA implementation not yet supported");
|
||||
}
|
||||
return FarthestPointSamplingCpu(points, lengths, K, random_start_point);
|
||||
}
|
@ -8,11 +8,12 @@ from random import randint
|
||||
from typing import Optional, Tuple, Union, List
|
||||
|
||||
import torch
|
||||
from pytorch3d import _C
|
||||
|
||||
from .utils import masked_gather
|
||||
|
||||
|
||||
def sample_farthest_points_naive(
|
||||
def sample_farthest_points(
|
||||
points: torch.Tensor,
|
||||
lengths: Optional[torch.Tensor] = None,
|
||||
K: Union[int, List, torch.Tensor] = 50,
|
||||
@ -34,7 +35,7 @@ def sample_farthest_points_naive(
|
||||
points: (N, P, D) array containing the batch of pointclouds
|
||||
lengths: (N,) number of points in each pointcloud (to support heterogeneous
|
||||
batches of pointclouds)
|
||||
K: samples you want in each sampled point cloud (this is typically << P). If
|
||||
K: samples required in each sampled point cloud (this is typically << P). If
|
||||
K is an int then the same number of samples are selected for each
|
||||
pointcloud in the batch. If K is a tensor is should be length (N,)
|
||||
giving the number of samples to select for each element in the batch
|
||||
@ -52,6 +53,50 @@ def sample_farthest_points_naive(
|
||||
N, P, D = points.shape
|
||||
device = points.device
|
||||
|
||||
# Validate inputs
|
||||
if lengths is None:
|
||||
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
||||
|
||||
if lengths.shape[0] != N:
|
||||
raise ValueError("points and lengths must have same batch dimension.")
|
||||
|
||||
# TODO: support providing K as a ratio of the total number of points instead of as an int
|
||||
if isinstance(K, int):
|
||||
K = torch.full((N,), K, dtype=torch.int64, device=device)
|
||||
elif isinstance(K, list):
|
||||
K = torch.tensor(K, dtype=torch.int64, device=device)
|
||||
|
||||
if K.shape[0] != N:
|
||||
raise ValueError("K and points must have the same batch dimension")
|
||||
|
||||
# Check dtypes are correct and convert if necessary
|
||||
if not (points.dtype == torch.float32):
|
||||
points = points.to(torch.float32)
|
||||
if not (lengths.dtype == torch.int64):
|
||||
lengths = lengths.to(torch.int64)
|
||||
if not (K.dtype == torch.int64):
|
||||
K = K.to(torch.int64)
|
||||
|
||||
with torch.no_grad():
|
||||
# pyre-fixme[16]: `pytorch3d_._C` has no attribute `sample_farthest_points`.
|
||||
idx = _C.sample_farthest_points(points, lengths, K, random_start_point)
|
||||
sampled_points = masked_gather(points, idx)
|
||||
|
||||
return sampled_points, idx
|
||||
|
||||
|
||||
def sample_farthest_points_naive(
|
||||
points: torch.Tensor,
|
||||
lengths: Optional[torch.Tensor] = None,
|
||||
K: Union[int, List, torch.Tensor] = 50,
|
||||
random_start_point: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Same Args/Returns as sample_farthest_points
|
||||
"""
|
||||
N, P, D = points.shape
|
||||
device = points.device
|
||||
|
||||
# Validate inputs
|
||||
if lengths is None:
|
||||
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
||||
|
37
tests/bm_sample_farthest_points.py
Normal file
37
tests/bm_sample_farthest_points.py
Normal file
@ -0,0 +1,37 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from itertools import product
|
||||
|
||||
from fvcore.common.benchmark import benchmark
|
||||
from test_sample_farthest_points import TestFPS
|
||||
|
||||
|
||||
def bm_fps() -> None:
|
||||
kwargs_list = []
|
||||
backends = ["cpu", "cuda:0"]
|
||||
Ns = [8, 32]
|
||||
Ps = [64, 256]
|
||||
Ds = [3]
|
||||
Ks = [24]
|
||||
test_cases = product(Ns, Ps, Ds, Ks, backends)
|
||||
for case in test_cases:
|
||||
N, P, D, K, d = case
|
||||
kwargs_list.append({"N": N, "P": P, "D": D, "K": K, "device": d})
|
||||
|
||||
benchmark(
|
||||
TestFPS.sample_farthest_points_naive,
|
||||
"FPS_NAIVE_PYTHON",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
kwargs_list = [k for k in kwargs_list if k["device"] == "cpu"]
|
||||
benchmark(TestFPS.sample_farthest_points, "FPS_CPU", kwargs_list, warmup_iters=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bm_fps()
|
@ -8,13 +8,15 @@ import unittest
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d.ops.sample_farthest_points import sample_farthest_points_naive
|
||||
from pytorch3d.ops.sample_farthest_points import (
|
||||
sample_farthest_points_naive,
|
||||
sample_farthest_points,
|
||||
)
|
||||
from pytorch3d.ops.utils import masked_gather
|
||||
|
||||
|
||||
class TestFPS(TestCaseMixin, unittest.TestCase):
|
||||
def test_simple(self):
|
||||
device = get_random_cuda_device()
|
||||
def _test_simple(self, fps_func, device="cpu"):
|
||||
# fmt: off
|
||||
points = torch.tensor(
|
||||
[
|
||||
@ -44,7 +46,7 @@ class TestFPS(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
# fmt: on
|
||||
expected_inds = torch.tensor([[0, 4], [0, 7]], dtype=torch.int64, device=device)
|
||||
out_points, out_inds = sample_farthest_points_naive(points, K=2)
|
||||
out_points, out_inds = fps_func(points, K=2)
|
||||
self.assertClose(out_inds, expected_inds)
|
||||
|
||||
# Gather the points
|
||||
@ -55,24 +57,37 @@ class TestFPS(TestCaseMixin, unittest.TestCase):
|
||||
expected_inds = torch.tensor(
|
||||
[[0, 4, 1], [0, 7, -1]], dtype=torch.int64, device=device
|
||||
)
|
||||
out_points, out_inds = sample_farthest_points_naive(points, K=[3, 2])
|
||||
out_points, out_inds = fps_func(points, K=[3, 2])
|
||||
self.assertClose(out_inds, expected_inds)
|
||||
|
||||
# Gather the points
|
||||
expected_points = masked_gather(points, expected_inds)
|
||||
self.assertClose(out_points, expected_points)
|
||||
|
||||
def test_random_heterogeneous(self):
|
||||
device = get_random_cuda_device()
|
||||
N, P, D, K = 5, 40, 5, 8
|
||||
points = torch.randn((N, P, D), device=device)
|
||||
out_points, out_idxs = sample_farthest_points_naive(points, K=K)
|
||||
def _test_compare_random_heterogeneous(self, device="cpu"):
|
||||
N, P, D, K = 5, 20, 5, 8
|
||||
points = torch.randn((N, P, D), device=device, dtype=torch.float32)
|
||||
out_points_naive, out_idxs_naive = sample_farthest_points_naive(points, K=K)
|
||||
out_points, out_idxs = sample_farthest_points(points, K=K)
|
||||
self.assertTrue(out_idxs.min() >= 0)
|
||||
self.assertClose(out_idxs, out_idxs_naive)
|
||||
self.assertClose(out_points, out_points_naive)
|
||||
for n in range(N):
|
||||
self.assertEqual(out_idxs[n].ne(-1).sum(), K)
|
||||
|
||||
# Test case where K > P
|
||||
K = 30
|
||||
points1 = torch.randn((N, P, D), dtype=torch.float32, device=device)
|
||||
points2 = points1.clone()
|
||||
points1.requires_grad = True
|
||||
points2.requires_grad = True
|
||||
lengths = torch.randint(low=1, high=P, size=(N,), device=device)
|
||||
out_points, out_idxs = sample_farthest_points_naive(points, lengths, K=50)
|
||||
out_points_naive, out_idxs_naive = sample_farthest_points_naive(
|
||||
points1, lengths, K=K
|
||||
)
|
||||
out_points, out_idxs = sample_farthest_points(points2, lengths, K=K)
|
||||
self.assertClose(out_idxs, out_idxs_naive)
|
||||
self.assertClose(out_points, out_points_naive)
|
||||
|
||||
for n in range(N):
|
||||
# Check that for heterogeneous batches, the max number of
|
||||
@ -85,8 +100,15 @@ class TestFPS(TestCaseMixin, unittest.TestCase):
|
||||
vals, counts = torch.unique(out_idxs[n][val_mask], return_counts=True)
|
||||
self.assertTrue(counts.le(1).all())
|
||||
|
||||
def test_errors(self):
|
||||
device = get_random_cuda_device()
|
||||
# Check gradients
|
||||
grad_sampled_points = torch.ones((N, K, D), dtype=torch.float32, device=device)
|
||||
loss1 = (out_points_naive * grad_sampled_points).sum()
|
||||
loss1.backward()
|
||||
loss2 = (out_points * grad_sampled_points).sum()
|
||||
loss2.backward()
|
||||
self.assertClose(points1.grad, points2.grad, atol=5e-6)
|
||||
|
||||
def _test_errors(self, fps_func, device="cpu"):
|
||||
N, P, D, K = 5, 40, 5, 8
|
||||
points = torch.randn((N, P, D), device=device)
|
||||
wrong_batch_dim = torch.randint(low=1, high=K, size=(K,), device=device)
|
||||
@ -99,8 +121,7 @@ class TestFPS(TestCaseMixin, unittest.TestCase):
|
||||
with self.assertRaisesRegex(ValueError, "points and lengths must have"):
|
||||
sample_farthest_points_naive(points, lengths=wrong_batch_dim, K=K)
|
||||
|
||||
def test_random_start(self):
|
||||
device = get_random_cuda_device()
|
||||
def _test_random_start(self, fps_func, device="cpu"):
|
||||
N, P, D, K = 5, 40, 5, 8
|
||||
points = torch.randn((N, P, D), device=device)
|
||||
out_points, out_idxs = sample_farthest_points_naive(
|
||||
@ -109,3 +130,64 @@ class TestFPS(TestCaseMixin, unittest.TestCase):
|
||||
# Check the first index is not 0 for all batch elements
|
||||
# when random_start_point = True
|
||||
self.assertTrue(out_idxs[:, 0].sum() > 0)
|
||||
|
||||
def _test_gradcheck(self, fps_func, device="cpu"):
|
||||
N, P, D, K = 2, 5, 3, 2
|
||||
points = torch.randn(
|
||||
(N, P, D), dtype=torch.float32, device=device, requires_grad=True
|
||||
)
|
||||
torch.autograd.gradcheck(
|
||||
fps_func,
|
||||
(points, None, K),
|
||||
check_undefined_grad=False,
|
||||
eps=2e-3,
|
||||
atol=0.001,
|
||||
)
|
||||
|
||||
def test_sample_farthest_points_naive(self):
|
||||
device = get_random_cuda_device()
|
||||
self._test_simple(sample_farthest_points_naive, device)
|
||||
self._test_errors(sample_farthest_points_naive, device)
|
||||
self._test_random_start(sample_farthest_points_naive, device)
|
||||
self._test_gradcheck(sample_farthest_points_naive, device)
|
||||
|
||||
def test_sample_farthest_points_cpu(self):
|
||||
self._test_simple(sample_farthest_points, "cpu")
|
||||
self._test_errors(sample_farthest_points, "cpu")
|
||||
self._test_compare_random_heterogeneous("cpu")
|
||||
self._test_random_start(sample_farthest_points, "cpu")
|
||||
self._test_gradcheck(sample_farthest_points, "cpu")
|
||||
|
||||
@staticmethod
|
||||
def sample_farthest_points_naive(N: int, P: int, D: int, K: int, device: str):
|
||||
device = torch.device(device)
|
||||
pts = torch.randn(
|
||||
N, P, D, dtype=torch.float32, device=device, requires_grad=True
|
||||
)
|
||||
grad_pts = torch.randn(N, K, D, dtype=torch.float32, device=device)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def output():
|
||||
out_points, _ = sample_farthest_points_naive(pts, K=K)
|
||||
loss = (out_points * grad_pts).sum()
|
||||
loss.backward()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def sample_farthest_points(N: int, P: int, D: int, K: int, device: str):
|
||||
device = torch.device(device)
|
||||
pts = torch.randn(
|
||||
N, P, D, dtype=torch.float32, device=device, requires_grad=True
|
||||
)
|
||||
grad_pts = torch.randn(N, K, D, dtype=torch.float32, device=device)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def output():
|
||||
out_points, _ = sample_farthest_points(pts, K=K)
|
||||
loss = (out_points * grad_pts).sum()
|
||||
loss.backward()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return output
|
||||
|
Loading…
x
Reference in New Issue
Block a user