diff --git a/pytorch3d/csrc/ext.cpp b/pytorch3d/csrc/ext.cpp index 3b9a9edc..2178850c 100644 --- a/pytorch3d/csrc/ext.cpp +++ b/pytorch3d/csrc/ext.cpp @@ -26,8 +26,8 @@ #include "point_mesh/point_mesh_cuda.h" #include "rasterize_meshes/rasterize_meshes.h" #include "rasterize_points/rasterize_points.h" -#include "sample_pdf/sample_pdf.h" #include "sample_farthest_points/sample_farthest_points.h" +#include "sample_pdf/sample_pdf.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("face_areas_normals_forward", &FaceAreasNormalsForward); diff --git a/pytorch3d/csrc/point_mesh/point_mesh_cuda.cu b/pytorch3d/csrc/point_mesh/point_mesh_cuda.cu index b1ebf78c..85716c00 100644 --- a/pytorch3d/csrc/point_mesh/point_mesh_cuda.cu +++ b/pytorch3d/csrc/point_mesh/point_mesh_cuda.cu @@ -121,7 +121,7 @@ __global__ void DistanceForwardKernel( // Unroll the last 6 iterations of the loop since they will happen // synchronized within a single warp. if (tid < 32) - WarpReduce(min_dists, min_idxs, tid); + WarpReduceMin(min_dists, min_idxs, tid); // Finally thread 0 writes the result to the output buffer. if (tid == 0) { diff --git a/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.cpp b/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.cpp index f4167ac5..aaa860ce 100644 --- a/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.cpp +++ b/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.cpp @@ -15,7 +15,7 @@ at::Tensor FarthestPointSamplingCpu( const at::Tensor& points, const at::Tensor& lengths, const at::Tensor& K, - const bool random_start_point) { + const at::Tensor& start_idxs) { // Get constants const int64_t N = points.size(0); const int64_t P = points.size(1); @@ -32,6 +32,7 @@ at::Tensor FarthestPointSamplingCpu( auto lengths_a = lengths.accessor(); auto k_a = K.accessor(); auto sampled_indices_a = sampled_indices.accessor(); + auto start_idxs_a = start_idxs.accessor(); // Initialize a mask to prevent duplicates // If true, the point has already been selected. @@ -41,10 +42,6 @@ at::Tensor FarthestPointSamplingCpu( // distances from each point to any of the previously selected points std::vector dists(P, std::numeric_limits::max()); - // Initialize random number generation for random starting points - std::random_device rd; - std::default_random_engine eng(rd()); - for (int64_t n = 0; n < N; ++n) { // Resize and reset points mask and distances for each batch selected_points_mask.resize(lengths_a[n]); @@ -52,9 +49,8 @@ at::Tensor FarthestPointSamplingCpu( std::fill(selected_points_mask.begin(), selected_points_mask.end(), false); std::fill(dists.begin(), dists.end(), std::numeric_limits::max()); - // Select a starting point index and save it - std::uniform_int_distribution distr(0, lengths_a[n] - 1); - int64_t last_idx = random_start_point ? distr(eng) : 0; + // Get the starting point index and save it + int64_t last_idx = start_idxs_a[n]; sampled_indices_a[n][0] = last_idx; // Set the value of the mask at this point to false diff --git a/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.cu b/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.cu new file mode 100644 index 00000000..5437cccb --- /dev/null +++ b/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.cu @@ -0,0 +1,252 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include "utils/pytorch3d_cutils.h" +#include "utils/warp_reduce.cuh" + +template +__global__ void FarthestPointSamplingKernel( + // clang-format off + const at::PackedTensorAccessor64 points, + const at::PackedTensorAccessor64 lengths, + const at::PackedTensorAccessor64 K, + at::PackedTensorAccessor64 idxs, + at::PackedTensorAccessor64 min_point_dist, + const at::PackedTensorAccessor64 start_idxs + // clang-format on +) { + // Get constants + const int64_t N = points.size(0); + const int64_t P = points.size(1); + const int64_t D = points.size(2); + + // Create single shared memory buffer which is split and cast to different + // types: dists/dists_idx are used to save the maximum distances seen by the + // points processed by any one thread and the associated point indices. + // These values only need to be accessed by other threads in this block which + // are processing the same batch and not by other blocks. + extern __shared__ char shared_buf[]; + float* dists = (float*)shared_buf; // block_size floats + int64_t* dists_idx = (int64_t*)&dists[block_size]; // block_size int64_t + + // Get batch index and thread index + const int64_t batch_idx = blockIdx.x; + const size_t tid = threadIdx.x; + + // If K is greater than the number of points in the pointcloud + // we only need to iterate until the smaller value is reached. + const int64_t k_n = min(K[batch_idx], lengths[batch_idx]); + + // Write the first selected point to global memory in the first thread + int64_t selected = start_idxs[batch_idx]; + if (tid == 0) + idxs[batch_idx][0] = selected; + + // Iterate to find k_n sampled points + for (int64_t k = 1; k < k_n; ++k) { + // Keep track of the maximum of the minimum distance to previously selected + // points seen by this thread + int64_t max_dist_idx = 0; + float max_dist = -1.0; + + // Iterate through all the points in this pointcloud. For already selected + // points, the minimum distance to the set of previously selected points + // will be 0.0 so they won't be selected again. + for (int64_t p = tid; p < lengths[batch_idx]; p += block_size) { + // Calculate the distance to the last selected point + float dist2 = 0.0; + for (int64_t d = 0; d < D; ++d) { + float diff = points[batch_idx][selected][d] - points[batch_idx][p][d]; + dist2 += (diff * diff); + } + + // If the distance of point p to the last selected point is + // less than the previous minimum distance of p to the set of selected + // points, then updated the corresponding value in min_point_dist + // so it always contains the min distance. + const float p_min_dist = min(dist2, min_point_dist[batch_idx][p]); + min_point_dist[batch_idx][p] = p_min_dist; + + // Update the max distance and point idx for this thread. + max_dist_idx = (p_min_dist > max_dist) ? p : max_dist_idx; + max_dist = (p_min_dist > max_dist) ? p_min_dist : max_dist; + } + + // After going through all points for this thread, save the max + // point and idx seen by this thread. Each thread sees P/block_size points. + dists[tid] = max_dist; + dists_idx[tid] = max_dist_idx; + // Sync to ensure all threads in the block have updated their max point. + __syncthreads(); + + // Parallelized block reduction to find the max point seen by + // all the threads in this block for iteration k. + // Each block represents one batch element so we can use a divide/conquer + // approach to find the max, syncing all threads after each step. + + for (int s = block_size / 2; s > 0; s >>= 1) { + if (tid < s) { + // Compare the best point seen by two threads and update the shared + // memory at the location of the first thread index with the max out + // of the two threads. + if (dists[tid] < dists[tid + s]) { + dists[tid] = dists[tid + s]; + dists_idx[tid] = dists_idx[tid + s]; + } + } + __syncthreads(); + } + + // TODO(nikhilar): As reduction proceeds, the number of “active” threads + // decreases. When tid < 32, there should only be one warp left which could + // be unrolled. + + // The overall max after reducing will be saved + // at the location of tid = 0. + selected = dists_idx[0]; + + if (tid == 0) { + // Write the farthest point for iteration k to global memory + idxs[batch_idx][k] = selected; + } + } +} + +at::Tensor FarthestPointSamplingCuda( + const at::Tensor& points, // (N, P, 3) + const at::Tensor& lengths, // (N,) + const at::Tensor& K, // (N,) + const at::Tensor& start_idxs) { + // Check inputs are on the same device + at::TensorArg p_t{points, "points", 1}, lengths_t{lengths, "lengths", 2}, + k_t{K, "K", 3}, start_idxs_t{start_idxs, "start_idxs", 4}; + at::CheckedFrom c = "FarthestPointSamplingCuda"; + at::checkAllSameGPU(c, {p_t, lengths_t, k_t, start_idxs_t}); + at::checkAllSameType(c, {lengths_t, k_t, start_idxs_t}); + + // Set the device for the kernel launch based on the device of points + at::cuda::CUDAGuard device_guard(points.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + TORCH_CHECK( + points.size(0) == lengths.size(0), + "Point and lengths must have the same batch dimension"); + + TORCH_CHECK( + points.size(0) == K.size(0), + "Points and K must have the same batch dimension"); + + const int64_t N = points.size(0); + const int64_t P = points.size(1); + const int64_t max_K = at::max(K).item(); + + // Initialize the output tensor with the sampled indices + auto idxs = at::full({N, max_K}, -1, lengths.options()); + auto min_point_dist = at::full({N, P}, 1e10, points.options()); + + if (N == 0 || P == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return idxs; + } + + // Set the number of blocks to the batch size so that the + // block reduction step can be done for each pointcloud + // to find the max distance point in the pointcloud at each iteration. + const size_t blocks = N; + + // Set the threads to the nearest power of 2 of the number of + // points in the pointcloud (up to the max threads in a block). + // This will ensure each thread processes the minimum necessary number of + // points (P/threads). + const int points_pow_2 = std::log(static_cast(P)) / std::log(2.0); + const size_t threads = max(min(1 << points_pow_2, MAX_THREADS_PER_BLOCK), 1); + + // Create the accessors + auto points_a = points.packed_accessor64(); + auto lengths_a = + lengths.packed_accessor64(); + auto K_a = K.packed_accessor64(); + auto idxs_a = idxs.packed_accessor64(); + auto start_idxs_a = + start_idxs.packed_accessor64(); + auto min_point_dist_a = + min_point_dist.packed_accessor64(); + + // Initialize the shared memory which will be used to store the + // distance/index of the best point seen by each thread. + size_t shared_mem = threads * sizeof(float) + threads * sizeof(int64_t); + // TODO: using shared memory for min_point_dist gives an ~2x speed up + // compared to using a global (N, P) shaped tensor, however for + // larger pointclouds this may exceed the shared memory limit per block. + // If a speed up is required for smaller pointclouds, then the storage + // could be switched to shared memory if the required total shared memory is + // within the memory limit per block. + + // Support a case for all powers of 2 up to MAX_THREADS_PER_BLOCK possible per + // block. + switch (threads) { + case 1024: + FarthestPointSamplingKernel<1024> + <<>>( + points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); + break; + case 512: + FarthestPointSamplingKernel<512><<>>( + points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); + break; + case 256: + FarthestPointSamplingKernel<256><<>>( + points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); + break; + case 128: + FarthestPointSamplingKernel<128><<>>( + points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); + break; + case 64: + FarthestPointSamplingKernel<64><<>>( + points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); + break; + case 32: + FarthestPointSamplingKernel<32><<>>( + points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); + break; + case 16: + FarthestPointSamplingKernel<16><<>>( + points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); + break; + case 8: + FarthestPointSamplingKernel<8><<>>( + points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); + break; + case 4: + FarthestPointSamplingKernel<4><<>>( + points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); + break; + case 2: + FarthestPointSamplingKernel<2><<>>( + points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); + break; + case 1: + FarthestPointSamplingKernel<1><<>>( + points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); + break; + default: + FarthestPointSamplingKernel<1024> + <<>>( + points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); + } + + AT_CUDA_CHECK(cudaGetLastError()); + return idxs; +} diff --git a/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.h b/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.h index bb4456d3..87c7faf5 100644 --- a/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.h +++ b/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.h @@ -29,28 +29,44 @@ // K: a tensor of length (N,) giving the number of // samples to select for each element in the batch. // The number of samples is typically << P. -// random_start_point: bool, if True, a random point is selected as the -// starting point for iterative sampling. +// start_idxs: (N,) long Tensor giving the index of the first point to +// sample. Default is all 0. When a random start point is required, +// start_idxs should be set to a random value between [0, lengths[n]] +// for batch element n. // Returns: // selected_indices: (N, K) array of selected indices. If the values in // K are not all the same, then the shape will be (N, max(K), D), and // padded with -1 for batch elements where k_i < max(K). The selected // points are gathered in the pytorch autograd wrapper. +at::Tensor FarthestPointSamplingCuda( + const at::Tensor& points, + const at::Tensor& lengths, + const at::Tensor& K, + const at::Tensor& start_idxs); + at::Tensor FarthestPointSamplingCpu( const at::Tensor& points, const at::Tensor& lengths, const at::Tensor& K, - const bool random_start_point); + const at::Tensor& start_idxs); // Exposed implementation. at::Tensor FarthestPointSampling( const at::Tensor& points, const at::Tensor& lengths, const at::Tensor& K, - const bool random_start_point) { + const at::Tensor& start_idxs) { if (points.is_cuda() || lengths.is_cuda() || K.is_cuda()) { - AT_ERROR("CUDA implementation not yet supported"); +#ifdef WITH_CUDA + CHECK_CUDA(points); + CHECK_CUDA(lengths); + CHECK_CUDA(K); + CHECK_CUDA(start_idxs); + return FarthestPointSamplingCuda(points, lengths, K, start_idxs); +#else + AT_ERROR("Not compiled with GPU support."); +#endif } - return FarthestPointSamplingCpu(points, lengths, K, random_start_point); + return FarthestPointSamplingCpu(points, lengths, K, start_idxs); } diff --git a/pytorch3d/csrc/utils/pytorch3d_cutils.h b/pytorch3d/csrc/utils/pytorch3d_cutils.h index 660c9dde..c9e80194 100644 --- a/pytorch3d/csrc/utils/pytorch3d_cutils.h +++ b/pytorch3d/csrc/utils/pytorch3d_cutils.h @@ -15,3 +15,6 @@ #define CHECK_CONTIGUOUS_CUDA(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) + +// Max possible threads per block +const int MAX_THREADS_PER_BLOCK = 1024; diff --git a/pytorch3d/csrc/utils/warp_reduce.cuh b/pytorch3d/csrc/utils/warp_reduce.cuh index 035dbf2a..5f021b75 100644 --- a/pytorch3d/csrc/utils/warp_reduce.cuh +++ b/pytorch3d/csrc/utils/warp_reduce.cuh @@ -10,41 +10,85 @@ #include #include -// helper WarpReduce used in .cu files +// Helper functions WarpReduceMin and WarpReduceMax used in .cu files +// Starting in Volta, instructions are no longer synchronous within a warp. +// We need to call __syncwarp() to sync the 32 threads in the warp +// instead of all the threads in the block. template -__device__ void WarpReduce( - volatile scalar_t* min_dists, - volatile int64_t* min_idxs, - const size_t tid) { +__device__ void +WarpReduceMin(scalar_t* min_dists, int64_t* min_idxs, const size_t tid) { // s = 32 if (min_dists[tid] > min_dists[tid + 32]) { min_idxs[tid] = min_idxs[tid + 32]; min_dists[tid] = min_dists[tid + 32]; } + __syncwarp(); // s = 16 if (min_dists[tid] > min_dists[tid + 16]) { min_idxs[tid] = min_idxs[tid + 16]; min_dists[tid] = min_dists[tid + 16]; } + __syncwarp(); // s = 8 if (min_dists[tid] > min_dists[tid + 8]) { min_idxs[tid] = min_idxs[tid + 8]; min_dists[tid] = min_dists[tid + 8]; } + __syncwarp(); // s = 4 if (min_dists[tid] > min_dists[tid + 4]) { min_idxs[tid] = min_idxs[tid + 4]; min_dists[tid] = min_dists[tid + 4]; } + __syncwarp(); // s = 2 if (min_dists[tid] > min_dists[tid + 2]) { min_idxs[tid] = min_idxs[tid + 2]; min_dists[tid] = min_dists[tid + 2]; } + __syncwarp(); // s = 1 if (min_dists[tid] > min_dists[tid + 1]) { min_idxs[tid] = min_idxs[tid + 1]; min_dists[tid] = min_dists[tid + 1]; } + __syncwarp(); +} + +template +__device__ void WarpReduceMax( + volatile scalar_t* dists, + volatile int64_t* dists_idx, + const size_t tid) { + if (dists[tid] < dists[tid + 32]) { + dists[tid] = dists[tid + 32]; + dists_idx[tid] = dists_idx[tid + 32]; + } + __syncwarp(); + if (dists[tid] < dists[tid + 16]) { + dists[tid] = dists[tid + 16]; + dists_idx[tid] = dists_idx[tid + 16]; + } + __syncwarp(); + if (dists[tid] < dists[tid + 8]) { + dists[tid] = dists[tid + 8]; + dists_idx[tid] = dists_idx[tid + 8]; + } + __syncwarp(); + if (dists[tid] < dists[tid + 4]) { + dists[tid] = dists[tid + 4]; + dists_idx[tid] = dists_idx[tid + 4]; + } + __syncwarp(); + if (dists[tid] < dists[tid + 2]) { + dists[tid] = dists[tid + 2]; + dists_idx[tid] = dists_idx[tid + 2]; + } + __syncwarp(); + if (dists[tid] < dists[tid + 1]) { + dists[tid] = dists[tid + 1]; + dists_idx[tid] = dists_idx[tid + 1]; + } + __syncwarp(); } diff --git a/pytorch3d/ops/__init__.py b/pytorch3d/ops/__init__.py index 213b350c..aeba4495 100644 --- a/pytorch3d/ops/__init__.py +++ b/pytorch3d/ops/__init__.py @@ -24,6 +24,7 @@ from .points_to_volumes import ( add_pointclouds_to_volumes, add_points_features_to_volume_densities_features, ) +from .sample_farthest_points import sample_farthest_points from .sample_points_from_meshes import sample_points_from_meshes from .subdivide_meshes import SubdivideMeshes from .utils import ( diff --git a/pytorch3d/ops/sample_farthest_points.py b/pytorch3d/ops/sample_farthest_points.py index c75253f6..40412694 100644 --- a/pytorch3d/ops/sample_farthest_points.py +++ b/pytorch3d/ops/sample_farthest_points.py @@ -57,7 +57,7 @@ def sample_farthest_points( if lengths is None: lengths = torch.full((N,), P, dtype=torch.int64, device=device) - if lengths.shape[0] != N: + if lengths.shape != (N,): raise ValueError("points and lengths must have same batch dimension.") # TODO: support providing K as a ratio of the total number of points instead of as an int @@ -77,9 +77,15 @@ def sample_farthest_points( if not (K.dtype == torch.int64): K = K.to(torch.int64) + # Generate the starting indices for sampling + start_idxs = torch.zeros_like(lengths) + if random_start_point: + for n in range(N): + start_idxs[n] = torch.randint(high=lengths[n], size=(1,)).item() + with torch.no_grad(): # pyre-fixme[16]: `pytorch3d_._C` has no attribute `sample_farthest_points`. - idx = _C.sample_farthest_points(points, lengths, K, random_start_point) + idx = _C.sample_farthest_points(points, lengths, K, start_idxs) sampled_points = masked_gather(points, idx) return sampled_points, idx diff --git a/tests/bm_sample_farthest_points.py b/tests/bm_sample_farthest_points.py index 5108d1d1..b48fe04e 100644 --- a/tests/bm_sample_farthest_points.py +++ b/tests/bm_sample_farthest_points.py @@ -29,8 +29,17 @@ def bm_fps() -> None: warmup_iters=1, ) - kwargs_list = [k for k in kwargs_list if k["device"] == "cpu"] - benchmark(TestFPS.sample_farthest_points, "FPS_CPU", kwargs_list, warmup_iters=1) + # Add some larger batch sizes and pointcloud sizes + Ns = [32] + Ps = [2048, 8192, 18384] + Ds = [3, 9] + Ks = [24, 48] + test_cases = product(Ns, Ps, Ds, Ks, backends) + for case in test_cases: + N, P, D, K, d = case + kwargs_list.append({"N": N, "P": P, "D": D, "K": K, "device": d}) + + benchmark(TestFPS.sample_farthest_points, "FPS", kwargs_list, warmup_iters=1) if __name__ == "__main__": diff --git a/tests/test_sample_farthest_points.py b/tests/test_sample_farthest_points.py index 7b071b18..2cd35df8 100644 --- a/tests/test_sample_farthest_points.py +++ b/tests/test_sample_farthest_points.py @@ -6,14 +6,25 @@ import unittest +import numpy as np import torch -from common_testing import TestCaseMixin, get_random_cuda_device +from common_testing import ( + TestCaseMixin, + get_random_cuda_device, + get_tests_dir, + get_pytorch3d_dir, +) +from pytorch3d.io import load_obj from pytorch3d.ops.sample_farthest_points import ( sample_farthest_points_naive, sample_farthest_points, ) from pytorch3d.ops.utils import masked_gather +DATA_DIR = get_tests_dir() / "data" +TUTORIAL_DATA_DIR = get_pytorch3d_dir() / "docs/tutorials/data" +DEBUG = False + class TestFPS(TestCaseMixin, unittest.TestCase): def _test_simple(self, fps_func, device="cpu"): @@ -123,22 +134,22 @@ class TestFPS(TestCaseMixin, unittest.TestCase): def _test_random_start(self, fps_func, device="cpu"): N, P, D, K = 5, 40, 5, 8 - points = torch.randn((N, P, D), device=device) - out_points, out_idxs = sample_farthest_points_naive( - points, K=K, random_start_point=True - ) - # Check the first index is not 0 for all batch elements + points = torch.randn((N, P, D), dtype=torch.float32, device=device) + out_points, out_idxs = fps_func(points, K=K, random_start_point=True) + # Check the first index is not 0 or the same number for all batch elements # when random_start_point = True self.assertTrue(out_idxs[:, 0].sum() > 0) + self.assertFalse(out_idxs[:, 0].eq(out_idxs[0, 0]).all()) def _test_gradcheck(self, fps_func, device="cpu"): - N, P, D, K = 2, 5, 3, 2 + N, P, D, K = 2, 10, 3, 2 points = torch.randn( (N, P, D), dtype=torch.float32, device=device, requires_grad=True ) + lengths = torch.randint(low=1, high=P, size=(N,), device=device) torch.autograd.gradcheck( fps_func, - (points, None, K), + (points, lengths, K), check_undefined_grad=False, eps=2e-3, atol=0.001, @@ -158,6 +169,76 @@ class TestFPS(TestCaseMixin, unittest.TestCase): self._test_random_start(sample_farthest_points, "cpu") self._test_gradcheck(sample_farthest_points, "cpu") + def test_sample_farthest_points_cuda(self): + device = get_random_cuda_device() + self._test_simple(sample_farthest_points, device) + self._test_errors(sample_farthest_points, device) + self._test_compare_random_heterogeneous(device) + self._test_random_start(sample_farthest_points, device) + self._test_gradcheck(sample_farthest_points, device) + + def test_cuda_vs_cpu(self): + """ + Compare cuda vs cpu on a complex object + """ + obj_filename = TUTORIAL_DATA_DIR / "cow_mesh/cow.obj" + K = 250 + + # Run on CPU + device = "cpu" + points, _, _ = load_obj(obj_filename, device=device, load_textures=False) + points = points[None, ...] + out_points_cpu, out_idxs_cpu = sample_farthest_points(points, K=K) + + # Run on GPU + device = get_random_cuda_device() + points_cuda = points.to(device) + out_points_cuda, out_idxs_cuda = sample_farthest_points(points_cuda, K=K) + + # Check that the indices from CUDA and CPU match + self.assertClose(out_idxs_cpu, out_idxs_cuda.cpu()) + + # Check there are no duplicate indices + val_mask = out_idxs_cuda[0].ne(-1) + vals, counts = torch.unique(out_idxs_cuda[0][val_mask], return_counts=True) + self.assertTrue(counts.le(1).all()) + + # Plot all results + if DEBUG: + # mplot3d is required for 3d projection plots + import matplotlib.pyplot as plt + from mpl_toolkits import mplot3d # noqa: F401 + + # Move to cpu and convert to numpy for plotting + points = points.squeeze() + out_points_cpu = out_points_cpu.squeeze().numpy() + out_points_cuda = out_points_cuda.squeeze().cpu().numpy() + + # Farthest point sampling CPU + fig = plt.figure(figsize=plt.figaspect(1.0 / 3)) + ax1 = fig.add_subplot(1, 3, 1, projection="3d") + ax1.scatter(*points.T, alpha=0.1) + ax1.scatter(*out_points_cpu.T, color="black") + ax1.set_title("FPS CPU") + + # Farthest point sampling CUDA + ax2 = fig.add_subplot(1, 3, 2, projection="3d") + ax2.scatter(*points.T, alpha=0.1) + ax2.scatter(*out_points_cuda.T, color="red") + ax2.set_title("FPS CUDA") + + # Random Sampling + random_points = np.random.permutation(points)[:K] + ax3 = fig.add_subplot(1, 3, 3, projection="3d") + ax3.scatter(*points.T, alpha=0.1) + ax3.scatter(*random_points.T, color="green") + ax3.set_title("Random") + + # Save image + filename = "DEBUG_fps.jpg" + filepath = DATA_DIR / filename + plt.savefig(filepath) + @staticmethod def sample_farthest_points_naive(N: int, P: int, D: int, K: int, device: str): device = torch.device(device)