avoid CPU/GPU sync in sample_farthest_points

Summary:
Optimizing sample_farthest_poinst by reducing CPU/GPU sync:
1. replacing iterative randint for starting indexes for 1 function call, if length is constant
2. Avoid sync in fetching maxumum of sample points, if we sample the same amount
3. Initializing 1 tensor for samples and indixes

compare
https://fburl.com/mlhub/7wk0xi98
Before
{F1980383703}
after
{F1980383707}

Histogram match pretty closely
{F1980464338}

Reviewed By: bottler

Differential Revision: D78731869

fbshipit-source-id: 060528ae7a1e0fbbd005d129c151eaf9405841de
This commit is contained in:
Olga Gerasimova 2025-07-23 10:23:40 -07:00 committed by Facebook GitHub Bot
parent e3d3a67a89
commit 5043d15361
3 changed files with 24 additions and 10 deletions

View File

@ -107,7 +107,8 @@ at::Tensor FarthestPointSamplingCuda(
const at::Tensor& points, // (N, P, 3)
const at::Tensor& lengths, // (N,)
const at::Tensor& K, // (N,)
const at::Tensor& start_idxs) {
const at::Tensor& start_idxs,
const int64_t max_K_known = -1) {
// Check inputs are on the same device
at::TensorArg p_t{points, "points", 1}, lengths_t{lengths, "lengths", 2},
k_t{K, "K", 3}, start_idxs_t{start_idxs, "start_idxs", 4};
@ -129,7 +130,12 @@ at::Tensor FarthestPointSamplingCuda(
const int64_t N = points.size(0);
const int64_t P = points.size(1);
const int64_t max_K = at::max(K).item<int64_t>();
int64_t max_K;
if (max_K_known > 0) {
max_K = max_K_known;
} else {
max_K = at::max(K).item<int64_t>();
}
// Initialize the output tensor with the sampled indices
auto idxs = at::full({N, max_K}, -1, lengths.options());

View File

@ -43,7 +43,8 @@ at::Tensor FarthestPointSamplingCuda(
const at::Tensor& points,
const at::Tensor& lengths,
const at::Tensor& K,
const at::Tensor& start_idxs);
const at::Tensor& start_idxs,
const int64_t max_K_known = -1);
at::Tensor FarthestPointSamplingCpu(
const at::Tensor& points,
@ -56,14 +57,16 @@ at::Tensor FarthestPointSampling(
const at::Tensor& points,
const at::Tensor& lengths,
const at::Tensor& K,
const at::Tensor& start_idxs) {
const at::Tensor& start_idxs,
const int64_t max_K_known = -1) {
if (points.is_cuda() || lengths.is_cuda() || K.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CUDA(points);
CHECK_CUDA(lengths);
CHECK_CUDA(K);
CHECK_CUDA(start_idxs);
return FarthestPointSamplingCuda(points, lengths, K, start_idxs);
return FarthestPointSamplingCuda(
points, lengths, K, start_idxs, max_K_known);
#else
AT_ERROR("Not compiled with GPU support.");
#endif

View File

@ -55,6 +55,7 @@ def sample_farthest_points(
N, P, D = points.shape
device = points.device
constant_length = lengths is None
# Validate inputs
if lengths is None:
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
@ -65,7 +66,9 @@ def sample_farthest_points(
raise ValueError("A value in lengths was too large.")
# TODO: support providing K as a ratio of the total number of points instead of as an int
max_K = -1
if isinstance(K, int):
max_K = K
K = torch.full((N,), K, dtype=torch.int64, device=device)
elif isinstance(K, list):
K = torch.tensor(K, dtype=torch.int64, device=device)
@ -82,15 +85,17 @@ def sample_farthest_points(
K = K.to(torch.int64)
# Generate the starting indices for sampling
start_idxs = torch.zeros_like(lengths)
if random_start_point:
for n in range(N):
# pyre-fixme[6]: For 1st param expected `int` but got `Tensor`.
start_idxs[n] = torch.randint(high=lengths[n], size=(1,)).item()
if constant_length:
start_idxs = torch.randint(high=P, size=(N,), device=device)
else:
start_idxs = (lengths * torch.rand(lengths.size())).to(torch.int64)
else:
start_idxs = torch.zeros_like(lengths)
with torch.no_grad():
# pyre-fixme[16]: `pytorch3d_._C` has no attribute `sample_farthest_points`.
idx = _C.sample_farthest_points(points, lengths, K, start_idxs)
idx = _C.sample_farthest_points(points, lengths, K, start_idxs, max_K)
sampled_points = masked_gather(points, idx)
return sampled_points, idx