From 5043d15361d16a7093b4b60572c5f730c6c83308 Mon Sep 17 00:00:00 2001 From: Olga Gerasimova Date: Wed, 23 Jul 2025 10:23:40 -0700 Subject: [PATCH] avoid CPU/GPU sync in sample_farthest_points Summary: Optimizing sample_farthest_poinst by reducing CPU/GPU sync: 1. replacing iterative randint for starting indexes for 1 function call, if length is constant 2. Avoid sync in fetching maxumum of sample points, if we sample the same amount 3. Initializing 1 tensor for samples and indixes compare https://fburl.com/mlhub/7wk0xi98 Before {F1980383703} after {F1980383707} Histogram match pretty closely {F1980464338} Reviewed By: bottler Differential Revision: D78731869 fbshipit-source-id: 060528ae7a1e0fbbd005d129c151eaf9405841de --- .../sample_farthest_points.cu | 10 ++++++++-- .../sample_farthest_points.h | 9 ++++++--- pytorch3d/ops/sample_farthest_points.py | 15 ++++++++++----- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.cu b/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.cu index a0b84dd4..5b788629 100644 --- a/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.cu +++ b/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.cu @@ -107,7 +107,8 @@ at::Tensor FarthestPointSamplingCuda( const at::Tensor& points, // (N, P, 3) const at::Tensor& lengths, // (N,) const at::Tensor& K, // (N,) - const at::Tensor& start_idxs) { + const at::Tensor& start_idxs, + const int64_t max_K_known = -1) { // Check inputs are on the same device at::TensorArg p_t{points, "points", 1}, lengths_t{lengths, "lengths", 2}, k_t{K, "K", 3}, start_idxs_t{start_idxs, "start_idxs", 4}; @@ -129,7 +130,12 @@ at::Tensor FarthestPointSamplingCuda( const int64_t N = points.size(0); const int64_t P = points.size(1); - const int64_t max_K = at::max(K).item(); + int64_t max_K; + if (max_K_known > 0) { + max_K = max_K_known; + } else { + max_K = at::max(K).item(); + } // Initialize the output tensor with the sampled indices auto idxs = at::full({N, max_K}, -1, lengths.options()); diff --git a/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.h b/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.h index a44a0a81..0db40758 100644 --- a/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.h +++ b/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.h @@ -43,7 +43,8 @@ at::Tensor FarthestPointSamplingCuda( const at::Tensor& points, const at::Tensor& lengths, const at::Tensor& K, - const at::Tensor& start_idxs); + const at::Tensor& start_idxs, + const int64_t max_K_known = -1); at::Tensor FarthestPointSamplingCpu( const at::Tensor& points, @@ -56,14 +57,16 @@ at::Tensor FarthestPointSampling( const at::Tensor& points, const at::Tensor& lengths, const at::Tensor& K, - const at::Tensor& start_idxs) { + const at::Tensor& start_idxs, + const int64_t max_K_known = -1) { if (points.is_cuda() || lengths.is_cuda() || K.is_cuda()) { #ifdef WITH_CUDA CHECK_CUDA(points); CHECK_CUDA(lengths); CHECK_CUDA(K); CHECK_CUDA(start_idxs); - return FarthestPointSamplingCuda(points, lengths, K, start_idxs); + return FarthestPointSamplingCuda( + points, lengths, K, start_idxs, max_K_known); #else AT_ERROR("Not compiled with GPU support."); #endif diff --git a/pytorch3d/ops/sample_farthest_points.py b/pytorch3d/ops/sample_farthest_points.py index a45b1de2..999aa0e4 100644 --- a/pytorch3d/ops/sample_farthest_points.py +++ b/pytorch3d/ops/sample_farthest_points.py @@ -55,6 +55,7 @@ def sample_farthest_points( N, P, D = points.shape device = points.device + constant_length = lengths is None # Validate inputs if lengths is None: lengths = torch.full((N,), P, dtype=torch.int64, device=device) @@ -65,7 +66,9 @@ def sample_farthest_points( raise ValueError("A value in lengths was too large.") # TODO: support providing K as a ratio of the total number of points instead of as an int + max_K = -1 if isinstance(K, int): + max_K = K K = torch.full((N,), K, dtype=torch.int64, device=device) elif isinstance(K, list): K = torch.tensor(K, dtype=torch.int64, device=device) @@ -82,15 +85,17 @@ def sample_farthest_points( K = K.to(torch.int64) # Generate the starting indices for sampling - start_idxs = torch.zeros_like(lengths) if random_start_point: - for n in range(N): - # pyre-fixme[6]: For 1st param expected `int` but got `Tensor`. - start_idxs[n] = torch.randint(high=lengths[n], size=(1,)).item() + if constant_length: + start_idxs = torch.randint(high=P, size=(N,), device=device) + else: + start_idxs = (lengths * torch.rand(lengths.size())).to(torch.int64) + else: + start_idxs = torch.zeros_like(lengths) with torch.no_grad(): # pyre-fixme[16]: `pytorch3d_._C` has no attribute `sample_farthest_points`. - idx = _C.sample_farthest_points(points, lengths, K, start_idxs) + idx = _C.sample_farthest_points(points, lengths, K, start_idxs, max_K) sampled_points = masked_gather(points, idx) return sampled_points, idx