Farthest point sampling CUDA

Summary:
CUDA implementation of farthest point sampling algorithm.

## Visual comparison

Compared to random sampling, farthest point sampling gives better coverage of the shape.

{F658631262}

## Reduction

Parallelized block reduction to find the max value at each iteration happens as follows:

1. First split the points into two equal sized parts (e.g. for a list with 8 values):
`[20, 27, 6, 8 | 11, 10, 2, 33]`
2. Use half of the thread (4 threads) to compare pairs of elements from each half (e.g elements [0, 4], [1, 5] etc) and store the result in the first half of the list:
`[20, 27, 6, 33 | 11, 10, 2, 33]`
Now we no longer care about the second part but again divide the first part into two
`[20, 27 | 6, 33| -, -, -, -]`
Now we can use 2 threads to compare the 4 elements
4. Finally we have gotten down to a single pair
`[20 | 33 | -, - | -, -, -, -]`
Use 1 thread to compare the remaining two elements
5. The max will now be at thread id = 0
`[33 | - | -, - | -, -, -, -]`
The reduction will give the farthest point for the selected batch index at this iteration.

Reviewed By: bottler, jcjohnson

Differential Revision: D30401803

fbshipit-source-id: 525bd5ae27c4b13b501812cfe62306bb003827d2
This commit is contained in:
Nikhila Ravi 2021-09-15 13:47:55 -07:00 committed by Facebook GitHub Bot
parent d9f7611c4b
commit bd04ffaf77
11 changed files with 441 additions and 33 deletions

View File

@ -26,8 +26,8 @@
#include "point_mesh/point_mesh_cuda.h"
#include "rasterize_meshes/rasterize_meshes.h"
#include "rasterize_points/rasterize_points.h"
#include "sample_pdf/sample_pdf.h"
#include "sample_farthest_points/sample_farthest_points.h"
#include "sample_pdf/sample_pdf.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("face_areas_normals_forward", &FaceAreasNormalsForward);

View File

