Farthest point sampling CUDA

Summary:
CUDA implementation of farthest point sampling algorithm.

## Visual comparison

Compared to random sampling, farthest point sampling gives better coverage of the shape.

{F658631262}

## Reduction

Parallelized block reduction to find the max value at each iteration happens as follows:

1. First split the points into two equal sized parts (e.g. for a list with 8 values):
`[20, 27, 6, 8 | 11, 10, 2, 33]`
2. Use half of the thread (4 threads) to compare pairs of elements from each half (e.g elements [0, 4], [1, 5] etc) and store the result in the first half of the list:
`[20, 27, 6, 33 | 11, 10, 2, 33]`
Now we no longer care about the second part but again divide the first part into two
`[20, 27 | 6, 33| -, -, -, -]`
Now we can use 2 threads to compare the 4 elements
4. Finally we have gotten down to a single pair
`[20 | 33 | -, - | -, -, -, -]`
Use 1 thread to compare the remaining two elements
5. The max will now be at thread id = 0
`[33 | - | -, - | -, -, -, -]`
The reduction will give the farthest point for the selected batch index at this iteration.

Reviewed By: bottler, jcjohnson

Differential Revision: D30401803

fbshipit-source-id: 525bd5ae27c4b13b501812cfe62306bb003827d2
This commit is contained in:
Nikhila Ravi
2021-09-15 13:47:55 -07:00
committed by Facebook GitHub Bot
parent d9f7611c4b
commit bd04ffaf77
11 changed files with 441 additions and 33 deletions

View File

@@ -24,6 +24,7 @@ from .points_to_volumes import (
add_pointclouds_to_volumes,
add_points_features_to_volume_densities_features,
)
from .sample_farthest_points import sample_farthest_points
from .sample_points_from_meshes import sample_points_from_meshes
from .subdivide_meshes import SubdivideMeshes
from .utils import (

View File

@@ -57,7 +57,7 @@ def sample_farthest_points(
if lengths is None:
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
if lengths.shape[0] != N:
if lengths.shape != (N,):
raise ValueError("points and lengths must have same batch dimension.")
# TODO: support providing K as a ratio of the total number of points instead of as an int
@@ -77,9 +77,15 @@ def sample_farthest_points(
if not (K.dtype == torch.int64):
K = K.to(torch.int64)
# Generate the starting indices for sampling
start_idxs = torch.zeros_like(lengths)
if random_start_point:
for n in range(N):
start_idxs[n] = torch.randint(high=lengths[n], size=(1,)).item()
with torch.no_grad():
# pyre-fixme[16]: `pytorch3d_._C` has no attribute `sample_farthest_points`.
idx = _C.sample_farthest_points(points, lengths, K, random_start_point)
idx = _C.sample_farthest_points(points, lengths, K, start_idxs)
sampled_points = masked_gather(points, idx)
return sampled_points, idx