mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-22 23:30:35 +08:00
Farthest point sampling C++
Summary: C++ implementation of iterative farthest point sampling. Reviewed By: jcjohnson Differential Revision: D30349887 fbshipit-source-id: d25990f857752633859fe00283e182858a870269
This commit is contained in:
committed by
Facebook GitHub Bot
parent
3b7d78c7a7
commit
d9f7611c4b
@@ -8,11 +8,12 @@ from random import randint
|
||||
from typing import Optional, Tuple, Union, List
|
||||
|
||||
import torch
|
||||
from pytorch3d import _C
|
||||
|
||||
from .utils import masked_gather
|
||||
|
||||
|
||||
def sample_farthest_points_naive(
|
||||
def sample_farthest_points(
|
||||
points: torch.Tensor,
|
||||
lengths: Optional[torch.Tensor] = None,
|
||||
K: Union[int, List, torch.Tensor] = 50,
|
||||
@@ -34,7 +35,7 @@ def sample_farthest_points_naive(
|
||||
points: (N, P, D) array containing the batch of pointclouds
|
||||
lengths: (N,) number of points in each pointcloud (to support heterogeneous
|
||||
batches of pointclouds)
|
||||
K: samples you want in each sampled point cloud (this is typically << P). If
|
||||
K: samples required in each sampled point cloud (this is typically << P). If
|
||||
K is an int then the same number of samples are selected for each
|
||||
pointcloud in the batch. If K is a tensor is should be length (N,)
|
||||
giving the number of samples to select for each element in the batch
|
||||
@@ -52,6 +53,50 @@ def sample_farthest_points_naive(
|
||||
N, P, D = points.shape
|
||||
device = points.device
|
||||
|
||||
# Validate inputs
|
||||
if lengths is None:
|
||||
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
||||
|
||||
if lengths.shape[0] != N:
|
||||
raise ValueError("points and lengths must have same batch dimension.")
|
||||
|
||||
# TODO: support providing K as a ratio of the total number of points instead of as an int
|
||||
if isinstance(K, int):
|
||||
K = torch.full((N,), K, dtype=torch.int64, device=device)
|
||||
elif isinstance(K, list):
|
||||
K = torch.tensor(K, dtype=torch.int64, device=device)
|
||||
|
||||
if K.shape[0] != N:
|
||||
raise ValueError("K and points must have the same batch dimension")
|
||||
|
||||
# Check dtypes are correct and convert if necessary
|
||||
if not (points.dtype == torch.float32):
|
||||
points = points.to(torch.float32)
|
||||
if not (lengths.dtype == torch.int64):
|
||||
lengths = lengths.to(torch.int64)
|
||||
if not (K.dtype == torch.int64):
|
||||
K = K.to(torch.int64)
|
||||
|
||||
with torch.no_grad():
|
||||
# pyre-fixme[16]: `pytorch3d_._C` has no attribute `sample_farthest_points`.
|
||||
idx = _C.sample_farthest_points(points, lengths, K, random_start_point)
|
||||
sampled_points = masked_gather(points, idx)
|
||||
|
||||
return sampled_points, idx
|
||||
|
||||
|
||||
def sample_farthest_points_naive(
|
||||
points: torch.Tensor,
|
||||
lengths: Optional[torch.Tensor] = None,
|
||||
K: Union[int, List, torch.Tensor] = 50,
|
||||
random_start_point: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Same Args/Returns as sample_farthest_points
|
||||
"""
|
||||
N, P, D = points.shape
|
||||
device = points.device
|
||||
|
||||
# Validate inputs
|
||||
if lengths is None:
|
||||
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
||||
|
||||
Reference in New Issue
Block a user