@ -121,7 +121,7 @@ __global__ void DistanceForwardKernel(
// Unroll the last 6 iterations of the loop since they will happen
// synchronized within a single warp.
if (tid < 32)
WarpReduce<float>(min_dists, min_idxs, tid);
WarpReduceMin<float>(min_dists, min_idxs, tid);
// Finally thread 0 writes the result to the output buffer.
if (tid == 0) {

View File

@ -15,7 +15,7 @@ at::Tensor FarthestPointSamplingCpu(
const at::Tensor& points,
const at::Tensor& lengths,
const at::Tensor& K,
const bool random_start_point) {
const at::Tensor& start_idxs) {
// Get constants
const int64_t N = points.size(0);
const int64_t P = points.size(1);
@ -32,6 +32,7 @@ at::Tensor FarthestPointSamplingCpu(
auto lengths_a = lengths.accessor<int64_t, 1>();
auto k_a = K.accessor<int64_t, 1>();
auto sampled_indices_a = sampled_indices.accessor<int64_t, 2>();
auto start_idxs_a = start_idxs.accessor<int64_t, 1>();
// Initialize a mask to prevent duplicates
// If true, the point has already been selected.
@ -41,10 +42,6 @@ at::Tensor FarthestPointSamplingCpu(
// distances from each point to any of the previously selected points
std::vector<float> dists(P, std::numeric_limits<float>::max());
// Initialize random number generation for random starting points
std::random_device rd;
std::default_random_engine eng(rd());
for (int64_t n = 0; n < N; ++n) {
// Resize and reset points mask and distances for each batch
selected_points_mask.resize(lengths_a[n]);
@ -52,9 +49,8 @@ at::Tensor FarthestPointSamplingCpu(
std::fill(selected_points_mask.begin(), selected_points_mask.end(), false);
std::fill(dists.begin(), dists.end(), std::numeric_limits<float>::max());
// Select a starting point index and save it
std::uniform_int_distribution<int> distr(0, lengths_a[n] - 1);
int64_t last_idx = random_start_point ? distr(eng) : 0;
// Get the starting point index and save it
int64_t last_idx = start_idxs_a[n];
sampled_indices_a[n][0] = last_idx;
// Set the value of the mask at this point to false

View File

@ -0,0 +1,252 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "utils/pytorch3d_cutils.h"
#include "utils/warp_reduce.cuh"
template <unsigned int block_size>
__global__ void FarthestPointSamplingKernel(
// clang-format off
const at::PackedTensorAccessor64<float, 3, at::RestrictPtrTraits> points,
const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits> lengths,
const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits> K,
at::PackedTensorAccessor64<int64_t, 2, at::RestrictPtrTraits> idxs,
at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> min_point_dist,
const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits> start_idxs
// clang-format on
) {
// Get constants
const int64_t N = points.size(0);
const int64_t P = points.size(1);
const int64_t D = points.size(2);
// Create single shared memory buffer which is split and cast to different
// types: dists/dists_idx are used to save the maximum distances seen by the
// points processed by any one thread and the associated point indices.
// These values only need to be accessed by other threads in this block which
// are processing the same batch and not by other blocks.
extern __shared__ char shared_buf[];
float* dists = (float*)shared_buf; // block_size floats
int64_t* dists_idx = (int64_t*)&dists[block_size]; // block_size int64_t
// Get batch index and thread index
const int64_t batch_idx = blockIdx.x;
const size_t tid = threadIdx.x;
// If K is greater than the number of points in the pointcloud
// we only need to iterate until the smaller value is reached.
const int64_t k_n = min(K[batch_idx], lengths[batch_idx]);
// Write the first selected point to global memory in the first thread
int64_t selected = start_idxs[batch_idx];
if (tid == 0)
idxs[batch_idx][0] = selected;
// Iterate to find k_n sampled points
for (int64_t k = 1; k < k_n; ++k) {
// Keep track of the maximum of the minimum distance to previously selected
// points seen by this thread
int64_t max_dist_idx = 0;
float max_dist = -1.0;
// Iterate through all the points in this pointcloud. For already selected
// points, the minimum distance to the set of previously selected points
// will be 0.0 so they won't be selected again.
for (int64_t p = tid; p < lengths[batch_idx]; p += block_size) {
// Calculate the distance to the last selected point
float dist2 = 0.0;
for (int64_t d = 0; d < D; ++d) {
float diff = points[batch_idx][selected][d] - points[batch_idx][p][d];
dist2 += (diff * diff);
}
// If the distance of point p to the last selected point is
// less than the previous minimum distance of p to the set of selected
// points, then updated the corresponding value in min_point_dist
// so it always contains the min distance.
const float p_min_dist = min(dist2, min_point_dist[batch_idx][p]);
min_point_dist[batch_idx][p] = p_min_dist;
// Update the max distance and point idx for this thread.
max_dist_idx = (p_min_dist > max_dist) ? p : max_dist_idx;
max_dist = (p_min_dist > max_dist) ? p_min_dist : max_dist;
}
// After going through all points for this thread, save the max
// point and idx seen by this thread. Each thread sees P/block_size points.
dists[tid] = max_dist;
dists_idx[tid] = max_dist_idx;
// Sync to ensure all threads in the block have updated their max point.
__syncthreads();
// Parallelized block reduction to find the max point seen by
// all the threads in this block for iteration k.
// Each block represents one batch element so we can use a divide/conquer
// approach to find the max, syncing all threads after each step.
for (int s = block_size / 2; s > 0; s >>= 1) {
if (tid < s) {
// Compare the best point seen by two threads and update the shared
// memory at the location of the first thread index with the max out
// of the two threads.
if (dists[tid] < dists[tid + s]) {
dists[tid] = dists[tid + s];
dists_idx[tid] = dists_idx[tid + s];
}
}
__syncthreads();
}
// TODO(nikhilar): As reduction proceeds, the number of “active” threads
// decreases. When tid < 32, there should only be one warp left which could
// be unrolled.
// The overall max after reducing will be saved
// at the location of tid = 0.
selected = dists_idx[0];
if (tid == 0) {
// Write the farthest point for iteration k to global memory
idxs[batch_idx][k] = selected;
}
}
}
at::Tensor FarthestPointSamplingCuda(
const at::Tensor& points, // (N, P, 3)
const at::Tensor& lengths, // (N,)
const at::Tensor& K, // (N,)
const at::Tensor& start_idxs) {
// Check inputs are on the same device
at::TensorArg p_t{points, "points", 1}, lengths_t{lengths, "lengths", 2},
k_t{K, "K", 3}, start_idxs_t{start_idxs, "start_idxs", 4};
at::CheckedFrom c = "FarthestPointSamplingCuda";
at::checkAllSameGPU(c, {p_t, lengths_t, k_t, start_idxs_t});
at::checkAllSameType(c, {lengths_t, k_t, start_idxs_t});
// Set the device for the kernel launch based on the device of points
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_CHECK(
points.size(0) == lengths.size(0),
"Point and lengths must have the same batch dimension");
TORCH_CHECK(
points.size(0) == K.size(0),
"Points and K must have the same batch dimension");
const int64_t N = points.size(0);
const int64_t P = points.size(1);
const int64_t max_K = at::max(K).item<int64_t>();
// Initialize the output tensor with the sampled indices
auto idxs = at::full({N, max_K}, -1, lengths.options());
auto min_point_dist = at::full({N, P}, 1e10, points.options());
if (N == 0 || P == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return idxs;
}
// Set the number of blocks to the batch size so that the
// block reduction step can be done for each pointcloud
// to find the max distance point in the pointcloud at each iteration.
const size_t blocks = N;
// Set the threads to the nearest power of 2 of the number of
// points in the pointcloud (up to the max threads in a block).
// This will ensure each thread processes the minimum necessary number of
// points (P/threads).
const int points_pow_2 = std::log(static_cast<double>(P)) / std::log(2.0);
const size_t threads = max(min(1 << points_pow_2, MAX_THREADS_PER_BLOCK), 1);
// Create the accessors
auto points_a = points.packed_accessor64<float, 3, at::RestrictPtrTraits>();
auto lengths_a =
lengths.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>();
auto K_a = K.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>();
auto idxs_a = idxs.packed_accessor64<int64_t, 2, at::RestrictPtrTraits>();
auto start_idxs_a =
start_idxs.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>();
auto min_point_dist_a =
min_point_dist.packed_accessor64<float, 2, at::RestrictPtrTraits>();
// Initialize the shared memory which will be used to store the
// distance/index of the best point seen by each thread.
size_t shared_mem = threads * sizeof(float) + threads * sizeof(int64_t);
// TODO: using shared memory for min_point_dist gives an ~2x speed up
// compared to using a global (N, P) shaped tensor, however for
// larger pointclouds this may exceed the shared memory limit per block.
// If a speed up is required for smaller pointclouds, then the storage
// could be switched to shared memory if the required total shared memory is
// within the memory limit per block.
// Support a case for all powers of 2 up to MAX_THREADS_PER_BLOCK possible per
// block.
switch (threads) {
case 1024:
FarthestPointSamplingKernel<1024>
<<<blocks, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
break;
case 512:
FarthestPointSamplingKernel<512><<<blocks, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
break;
case 256:
FarthestPointSamplingKernel<256><<<blocks, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
break;
case 128:
FarthestPointSamplingKernel<128><<<blocks, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
break;
case 64:
FarthestPointSamplingKernel<64><<<blocks, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
break;
case 32:
FarthestPointSamplingKernel<32><<<blocks, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
break;
case 16:
FarthestPointSamplingKernel<16><<<blocks, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
break;
case 8:
FarthestPointSamplingKernel<8><<<blocks, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
break;
case 4:
FarthestPointSamplingKernel<4><<<threads, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
break;
case 2:
FarthestPointSamplingKernel<2><<<threads, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
break;
case 1:
FarthestPointSamplingKernel<1><<<threads, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
break;
default:
FarthestPointSamplingKernel<1024>
<<<blocks, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
}
AT_CUDA_CHECK(cudaGetLastError());
return idxs;
}

View File

@ -29,28 +29,44 @@
// K: a tensor of length (N,) giving the number of
// samples to select for each element in the batch.
// The number of samples is typically << P.
// random_start_point: bool, if True, a random point is selected as the
// starting point for iterative sampling.
// start_idxs: (N,) long Tensor giving the index of the first point to
// sample. Default is all 0. When a random start point is required,
// start_idxs should be set to a random value between [0, lengths[n]]
// for batch element n.
// Returns:
// selected_indices: (N, K) array of selected indices. If the values in
// K are not all the same, then the shape will be (N, max(K), D), and
// padded with -1 for batch elements where k_i < max(K). The selected
// points are gathered in the pytorch autograd wrapper.
at::Tensor FarthestPointSamplingCuda(
const at::Tensor& points,
const at::Tensor& lengths,
const at::Tensor& K,
const at::Tensor& start_idxs);
at::Tensor FarthestPointSamplingCpu(
const at::Tensor& points,
const at::Tensor& lengths,
const at::Tensor& K,
const bool random_start_point);
const at::Tensor& start_idxs);
// Exposed implementation.
at::Tensor FarthestPointSampling(
const at::Tensor& points,
const at::Tensor& lengths,
const at::Tensor& K,
const bool random_start_point) {
const at::Tensor& start_idxs) {
if (points.is_cuda() || lengths.is_cuda() || K.is_cuda()) {
AT_ERROR("CUDA implementation not yet supported");
#ifdef WITH_CUDA
CHECK_CUDA(points);
CHECK_CUDA(lengths);
CHECK_CUDA(K);
CHECK_CUDA(start_idxs);
return FarthestPointSamplingCuda(points, lengths, K, start_idxs);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
return FarthestPointSamplingCpu(points, lengths, K, random_start_point);
return FarthestPointSamplingCpu(points, lengths, K, start_idxs);
}

View File

@ -15,3 +15,6 @@
#define CHECK_CONTIGUOUS_CUDA(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
// Max possible threads per block
const int MAX_THREADS_PER_BLOCK = 1024;

View File

@ -10,41 +10,85 @@
#include <math.h>
#include <cstdio>
// helper WarpReduce used in .cu files
// Helper functions WarpReduceMin and WarpReduceMax used in .cu files
// Starting in Volta, instructions are no longer synchronous within a warp.
// We need to call __syncwarp() to sync the 32 threads in the warp
// instead of all the threads in the block.
template <typename scalar_t>
__device__ void WarpReduce(
volatile scalar_t* min_dists,
volatile int64_t* min_idxs,
const size_t tid) {
__device__ void
WarpReduceMin(scalar_t* min_dists, int64_t* min_idxs, const size_t tid) {
// s = 32
if (min_dists[tid] > min_dists[tid + 32]) {
min_idxs[tid] = min_idxs[tid + 32];
min_dists[tid] = min_dists[tid + 32];
}
__syncwarp();
// s = 16
if (min_dists[tid] > min_dists[tid + 16]) {
min_idxs[tid] = min_idxs[tid + 16];
min_dists[tid] = min_dists[tid + 16];
}
__syncwarp();
// s = 8
if (min_dists[tid] > min_dists[tid + 8]) {
min_idxs[tid] = min_idxs[tid + 8];
min_dists[tid] = min_dists[tid + 8];
}
__syncwarp();
// s = 4
if (min_dists[tid] > min_dists[tid + 4]) {
min_idxs[tid] = min_idxs[tid + 4];
min_dists[tid] = min_dists[tid + 4];
}
__syncwarp();
// s = 2
if (min_dists[tid] > min_dists[tid + 2]) {
min_idxs[tid] = min_idxs[tid + 2];
min_dists[tid] = min_dists[tid + 2];
}
__syncwarp();
// s = 1
if (min_dists[tid] > min_dists[tid + 1]) {
min_idxs[tid] = min_idxs[tid + 1];
min_dists[tid] = min_dists[tid + 1];
}
__syncwarp();
}
template <typename scalar_t>
__device__ void WarpReduceMax(
volatile scalar_t* dists,
volatile int64_t* dists_idx,
const size_t tid) {
if (dists[tid] < dists[tid + 32]) {
dists[tid] = dists[tid + 32];
dists_idx[tid] = dists_idx[tid + 32];
}
__syncwarp();
if (dists[tid] < dists[tid + 16]) {
dists[tid] = dists[tid + 16];
dists_idx[tid] = dists_idx[tid + 16];
}
__syncwarp();
if (dists[tid] < dists[tid + 8]) {
dists[tid] = dists[tid + 8];
dists_idx[tid] = dists_idx[tid + 8];
}
__syncwarp();
if (dists[tid] < dists[tid + 4]) {
dists[tid] = dists[tid + 4];
dists_idx[tid] = dists_idx[tid + 4];
}
__syncwarp();
if (dists[tid] < dists[tid + 2]) {
dists[tid] = dists[tid + 2];
dists_idx[tid] = dists_idx[tid + 2];
}
__syncwarp();
if (dists[tid] < dists[tid + 1]) {
dists[tid] = dists[tid + 1];
dists_idx[tid] = dists_idx[tid + 1];
}
__syncwarp();
}

View File

@ -24,6 +24,7 @@ from .points_to_volumes import (
add_pointclouds_to_volumes,
add_points_features_to_volume_densities_features,
)
from .sample_farthest_points import sample_farthest_points
from .sample_points_from_meshes import sample_points_from_meshes
from .subdivide_meshes import SubdivideMeshes
from .utils import (

View File

@ -57,7 +57,7 @@ def sample_farthest_points(
if lengths is None:
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
if lengths.shape[0] != N:
if lengths.shape != (N,):
raise ValueError("points and lengths must have same batch dimension.")
# TODO: support providing K as a ratio of the total number of points instead of as an int
@ -77,9 +77,15 @@ def sample_farthest_points(
if not (K.dtype == torch.int64):
K = K.to(torch.int64)
# Generate the starting indices for sampling
start_idxs = torch.zeros_like(lengths)
if random_start_point:
for n in range(N):
start_idxs[n] = torch.randint(high=lengths[n], size=(1,)).item()
with torch.no_grad():
# pyre-fixme[16]: `pytorch3d_._C` has no attribute `sample_farthest_points`.
idx = _C.sample_farthest_points(points, lengths, K, random_start_point)
idx = _C.sample_farthest_points(points, lengths, K, start_idxs)
sampled_points = masked_gather(points, idx)
return sampled_points, idx

View File

@ -29,8 +29,17 @@ def bm_fps() -> None:
warmup_iters=1,
)
kwargs_list = [k for k in kwargs_list if k["device"] == "cpu"]
benchmark(TestFPS.sample_farthest_points, "FPS_CPU", kwargs_list, warmup_iters=1)
# Add some larger batch sizes and pointcloud sizes
Ns = [32]
Ps = [2048, 8192, 18384]
Ds = [3, 9]
Ks = [24, 48]
test_cases = product(Ns, Ps, Ds, Ks, backends)
for case in test_cases:
N, P, D, K, d = case
kwargs_list.append({"N": N, "P": P, "D": D, "K": K, "device": d})
benchmark(TestFPS.sample_farthest_points, "FPS", kwargs_list, warmup_iters=1)
if __name__ == "__main__":

View File

@ -6,14 +6,25 @@
import unittest
import numpy as np
import torch
from common_testing import TestCaseMixin, get_random_cuda_device
from common_testing import (
TestCaseMixin,
get_random_cuda_device,
get_tests_dir,
get_pytorch3d_dir,
)
from pytorch3d.io import load_obj
from pytorch3d.ops.sample_farthest_points import (
sample_farthest_points_naive,
sample_farthest_points,
)
from pytorch3d.ops.utils import masked_gather
DATA_DIR = get_tests_dir() / "data"
TUTORIAL_DATA_DIR = get_pytorch3d_dir() / "docs/tutorials/data"
DEBUG = False
class TestFPS(TestCaseMixin, unittest.TestCase):
def _test_simple(self, fps_func, device="cpu"):
@ -123,22 +134,22 @@ class TestFPS(TestCaseMixin, unittest.TestCase):
def _test_random_start(self, fps_func, device="cpu"):
N, P, D, K = 5, 40, 5, 8
points = torch.randn((N, P, D), device=device)
out_points, out_idxs = sample_farthest_points_naive(
points, K=K, random_start_point=True
)
# Check the first index is not 0 for all batch elements
points = torch.randn((N, P, D), dtype=torch.float32, device=device)
out_points, out_idxs = fps_func(points, K=K, random_start_point=True)
# Check the first index is not 0 or the same number for all batch elements
# when random_start_point = True
self.assertTrue(out_idxs[:, 0].sum() > 0)
self.assertFalse(out_idxs[:, 0].eq(out_idxs[0, 0]).all())
def _test_gradcheck(self, fps_func, device="cpu"):
N, P, D, K = 2, 5, 3, 2
N, P, D, K = 2, 10, 3, 2
points = torch.randn(
(N, P, D), dtype=torch.float32, device=device, requires_grad=True
)
lengths = torch.randint(low=1, high=P, size=(N,), device=device)
torch.autograd.gradcheck(
fps_func,
(points, None, K),
(points, lengths, K),
check_undefined_grad=False,
eps=2e-3,
atol=0.001,
@ -158,6 +169,76 @@ class TestFPS(TestCaseMixin, unittest.TestCase):
self._test_random_start(sample_farthest_points, "cpu")
self._test_gradcheck(sample_farthest_points, "cpu")
def test_sample_farthest_points_cuda(self):
device = get_random_cuda_device()
self._test_simple(sample_farthest_points, device)
self._test_errors(sample_farthest_points, device)
self._test_compare_random_heterogeneous(device)
self._test_random_start(sample_farthest_points, device)
self._test_gradcheck(sample_farthest_points, device)
def test_cuda_vs_cpu(self):
"""
Compare cuda vs cpu on a complex object
"""
obj_filename = TUTORIAL_DATA_DIR / "cow_mesh/cow.obj"
K = 250
# Run on CPU
device = "cpu"
points, _, _ = load_obj(obj_filename, device=device, load_textures=False)
points = points[None, ...]
out_points_cpu, out_idxs_cpu = sample_farthest_points(points, K=K)
# Run on GPU
device = get_random_cuda_device()
points_cuda = points.to(device)
out_points_cuda, out_idxs_cuda = sample_farthest_points(points_cuda, K=K)
# Check that the indices from CUDA and CPU match
self.assertClose(out_idxs_cpu, out_idxs_cuda.cpu())
# Check there are no duplicate indices
val_mask = out_idxs_cuda[0].ne(-1)
vals, counts = torch.unique(out_idxs_cuda[0][val_mask], return_counts=True)
self.assertTrue(counts.le(1).all())
# Plot all results
if DEBUG:
# mplot3d is required for 3d projection plots
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d # noqa: F401
# Move to cpu and convert to numpy for plotting
points = points.squeeze()
out_points_cpu = out_points_cpu.squeeze().numpy()
out_points_cuda = out_points_cuda.squeeze().cpu().numpy()
# Farthest point sampling CPU
fig = plt.figure(figsize=plt.figaspect(1.0 / 3))
ax1 = fig.add_subplot(1, 3, 1, projection="3d")
ax1.scatter(*points.T, alpha=0.1)
ax1.scatter(*out_points_cpu.T, color="black")
ax1.set_title("FPS CPU")
# Farthest point sampling CUDA
ax2 = fig.add_subplot(1, 3, 2, projection="3d")
ax2.scatter(*points.T, alpha=0.1)
ax2.scatter(*out_points_cuda.T, color="red")
ax2.set_title("FPS CUDA")
# Random Sampling
random_points = np.random.permutation(points)[:K]
ax3 = fig.add_subplot(1, 3, 3, projection="3d")
ax3.scatter(*points.T, alpha=0.1)
ax3.scatter(*random_points.T, color="green")
ax3.set_title("Random")
# Save image
filename = "DEBUG_fps.jpg"
filepath = DATA_DIR / filename
plt.savefig(filepath)
@staticmethod
def sample_farthest_points_naive(N: int, P: int, D: int, K: int, device: str):
device = torch.device(device)