diff --git a/pytorch3d/vis/plotly_vis.py b/pytorch3d/vis/plotly_vis.py index 7765b006..22e510e5 100644 --- a/pytorch3d/vis/plotly_vis.py +++ b/pytorch3d/vis/plotly_vis.py @@ -61,14 +61,14 @@ def plot_scene( **kwargs, ): """ - Main function to visualize Meshes and Pointclouds. + Main function to visualize Meshes, Cameras and Pointclouds. Plots input Pointclouds, Meshes, and Cameras data into named subplots, with named traces based on the dictionary keys. Cameras are rendered at the camera center location using a wireframe. Args: plots: A dict containing subplot and trace names, - as well as the Meshes and Pointclouds objects to be rendered. + as well as the Meshes, Cameras and Pointclouds objects to be rendered. See below for examples of the format. viewpoint_cameras: an instance of a Cameras object providing a location to view the plotly plot from. If the batch size is equal @@ -574,11 +574,21 @@ def _add_pointcloud_trace( pointclouds = pointclouds.detach().cpu() verts = pointclouds.points_packed() features = pointclouds.features_packed() - total_points_count = max_points_per_pointcloud * len(pointclouds) indices = None - if verts.shape[0] > total_points_count: - indices = np.random.choice(verts.shape[0], total_points_count, replace=False) + if pointclouds.num_points_per_cloud().max() > max_points_per_pointcloud: + start_index = 0 + index_list = [] + for num_points in pointclouds.num_points_per_cloud(): + if num_points > max_points_per_pointcloud: + indices_cloud = np.random.choice( + num_points, max_points_per_pointcloud, replace=False + ) + index_list.append(start_index + indices_cloud) + else: + index_list.append(start_index + np.arange(num_points)) + start_index += num_points + indices = np.concatenate(index_list) verts = verts[indices] color = None