Farthest point sampling C++

Summary: C++ implementation of iterative farthest point sampling.

Reviewed By: jcjohnson

Differential Revision: D30349887

fbshipit-source-id: d25990f857752633859fe00283e182858a870269
This commit is contained in:
Nikhila Ravi
2021-09-15 13:47:55 -07:00
committed by Facebook GitHub Bot
parent 3b7d78c7a7
commit d9f7611c4b
6 changed files with 346 additions and 19 deletions

View File

@@ -8,11 +8,12 @@ from random import randint
from typing import Optional, Tuple, Union, List
import torch
from pytorch3d import _C
from .utils import masked_gather
def sample_farthest_points_naive(
def sample_farthest_points(
points: torch.Tensor,
lengths: Optional[torch.Tensor] = None,
K: Union[int, List, torch.Tensor] = 50,
@@ -34,7 +35,7 @@ def sample_farthest_points_naive(
points: (N, P, D) array containing the batch of pointclouds
lengths: (N,) number of points in each pointcloud (to support heterogeneous
batches of pointclouds)
K: samples you want in each sampled point cloud (this is typically << P). If
K: samples required in each sampled point cloud (this is typically << P). If
K is an int then the same number of samples are selected for each
pointcloud in the batch. If K is a tensor is should be length (N,)
giving the number of samples to select for each element in the batch
@@ -52,6 +53,50 @@ def sample_farthest_points_naive(
N, P, D = points.shape
device = points.device
# Validate inputs
if lengths is None:
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
if lengths.shape[0] != N:
raise ValueError("points and lengths must have same batch dimension.")
# TODO: support providing K as a ratio of the total number of points instead of as an int
if isinstance(K, int):
K = torch.full((N,), K, dtype=torch.int64, device=device)
elif isinstance(K, list):
K = torch.tensor(K, dtype=torch.int64, device=device)
if K.shape[0] != N:
raise ValueError("K and points must have the same batch dimension")
# Check dtypes are correct and convert if necessary
if not (points.dtype == torch.float32):
points = points.to(torch.float32)
if not (lengths.dtype == torch.int64):
lengths = lengths.to(torch.int64)
if not (K.dtype == torch.int64):
K = K.to(torch.int64)
with torch.no_grad():
# pyre-fixme[16]: `pytorch3d_._C` has no attribute `sample_farthest_points`.
idx = _C.sample_farthest_points(points, lengths, K, random_start_point)
sampled_points = masked_gather(points, idx)
return sampled_points, idx
def sample_farthest_points_naive(
points: torch.Tensor,
lengths: Optional[torch.Tensor] = None,
K: Union[int, List, torch.Tensor] = 50,
random_start_point: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Same Args/Returns as sample_farthest_points
"""
N, P, D = points.shape
device = points.device
# Validate inputs
if lengths is None:
lengths = torch.full((N,), P, dtype=torch.int64, device=device)