diff --git a/pytorch3d/visualization/__init__.py b/pytorch3d/visualization/__init__.py index a9b9685c..9c49f36d 100644 --- a/pytorch3d/visualization/__init__.py +++ b/pytorch3d/visualization/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from .mesh_plotly import AxisArgs, Lighting, plot_meshes +from .plotly_vis import AxisArgs, Lighting, plot_meshes, plot_pointclouds __all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/pytorch3d/visualization/mesh_plotly.py b/pytorch3d/visualization/plotly_vis.py similarity index 65% rename from pytorch3d/visualization/mesh_plotly.py rename to pytorch3d/visualization/plotly_vis.py index 14b2f1a4..8c252d82 100644 --- a/pytorch3d/visualization/mesh_plotly.py +++ b/pytorch3d/visualization/plotly_vis.py @@ -3,11 +3,12 @@ import warnings from typing import NamedTuple, Optional, Tuple +import numpy as np import plotly.graph_objects as go import torch from plotly.subplots import make_subplots from pytorch3d.renderer import TexturesVertex -from pytorch3d.structures import Meshes +from pytorch3d.structures import Meshes, Pointclouds class AxisArgs(NamedTuple): @@ -113,6 +114,105 @@ def plot_meshes(meshes: Meshes, *, in_subplots: bool = False, ncols: int = 1, ** return fig +def plot_pointclouds( + pointclouds: Pointclouds, + *, + in_subplots: bool = False, + ncols: int = 1, + max_points: int = 20000, + **kwargs, +): + """ + Takes a Pointclouds object and generates a plotly figure. If there is more than + one pointcloud in the batch, and in_subplots is set to be True, each pointcloud will be + visualized in an individual subplot with ncols number of subplots in the same row. + Otherwise, each pointcloud in the batch will be visualized as an individual trace in the + same plot. If the Pointclouds object has features that are size (3) or (4) then those + rgb/rgba values will be used for the plotly figure. Otherwise, plotly's default colors + will be used. Assumes that all rgb/rgba feature values are in the range [0,1]. + + Args: + pointclouds: Pointclouds object which can contain a batch of pointclouds. + in_subplots: if each pointcloud should be visualized in an individual subplot. + ncols: number of subplots in the same row if in_subplots is set to be True. Otherwise + ncols will be ignored. + max_points: maximum number of points to plot. If the cloud has more, they are + randomly subsampled. + **kwargs: Accepts lighting (a Lighting object) and lightposition for Scatter3D + and any of the args xaxis, yaxis and zaxis which scene accepts. + Accepts axis_args which is an AxisArgs object that is applied to the 3 axes. + Also accepts subplot_titles, whichshould be a list of string titles + matching the number of subplots. Example settings for axis_args and lighting are + given at the top of this file. + + Returns: + Plotly figure of the pointcloud(s). If there is more than one pointcloud in the batch, + the plotly figure will contain a plot with one trace per pointcloud, or with each + pointcloud in a separate subplot if in_subplots is True. + """ + pointclouds = pointclouds.detach().cpu() + subplot_titles = kwargs.get("subplot_titles", None) + fig = _gen_fig_with_subplots(len(pointclouds), in_subplots, ncols, subplot_titles) + for i in range(len(pointclouds)): + verts = pointclouds[i].points_packed() + features = pointclouds[i].features_packed() + + indices = None + if max_points is not None and verts.shape[0] > max_points: + indices = np.random.choice(verts.shape[0], max_points, replace=False) + verts = verts[indices] + + color = None + if features is not None: + features = features[indices] + if features.shape[1] == 4: # rgba + template = "rgb(%d, %d, %d, %f)" + rgb = (features[:, :3] * 255).int() + color = [ + template % (*rgb_, a_) for rgb_, a_ in zip(rgb, features[:, 3]) + ] + + if features.shape[1] == 3: + template = "rgb(%d, %d, %d)" + rgb = (features * 255).int() + color = [template % (r, g, b) for r, g, b in rgb] + + trace_row = i // ncols + 1 if in_subplots else 1 + trace_col = i % ncols + 1 if in_subplots else 1 + fig.add_trace( + go.Scatter3d( # pyre-ignore[16] + x=verts[:, 0], + y=verts[:, 1], + z=verts[:, 2], + marker={"color": color, "size": 1}, + mode="markers", + ), + row=trace_row, + col=trace_col, + ) + + # Ensure update for every subplot. + plot_scene = "scene" + str(i + 1) if in_subplots else "scene" + current_layout = fig["layout"][plot_scene] + + verts_center = verts.mean(0) + + axis_args = kwargs.get("axis_args", AxisArgs()) + + xaxis, yaxis, zaxis = _gen_updated_axis_bounds( + verts, verts_center, current_layout, axis_args + ) + xaxis.update(**kwargs.get("xaxis", {})) + yaxis.update(**kwargs.get("yaxis", {})) + zaxis.update(**kwargs.get("zaxis", {})) + + current_layout.update( + {"xaxis": xaxis, "yaxis": yaxis, "zaxis": zaxis, "aspectmode": "cube"} + ) + + return fig + + def _gen_fig_with_subplots( batch_size: int, in_subplots: bool, ncols: int, subplot_titles: Optional[list] ):