From 4281df19cefb640067a49b961587342d9e4d85ba Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Sat, 2 Oct 2021 13:37:09 -0700 Subject: [PATCH] subsample pointclouds Summary: New function to randomly subsample Pointclouds to a maximum size. Reviewed By: nikhilaravi Differential Revision: D30936533 fbshipit-source-id: 789eb5004b6a233034ec1c500f20f2d507a303ff --- pytorch3d/structures/pointclouds.py | 52 +++++++++++++++++++++++++++++ pytorch3d/vis/plotly_vis.py | 22 +----------- tests/test_pointclouds.py | 39 ++++++++++++++++++++++ 3 files changed, 92 insertions(+), 21 deletions(-) diff --git a/pytorch3d/structures/pointclouds.py b/pytorch3d/structures/pointclouds.py index e36a25f2..7c9cd4cc 100644 --- a/pytorch3d/structures/pointclouds.py +++ b/pytorch3d/structures/pointclouds.py @@ -4,6 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from itertools import zip_longest +from typing import Sequence, Union + +import numpy as np import torch from ..common.types import Device, make_device @@ -841,6 +845,54 @@ class Pointclouds: new_clouds = self.clone() return new_clouds.offset_(offsets_packed) + def subsample(self, max_points: Union[int, Sequence[int]]) -> "Pointclouds": + """ + Subsample each cloud so that it has at most max_points points. + + Args: + max_points: maximum number of points in each cloud. + + Returns: + new Pointclouds object, or self if nothing to be done. + """ + if isinstance(max_points, int): + max_points = [max_points] * len(self) + elif len(max_points) != len(self): + raise ValueError("wrong number of max_points supplied") + if all( + int(n_points) <= int(max_) + for n_points, max_ in zip(self.num_points_per_cloud(), max_points) + ): + return self + + points_list = [] + features_list = [] + normals_list = [] + for max_, n_points, points, features, normals in zip_longest( + map(int, max_points), + map(int, self.num_points_per_cloud()), + self.points_list(), + self.features_list() or (), + self.normals_list() or (), + ): + if n_points > max_: + keep_np = np.random.choice(n_points, max_, replace=False) + keep = torch.tensor(keep_np).to(points.device) + points = points[keep] + if features is not None: + features = features[keep] + if normals is not None: + normals = normals[keep] + points_list.append(points) + features_list.append(features) + normals_list.append(normals) + + return Pointclouds( + points=points_list, + normals=self.normals_list() and normals_list, + features=self.features_list() and features_list, + ) + def scale_(self, scale): """ Multiply the coordinates of this object by a scalar value. diff --git a/pytorch3d/vis/plotly_vis.py b/pytorch3d/vis/plotly_vis.py index 20468767..4a1dcf86 100644 --- a/pytorch3d/vis/plotly_vis.py +++ b/pytorch3d/vis/plotly_vis.py @@ -7,7 +7,6 @@ import warnings from typing import Dict, List, NamedTuple, Optional, Tuple, Union -import numpy as np import plotly.graph_objects as go import torch from plotly.subplots import make_subplots @@ -644,31 +643,12 @@ def _add_pointcloud_trace( max_points_per_pointcloud: the number of points to render, which are randomly sampled. marker_size: the size of the rendered points """ - pointclouds = pointclouds.detach().cpu() + pointclouds = pointclouds.detach().cpu().subsample(max_points_per_pointcloud) verts = pointclouds.points_packed() features = pointclouds.features_packed() - indices = None - if pointclouds.num_points_per_cloud().max() > max_points_per_pointcloud: - start_index = 0 - index_list = [] - for num_points in pointclouds.num_points_per_cloud(): - if num_points > max_points_per_pointcloud: - indices_cloud = np.random.choice( - num_points, max_points_per_pointcloud, replace=False - ) - index_list.append(start_index + indices_cloud) - else: - index_list.append(start_index + np.arange(num_points)) - start_index += num_points - indices = np.concatenate(index_list) - verts = verts[indices] - color = None if features is not None: - if indices is not None: - # Only select features if we selected vertices above - features = features[indices] if features.shape[1] == 4: # rgba template = "rgb(%d, %d, %d, %f)" rgb = (features[:, :3].clamp(0.0, 1.0) * 255).int() diff --git a/tests/test_pointclouds.py b/tests/test_pointclouds.py index aaeb3d05..f83de817 100644 --- a/tests/test_pointclouds.py +++ b/tests/test_pointclouds.py @@ -1057,6 +1057,45 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase): clouds.normals_packed(), torch.cat(normals_est_list, dim=0) ) + def test_subsample(self): + lengths = [4, 5, 13, 3] + points = [torch.rand(length, 3) for length in lengths] + features = [torch.rand(length, 5) for length in lengths] + normals = [torch.rand(length, 3) for length in lengths] + + pcl1 = Pointclouds(points=points).cuda() + self.assertIs(pcl1, pcl1.subsample(13)) + self.assertIs(pcl1, pcl1.subsample([6, 13, 13, 13])) + + lengths_max_4 = torch.tensor([4, 4, 4, 3]).cuda() + for with_normals, with_features in itertools.product([True, False], repeat=2): + with self.subTest(f"{with_normals} {with_features}"): + pcl = Pointclouds( + points=points, + normals=normals if with_normals else None, + features=features if with_features else None, + ) + pcl_copy = pcl.subsample(max_points=4) + for length, points_ in zip(lengths_max_4, pcl_copy.points_list()): + self.assertEqual(points_.shape, (length, 3)) + if with_normals: + for length, normals_ in zip(lengths_max_4, pcl_copy.normals_list()): + self.assertEqual(normals_.shape, (length, 3)) + else: + self.assertIsNone(pcl_copy.normals_list()) + if with_features: + for length, features_ in zip( + lengths_max_4, pcl_copy.features_list() + ): + self.assertEqual(features_.shape, (length, 5)) + else: + self.assertIsNone(pcl_copy.features_list()) + + pcl2 = Pointclouds(points=points) + pcl_copy2 = pcl2.subsample(lengths_max_4) + for length, points_ in zip(lengths_max_4, pcl_copy2.points_list()): + self.assertEqual(points_.shape, (length, 3)) + @staticmethod def compute_packed_with_init( num_clouds: int = 10, max_p: int = 100, features: int = 300