mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-22 23:30:35 +08:00
Farthest point sampling CUDA
Summary:
CUDA implementation of farthest point sampling algorithm.
## Visual comparison
Compared to random sampling, farthest point sampling gives better coverage of the shape.
{F658631262}
## Reduction
Parallelized block reduction to find the max value at each iteration happens as follows:
1. First split the points into two equal sized parts (e.g. for a list with 8 values):
`[20, 27, 6, 8 | 11, 10, 2, 33]`
2. Use half of the thread (4 threads) to compare pairs of elements from each half (e.g elements [0, 4], [1, 5] etc) and store the result in the first half of the list:
`[20, 27, 6, 33 | 11, 10, 2, 33]`
Now we no longer care about the second part but again divide the first part into two
`[20, 27 | 6, 33| -, -, -, -]`
Now we can use 2 threads to compare the 4 elements
4. Finally we have gotten down to a single pair
`[20 | 33 | -, - | -, -, -, -]`
Use 1 thread to compare the remaining two elements
5. The max will now be at thread id = 0
`[33 | - | -, - | -, -, -, -]`
The reduction will give the farthest point for the selected batch index at this iteration.
Reviewed By: bottler, jcjohnson
Differential Revision: D30401803
fbshipit-source-id: 525bd5ae27c4b13b501812cfe62306bb003827d2
This commit is contained in:
committed by
Facebook GitHub Bot
parent
d9f7611c4b
commit
bd04ffaf77
@@ -24,6 +24,7 @@ from .points_to_volumes import (
|
||||
add_pointclouds_to_volumes,
|
||||
add_points_features_to_volume_densities_features,
|
||||
)
|
||||
from .sample_farthest_points import sample_farthest_points
|
||||
from .sample_points_from_meshes import sample_points_from_meshes
|
||||
from .subdivide_meshes import SubdivideMeshes
|
||||
from .utils import (
|
||||
|
||||
@@ -57,7 +57,7 @@ def sample_farthest_points(
|
||||
if lengths is None:
|
||||
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
||||
|
||||
if lengths.shape[0] != N:
|
||||
if lengths.shape != (N,):
|
||||
raise ValueError("points and lengths must have same batch dimension.")
|
||||
|
||||
# TODO: support providing K as a ratio of the total number of points instead of as an int
|
||||
@@ -77,9 +77,15 @@ def sample_farthest_points(
|
||||
if not (K.dtype == torch.int64):
|
||||
K = K.to(torch.int64)
|
||||
|
||||
# Generate the starting indices for sampling
|
||||
start_idxs = torch.zeros_like(lengths)
|
||||
if random_start_point:
|
||||
for n in range(N):
|
||||
start_idxs[n] = torch.randint(high=lengths[n], size=(1,)).item()
|
||||
|
||||
with torch.no_grad():
|
||||
# pyre-fixme[16]: `pytorch3d_._C` has no attribute `sample_farthest_points`.
|
||||
idx = _C.sample_farthest_points(points, lengths, K, random_start_point)
|
||||
idx = _C.sample_farthest_points(points, lengths, K, start_idxs)
|
||||
sampled_points = masked_gather(points, idx)
|
||||
|
||||
return sampled_points, idx
|
||||
|
||||
Reference in New Issue
Block a user