subsample pointclouds

Summary: New function to randomly subsample Pointclouds to a maximum size.

Reviewed By: nikhilaravi

Differential Revision: D30936533

fbshipit-source-id: 789eb5004b6a233034ec1c500f20f2d507a303ff
This commit is contained in:
Jeremy Reizenstein
2021-10-02 13:37:09 -07:00
committed by Facebook GitHub Bot
parent ee2b2feb98
commit 4281df19ce
3 changed files with 92 additions and 21 deletions

View File

@@ -4,6 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from itertools import zip_longest
from typing import Sequence, Union
import numpy as np
import torch
from ..common.types import Device, make_device
@@ -841,6 +845,54 @@ class Pointclouds:
new_clouds = self.clone()
return new_clouds.offset_(offsets_packed)
def subsample(self, max_points: Union[int, Sequence[int]]) -> "Pointclouds":
"""
Subsample each cloud so that it has at most max_points points.
Args:
max_points: maximum number of points in each cloud.
Returns:
new Pointclouds object, or self if nothing to be done.
"""
if isinstance(max_points, int):
max_points = [max_points] * len(self)
elif len(max_points) != len(self):
raise ValueError("wrong number of max_points supplied")
if all(
int(n_points) <= int(max_)
for n_points, max_ in zip(self.num_points_per_cloud(), max_points)
):
return self
points_list = []
features_list = []
normals_list = []
for max_, n_points, points, features, normals in zip_longest(
map(int, max_points),
map(int, self.num_points_per_cloud()),
self.points_list(),
self.features_list() or (),
self.normals_list() or (),
):
if n_points > max_:
keep_np = np.random.choice(n_points, max_, replace=False)
keep = torch.tensor(keep_np).to(points.device)
points = points[keep]
if features is not None:
features = features[keep]
if normals is not None:
normals = normals[keep]
points_list.append(points)
features_list.append(features)
normals_list.append(normals)
return Pointclouds(
points=points_list,
normals=self.normals_list() and normals_list,
features=self.features_list() and features_list,
)
def scale_(self, scale):
"""
Multiply the coordinates of this object by a scalar value.