subsample pointclouds

Summary: New function to randomly subsample Pointclouds to a maximum size.

Reviewed By: nikhilaravi

Differential Revision: D30936533

fbshipit-source-id: 789eb5004b6a233034ec1c500f20f2d507a303ff
This commit is contained in:
Jeremy Reizenstein
2021-10-02 13:37:09 -07:00
committed by Facebook GitHub Bot
parent ee2b2feb98
commit 4281df19ce
3 changed files with 92 additions and 21 deletions

View File

@@ -4,6 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from itertools import zip_longest
from typing import Sequence, Union
import numpy as np
import torch
from ..common.types import Device, make_device
@@ -841,6 +845,54 @@ class Pointclouds:
new_clouds = self.clone()
return new_clouds.offset_(offsets_packed)
def subsample(self, max_points: Union[int, Sequence[int]]) -> "Pointclouds":
"""
Subsample each cloud so that it has at most max_points points.
Args:
max_points: maximum number of points in each cloud.
Returns:
new Pointclouds object, or self if nothing to be done.
"""
if isinstance(max_points, int):
max_points = [max_points] * len(self)
elif len(max_points) != len(self):
raise ValueError("wrong number of max_points supplied")
if all(
int(n_points) <= int(max_)
for n_points, max_ in zip(self.num_points_per_cloud(), max_points)
):
return self
points_list = []
features_list = []
normals_list = []
for max_, n_points, points, features, normals in zip_longest(
map(int, max_points),
map(int, self.num_points_per_cloud()),
self.points_list(),
self.features_list() or (),
self.normals_list() or (),
):
if n_points > max_:
keep_np = np.random.choice(n_points, max_, replace=False)
keep = torch.tensor(keep_np).to(points.device)
points = points[keep]
if features is not None:
features = features[keep]
if normals is not None:
normals = normals[keep]
points_list.append(points)
features_list.append(features)
normals_list.append(normals)
return Pointclouds(
points=points_list,
normals=self.normals_list() and normals_list,
features=self.features_list() and features_list,
)
def scale_(self, scale):
"""
Multiply the coordinates of this object by a scalar value.

View File

@@ -7,7 +7,6 @@
import warnings
from typing import Dict, List, NamedTuple, Optional, Tuple, Union
import numpy as np
import plotly.graph_objects as go
import torch
from plotly.subplots import make_subplots
@@ -644,31 +643,12 @@ def _add_pointcloud_trace(
max_points_per_pointcloud: the number of points to render, which are randomly sampled.
marker_size: the size of the rendered points
"""
pointclouds = pointclouds.detach().cpu()
pointclouds = pointclouds.detach().cpu().subsample(max_points_per_pointcloud)
verts = pointclouds.points_packed()
features = pointclouds.features_packed()
indices = None
if pointclouds.num_points_per_cloud().max() > max_points_per_pointcloud:
start_index = 0
index_list = []
for num_points in pointclouds.num_points_per_cloud():
if num_points > max_points_per_pointcloud:
indices_cloud = np.random.choice(
num_points, max_points_per_pointcloud, replace=False
)
index_list.append(start_index + indices_cloud)
else:
index_list.append(start_index + np.arange(num_points))
start_index += num_points
indices = np.concatenate(index_list)
verts = verts[indices]
color = None
if features is not None:
if indices is not None:
# Only select features if we selected vertices above
features = features[indices]
if features.shape[1] == 4: # rgba
template = "rgb(%d, %d, %d, %f)"
rgb = (features[:, :3].clamp(0.0, 1.0) * 255).int()