mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-21 05:26:02 +08:00
avoid CPU/GPU sync in sample_farthest_points
Summary: Optimizing sample_farthest_poinst by reducing CPU/GPU sync: 1. replacing iterative randint for starting indexes for 1 function call, if length is constant 2. Avoid sync in fetching maxumum of sample points, if we sample the same amount 3. Initializing 1 tensor for samples and indixes compare https://fburl.com/mlhub/7wk0xi98 Before {F1980383703} after {F1980383707} Histogram match pretty closely {F1980464338} Reviewed By: bottler Differential Revision: D78731869 fbshipit-source-id: 060528ae7a1e0fbbd005d129c151eaf9405841de
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e3d3a67a89
commit
5043d15361
@@ -55,6 +55,7 @@ def sample_farthest_points(
|
||||
N, P, D = points.shape
|
||||
device = points.device
|
||||
|
||||
constant_length = lengths is None
|
||||
# Validate inputs
|
||||
if lengths is None:
|
||||
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
||||
@@ -65,7 +66,9 @@ def sample_farthest_points(
|
||||
raise ValueError("A value in lengths was too large.")
|
||||
|
||||
# TODO: support providing K as a ratio of the total number of points instead of as an int
|
||||
max_K = -1
|
||||
if isinstance(K, int):
|
||||
max_K = K
|
||||
K = torch.full((N,), K, dtype=torch.int64, device=device)
|
||||
elif isinstance(K, list):
|
||||
K = torch.tensor(K, dtype=torch.int64, device=device)
|
||||
@@ -82,15 +85,17 @@ def sample_farthest_points(
|
||||
K = K.to(torch.int64)
|
||||
|
||||
# Generate the starting indices for sampling
|
||||
start_idxs = torch.zeros_like(lengths)
|
||||
if random_start_point:
|
||||
for n in range(N):
|
||||
# pyre-fixme[6]: For 1st param expected `int` but got `Tensor`.
|
||||
start_idxs[n] = torch.randint(high=lengths[n], size=(1,)).item()
|
||||
if constant_length:
|
||||
start_idxs = torch.randint(high=P, size=(N,), device=device)
|
||||
else:
|
||||
start_idxs = (lengths * torch.rand(lengths.size())).to(torch.int64)
|
||||
else:
|
||||
start_idxs = torch.zeros_like(lengths)
|
||||
|
||||
with torch.no_grad():
|
||||
# pyre-fixme[16]: `pytorch3d_._C` has no attribute `sample_farthest_points`.
|
||||
idx = _C.sample_farthest_points(points, lengths, K, start_idxs)
|
||||
idx = _C.sample_farthest_points(points, lengths, K, start_idxs, max_K)
|
||||
sampled_points = masked_gather(points, idx)
|
||||
|
||||
return sampled_points, idx
|
||||
|
||||
Reference in New Issue
Block a user