mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
subsample pointclouds
Summary: New function to randomly subsample Pointclouds to a maximum size. Reviewed By: nikhilaravi Differential Revision: D30936533 fbshipit-source-id: 789eb5004b6a233034ec1c500f20f2d507a303ff
This commit is contained in:
parent
ee2b2feb98
commit
4281df19ce
@ -4,6 +4,10 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from itertools import zip_longest
|
||||
from typing import Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..common.types import Device, make_device
|
||||
@ -841,6 +845,54 @@ class Pointclouds:
|
||||
new_clouds = self.clone()
|
||||
return new_clouds.offset_(offsets_packed)
|
||||
|
||||
def subsample(self, max_points: Union[int, Sequence[int]]) -> "Pointclouds":
|
||||
"""
|
||||
Subsample each cloud so that it has at most max_points points.
|
||||
|
||||
Args:
|
||||
max_points: maximum number of points in each cloud.
|
||||
|
||||
Returns:
|
||||
new Pointclouds object, or self if nothing to be done.
|
||||
"""
|
||||
if isinstance(max_points, int):
|
||||
max_points = [max_points] * len(self)
|
||||
elif len(max_points) != len(self):
|
||||
raise ValueError("wrong number of max_points supplied")
|
||||
if all(
|
||||
int(n_points) <= int(max_)
|
||||
for n_points, max_ in zip(self.num_points_per_cloud(), max_points)
|
||||
):
|
||||
return self
|
||||
|
||||
points_list = []
|
||||
features_list = []
|
||||
normals_list = []
|
||||
for max_, n_points, points, features, normals in zip_longest(
|
||||
map(int, max_points),
|
||||
map(int, self.num_points_per_cloud()),
|
||||
self.points_list(),
|
||||
self.features_list() or (),
|
||||
self.normals_list() or (),
|
||||
):
|
||||
if n_points > max_:
|
||||
keep_np = np.random.choice(n_points, max_, replace=False)
|
||||
keep = torch.tensor(keep_np).to(points.device)
|
||||
points = points[keep]
|
||||
if features is not None:
|
||||
features = features[keep]
|
||||
if normals is not None:
|
||||
normals = normals[keep]
|
||||
points_list.append(points)
|
||||
features_list.append(features)
|
||||
normals_list.append(normals)
|
||||
|
||||
return Pointclouds(
|
||||
points=points_list,
|
||||
normals=self.normals_list() and normals_list,
|
||||
features=self.features_list() and features_list,
|
||||
)
|
||||
|
||||
def scale_(self, scale):
|
||||
"""
|
||||
Multiply the coordinates of this object by a scalar value.
|
||||
|
@ -7,7 +7,6 @@
|
||||
import warnings
|
||||
from typing import Dict, List, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import plotly.graph_objects as go
|
||||
import torch
|
||||
from plotly.subplots import make_subplots
|
||||
@ -644,31 +643,12 @@ def _add_pointcloud_trace(
|
||||
max_points_per_pointcloud: the number of points to render, which are randomly sampled.
|
||||
marker_size: the size of the rendered points
|
||||
"""
|
||||
pointclouds = pointclouds.detach().cpu()
|
||||
pointclouds = pointclouds.detach().cpu().subsample(max_points_per_pointcloud)
|
||||
verts = pointclouds.points_packed()
|
||||
features = pointclouds.features_packed()
|
||||
|
||||
indices = None
|
||||
if pointclouds.num_points_per_cloud().max() > max_points_per_pointcloud:
|
||||
start_index = 0
|
||||
index_list = []
|
||||
for num_points in pointclouds.num_points_per_cloud():
|
||||
if num_points > max_points_per_pointcloud:
|
||||
indices_cloud = np.random.choice(
|
||||
num_points, max_points_per_pointcloud, replace=False
|
||||
)
|
||||
index_list.append(start_index + indices_cloud)
|
||||
else:
|
||||
index_list.append(start_index + np.arange(num_points))
|
||||
start_index += num_points
|
||||
indices = np.concatenate(index_list)
|
||||
verts = verts[indices]
|
||||
|
||||
color = None
|
||||
if features is not None:
|
||||
if indices is not None:
|
||||
# Only select features if we selected vertices above
|
||||
features = features[indices]
|
||||
if features.shape[1] == 4: # rgba
|
||||
template = "rgb(%d, %d, %d, %f)"
|
||||
rgb = (features[:, :3].clamp(0.0, 1.0) * 255).int()
|
||||
|
@ -1057,6 +1057,45 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
|
||||
clouds.normals_packed(), torch.cat(normals_est_list, dim=0)
|
||||
)
|
||||
|
||||
def test_subsample(self):
|
||||
lengths = [4, 5, 13, 3]
|
||||
points = [torch.rand(length, 3) for length in lengths]
|
||||
features = [torch.rand(length, 5) for length in lengths]
|
||||
normals = [torch.rand(length, 3) for length in lengths]
|
||||
|
||||
pcl1 = Pointclouds(points=points).cuda()
|
||||
self.assertIs(pcl1, pcl1.subsample(13))
|
||||
self.assertIs(pcl1, pcl1.subsample([6, 13, 13, 13]))
|
||||
|
||||
lengths_max_4 = torch.tensor([4, 4, 4, 3]).cuda()
|
||||
for with_normals, with_features in itertools.product([True, False], repeat=2):
|
||||
with self.subTest(f"{with_normals} {with_features}"):
|
||||
pcl = Pointclouds(
|
||||
points=points,
|
||||
normals=normals if with_normals else None,
|
||||
features=features if with_features else None,
|
||||
)
|
||||
pcl_copy = pcl.subsample(max_points=4)
|
||||
for length, points_ in zip(lengths_max_4, pcl_copy.points_list()):
|
||||
self.assertEqual(points_.shape, (length, 3))
|
||||
if with_normals:
|
||||
for length, normals_ in zip(lengths_max_4, pcl_copy.normals_list()):
|
||||
self.assertEqual(normals_.shape, (length, 3))
|
||||
else:
|
||||
self.assertIsNone(pcl_copy.normals_list())
|
||||
if with_features:
|
||||
for length, features_ in zip(
|
||||
lengths_max_4, pcl_copy.features_list()
|
||||
):
|
||||
self.assertEqual(features_.shape, (length, 5))
|
||||
else:
|
||||
self.assertIsNone(pcl_copy.features_list())
|
||||
|
||||
pcl2 = Pointclouds(points=points)
|
||||
pcl_copy2 = pcl2.subsample(lengths_max_4)
|
||||
for length, points_ in zip(lengths_max_4, pcl_copy2.points_list()):
|
||||
self.assertEqual(points_.shape, (length, 3))
|
||||
|
||||
@staticmethod
|
||||
def compute_packed_with_init(
|
||||
num_clouds: int = 10, max_p: int = 100, features: int = 300
|
||||
|
Loading…
x
Reference in New Issue
Block a user