diff --git a/pytorch3d/vis/__init__.py b/pytorch3d/vis/__init__.py index 9c49f36d..2ce6f7bb 100644 --- a/pytorch3d/vis/__init__.py +++ b/pytorch3d/vis/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from .plotly_vis import AxisArgs, Lighting, plot_meshes, plot_pointclouds +from .plotly_vis import AxisArgs, Lighting, plot_scene __all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/pytorch3d/vis/plotly_vis.py b/pytorch3d/vis/plotly_vis.py index 8c252d82..fbe8e1b2 100644 --- a/pytorch3d/vis/plotly_vis.py +++ b/pytorch3d/vis/plotly_vis.py @@ -1,14 +1,14 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. import warnings -from typing import NamedTuple, Optional, Tuple +from typing import Dict, List, NamedTuple, Union import numpy as np import plotly.graph_objects as go import torch from plotly.subplots import make_subplots from pytorch3d.renderer import TexturesVertex -from pytorch3d.structures import Meshes, Pointclouds +from pytorch3d.structures import Meshes, Pointclouds, join_meshes_as_scene class AxisArgs(NamedTuple): @@ -31,216 +31,333 @@ class Lighting(NamedTuple): vertexnormalsepsilon: float = 1e-12 -def plot_meshes(meshes: Meshes, *, in_subplots: bool = False, ncols: int = 1, **kwargs): - """ - Takes a Meshes object and generates a plotly figure. If there is more than - one mesh in the batch and in_subplots=True, each mesh will be - visualized in an individual subplot with ncols number of subplots in the same row. - Otherwise, each mesh in the batch will be visualized as an individual trace in the - same plot. If the Meshes object has vertex colors defined as its texture, the vertex - colors will be used for generating the plotly figure. Otherwise plotly's default - colors will be used. - - Args: - meshes: Meshes object to be visualized in a plotly figure. - in_subplots: if each mesh in the batch should be visualized in an individual subplot - ncols: number of subplots in the same row if in_subplots is set to be True. Otherwise - ncols will be ignored. - **kwargs: Accepts lighting (a Lighting object) and lightposition for Mesh3D and any of - the args xaxis, yaxis and zaxis accept for scene. Accepts axis_args which is an - AxisArgs object that is applied to the 3 axes. Also accepts subplot_titles, which - should be a list of string titles matching the number of subplots. - Example settings for axis_args and lighting are given above. - - Returns: - Plotly figure of the mesh. If there is more than one mesh in the batch, - the plotly figure will contain a series of vertically stacked subplots. - """ - meshes = meshes.detach().cpu() - subplot_titles = kwargs.get("subplot_titles", None) - fig = _gen_fig_with_subplots(len(meshes), in_subplots, ncols, subplot_titles) - for i in range(len(meshes)): - verts = meshes[i].verts_packed() - faces = meshes[i].faces_packed() - # If mesh has vertex colors defined as texture, use vertex colors - # for figure, otherwise use plotly's default colors. - verts_rgb = None - if isinstance(meshes[i].textures, TexturesVertex): - verts_rgb = meshes[i].textures.verts_features_packed() - verts_rgb.clamp_(min=0.0, max=1.0) - verts_rgb = torch.tensor(255.0) * verts_rgb - - # Reposition the unused vertices to be "inside" the object - # (i.e. they won't be visible in the plot). - verts_used = torch.zeros((verts.shape[0],), dtype=torch.bool) - verts_used[torch.unique(faces)] = True - verts_center = verts[verts_used].mean(0) - verts[~verts_used] = verts_center - - trace_row = i // ncols + 1 if in_subplots else 1 - trace_col = i % ncols + 1 if in_subplots else 1 - fig.add_trace( - go.Mesh3d( # pyre-ignore[16] - x=verts[:, 0], - y=verts[:, 1], - z=verts[:, 2], - vertexcolor=verts_rgb, - i=faces[:, 0], - j=faces[:, 1], - k=faces[:, 2], - lighting=kwargs.get("lighting", Lighting())._asdict(), - lightposition=kwargs.get("lightposition", {}), - ), - row=trace_row, - col=trace_col, - ) - # Ensure update for every subplot. - plot_scene = "scene" + str(i + 1) if in_subplots else "scene" - current_layout = fig["layout"][plot_scene] - - axis_args = kwargs.get("axis_args", AxisArgs()) - - xaxis, yaxis, zaxis = _gen_updated_axis_bounds( - verts, verts_center, current_layout, axis_args - ) - # Update the axis bounds with the axis settings passed in as kwargs. - xaxis.update(**kwargs.get("xaxis", {})) - yaxis.update(**kwargs.get("yaxis", {})) - zaxis.update(**kwargs.get("zaxis", {})) - - current_layout.update( - {"xaxis": xaxis, "yaxis": yaxis, "zaxis": zaxis, "aspectmode": "cube"} - ) - return fig - - -def plot_pointclouds( - pointclouds: Pointclouds, +def plot_scene( + plots: Dict[str, Dict[str, Union[Pointclouds, Meshes]]], *, - in_subplots: bool = False, ncols: int = 1, - max_points: int = 20000, + pointcloud_max_points: int = 20000, + pointcloud_marker_size: int = 1, **kwargs, ): """ - Takes a Pointclouds object and generates a plotly figure. If there is more than - one pointcloud in the batch, and in_subplots is set to be True, each pointcloud will be - visualized in an individual subplot with ncols number of subplots in the same row. - Otherwise, each pointcloud in the batch will be visualized as an individual trace in the - same plot. If the Pointclouds object has features that are size (3) or (4) then those - rgb/rgba values will be used for the plotly figure. Otherwise, plotly's default colors - will be used. Assumes that all rgb/rgba feature values are in the range [0,1]. + Main function to visualize Meshes and Pointclouds. + Plots input Pointclouds and Meshes data into named subplots, + with named traces based on the dictionary keys. Args: - pointclouds: Pointclouds object which can contain a batch of pointclouds. - in_subplots: if each pointcloud should be visualized in an individual subplot. - ncols: number of subplots in the same row if in_subplots is set to be True. Otherwise - ncols will be ignored. - max_points: maximum number of points to plot. If the cloud has more, they are - randomly subsampled. - **kwargs: Accepts lighting (a Lighting object) and lightposition for Scatter3D - and any of the args xaxis, yaxis and zaxis which scene accepts. - Accepts axis_args which is an AxisArgs object that is applied to the 3 axes. - Also accepts subplot_titles, whichshould be a list of string titles - matching the number of subplots. Example settings for axis_args and lighting are - given at the top of this file. + plots: A dict containing subplot and trace names, + as well as the Meshes and Pointclouds objects to be rendered. + See below for examples of the format. + ncols: the number of subplots per row + pointcloud_max_points: the maximum number of points to plot from + a pointcloud. If more are present, a random sample of size + pointcloud_max_points is used. + pointcloud_marker_size: the size of the points rendered by plotly + when plotting a pointcloud. + **kwargs: Accepts lighting (a Lighting object) and any of the args xaxis, + yaxis and zaxis which Plotly's scene accepts. Accepts axis_args, + which is an AxisArgs object that is applied to all 3 axes. + Example settings for axis_args and lighting are given at the + top of this file. - Returns: - Plotly figure of the pointcloud(s). If there is more than one pointcloud in the batch, - the plotly figure will contain a plot with one trace per pointcloud, or with each - pointcloud in a separate subplot if in_subplots is True. + Example: + + ..code-block::python + + mesh = ... + point_cloud = ... + fig = plot_scene({ + "subplot_title": { + "mesh_trace_title": mesh, + "pointcloud_trace_title": point_cloud + } + }) + fig.show() + + The above example will render one subplot which has both a mesh and pointcloud. + + If the Meshes or Pointclouds objects are batched, then every object in that batch + will be plotted in a single trace. + + ..code-block::python + mesh = ... # batch size 2 + point_cloud = ... # batch size 2 + fig = plot_scene({ + "subplot_title": { + "mesh_trace_title": mesh, + "pointcloud_trace_title": point_cloud + } + }) + fig.show() + + The above example renders one subplot with 2 traces, each of which renders + both objects from their respective batched data. + + Multiple subplots follow the same pattern: + ..code-block::python + mesh = ... # batch size 2 + point_cloud = ... # batch size 2 + fig = plot_scene({ + "subplot1_title": { + "mesh_trace_title": mesh[0], + "pointcloud_trace_title": point_cloud[0] + }, + "subplot2_title": { + "mesh_trace_title": mesh[1], + "pointcloud_trace_title": point_cloud[1] + } + }, + ncols=2) # specify the number of subplots per row + fig.show() + + The above example will render two subplots, each containing a mesh + and a pointcloud. The ncols argument will render two subplots in one row + instead of having them vertically stacked because the default is one subplot + per row. + + For an example of using kwargs, see below: + ..code-block::python + mesh = ... + point_cloud = ... + fig = plot_scene({ + "subplot_title": { + "mesh_trace_title": mesh, + "pointcloud_trace_title": point_cloud + } + }, + axis_args=AxisArgs(backgroundcolor="rgb(200,230,200)")) # kwarg axis_args + fig.show() + + The above example will render each axis with the input background color. + + See the tutorials in pytorch3d/docs/tutorials for more examples + (namely rendered_color_points.ipynb and rendered_textured_meshes.ipynb). """ - pointclouds = pointclouds.detach().cpu() - subplot_titles = kwargs.get("subplot_titles", None) - fig = _gen_fig_with_subplots(len(pointclouds), in_subplots, ncols, subplot_titles) - for i in range(len(pointclouds)): - verts = pointclouds[i].points_packed() - features = pointclouds[i].features_packed() + subplots = list(plots.keys()) + fig = _gen_fig_with_subplots(len(subplots), ncols, subplots) + lighting = kwargs.get("lighting", Lighting())._asdict() + axis_args_dict = kwargs.get("axis_args", AxisArgs())._asdict() - indices = None - if max_points is not None and verts.shape[0] > max_points: - indices = np.random.choice(verts.shape[0], max_points, replace=False) - verts = verts[indices] + # Set axis arguments to defaults defined at the top of this file + x_settings = {**axis_args_dict} + y_settings = {**axis_args_dict} + z_settings = {**axis_args_dict} - color = None - if features is not None: - features = features[indices] - if features.shape[1] == 4: # rgba - template = "rgb(%d, %d, %d, %f)" - rgb = (features[:, :3] * 255).int() - color = [ - template % (*rgb_, a_) for rgb_, a_ in zip(rgb, features[:, 3]) - ] + # Update the axes with any axis settings passed in as kwargs. + x_settings.update(**kwargs.get("xaxis", {})) + y_settings.update(**kwargs.get("yaxis", {})) + z_settings.update(**kwargs.get("zaxis", {})) - if features.shape[1] == 3: - template = "rgb(%d, %d, %d)" - rgb = (features * 255).int() - color = [template % (r, g, b) for r, g, b in rgb] + camera = { + "up": { + "x": 0, + "y": 1, + "z": 0, + } # set the up vector to match PyTorch3D world coordinates conventions + } - trace_row = i // ncols + 1 if in_subplots else 1 - trace_col = i % ncols + 1 if in_subplots else 1 - fig.add_trace( - go.Scatter3d( # pyre-ignore[16] - x=verts[:, 0], - y=verts[:, 1], - z=verts[:, 2], - marker={"color": color, "size": 1}, - mode="markers", - ), - row=trace_row, - col=trace_col, - ) + for subplot_idx in range(len(subplots)): + subplot_name = subplots[subplot_idx] + traces = plots[subplot_name] + for trace_name, struct in traces.items(): + if isinstance(struct, Meshes): + _add_mesh_trace(fig, struct, trace_name, subplot_idx, ncols, lighting) + elif isinstance(struct, Pointclouds): + _add_pointcloud_trace( + fig, + struct, + trace_name, + subplot_idx, + ncols, + pointcloud_max_points, + pointcloud_marker_size, + ) + else: + raise ValueError( + "struct {} is not a Meshes or Pointclouds object".format(struct) + ) # Ensure update for every subplot. - plot_scene = "scene" + str(i + 1) if in_subplots else "scene" + plot_scene = "scene" + str(subplot_idx + 1) current_layout = fig["layout"][plot_scene] + xaxis = current_layout["xaxis"] + yaxis = current_layout["yaxis"] + zaxis = current_layout["zaxis"] - verts_center = verts.mean(0) - - axis_args = kwargs.get("axis_args", AxisArgs()) - - xaxis, yaxis, zaxis = _gen_updated_axis_bounds( - verts, verts_center, current_layout, axis_args - ) - xaxis.update(**kwargs.get("xaxis", {})) - yaxis.update(**kwargs.get("yaxis", {})) - zaxis.update(**kwargs.get("zaxis", {})) + # Update the axes with our above default and provided settings. + xaxis.update(**x_settings) + yaxis.update(**y_settings) + zaxis.update(**z_settings) current_layout.update( - {"xaxis": xaxis, "yaxis": yaxis, "zaxis": zaxis, "aspectmode": "cube"} + { + "xaxis": xaxis, + "yaxis": yaxis, + "zaxis": zaxis, + "aspectmode": "cube", + "camera": camera, + } ) return fig -def _gen_fig_with_subplots( - batch_size: int, in_subplots: bool, ncols: int, subplot_titles: Optional[list] +def _add_mesh_trace( + fig: go.Figure, # pyre-ignore[11] + meshes: Meshes, + trace_name: str, + subplot_idx: int, + ncols: int, + lighting: Lighting, ): + """ + Adds a trace rendering a Meshes object to the passed in figure, with + a given name and in a specific subplot. + + Args: + fig: plotly figure to add the trace within. + meshes: Meshes object to render. It can be batched. + trace_name: name to label the trace with. + subplot_idx: identifies the subplot, with 0 being the top left. + ncols: the number of subplots per row. + lighting: a Lighting object that specifies the Mesh3D lighting. + """ + + mesh = join_meshes_as_scene(meshes) + mesh = mesh.detach().cpu() + verts = mesh.verts_packed() + faces = mesh.faces_packed() + # If mesh has vertex colors defined as texture, use vertex colors + # for figure, otherwise use plotly's default colors. + verts_rgb = None + if isinstance(mesh.textures, TexturesVertex): + verts_rgb = mesh.textures.verts_features_packed() + verts_rgb.clamp_(min=0.0, max=1.0) + verts_rgb = torch.tensor(255.0) * verts_rgb + + # Reposition the unused vertices to be "inside" the object + # (i.e. they won't be visible in the plot). + verts_used = torch.zeros((verts.shape[0],), dtype=torch.bool) + verts_used[torch.unique(faces)] = True + verts_center = verts[verts_used].mean(0) + verts[~verts_used] = verts_center + + row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1 + fig.add_trace( + go.Mesh3d( # pyre-ignore[16] + x=verts[:, 0], + y=verts[:, 1], + z=verts[:, 2], + vertexcolor=verts_rgb, + i=faces[:, 0], + j=faces[:, 1], + k=faces[:, 2], + lighting=lighting, + name=trace_name, + ), + row=row, + col=col, + ) + + # Access the current subplot's scene configuration + plot_scene = "scene" + str(subplot_idx + 1) + current_layout = fig["layout"][plot_scene] + + # update the bounds of the axes for the current trace + max_expand = (verts.max(0)[0] - verts.min(0)[0]).max() + _update_axes_bounds(verts_center, max_expand, current_layout) + + +def _add_pointcloud_trace( + fig: go.Figure, + pointclouds: Pointclouds, + trace_name: str, + subplot_idx: int, + ncols: int, + max_points_per_pointcloud: int, + marker_size: int, +): + """ + Adds a trace rendering a Pointclouds object to the passed in figure, with + a given name and in a specific subplot. + + Args: + fig: plotly figure to add the trace within. + pointclouds: Pointclouds object to render. It can be batched. + trace_name: name to label the trace with. + subplot_idx: identifies the subplot, with 0 being the top left. + ncols: the number of sublpots per row. + max_points_per_pointcloud: the number of points to render, which are randomly sampled. + marker_size: the size of the rendered points + """ + pointclouds = pointclouds.detach().cpu() + verts = pointclouds.points_packed() + features = pointclouds.features_packed() + total_points_count = max_points_per_pointcloud * len(pointclouds) + + indices = None + if verts.shape[0] > total_points_count: + indices = np.random.choice(verts.shape[0], total_points_count, replace=False) + verts = verts[indices] + + color = None + if features is not None: + features = features[indices] + if features.shape[1] == 4: # rgba + template = "rgb(%d, %d, %d, %f)" + rgb = (features[:, :3].clamp(0.0, 1.0) * 255).int() + color = [template % (*rgb_, a_) for rgb_, a_ in zip(rgb, features[:, 3])] + + if features.shape[1] == 3: + template = "rgb(%d, %d, %d)" + rgb = (features.clamp(0.0, 1.0) * 255).int() + color = [template % (r, g, b) for r, g, b in rgb] + + row = subplot_idx // ncols + 1 + col = subplot_idx % ncols + 1 + fig.add_trace( + go.Scatter3d( # pyre-ignore[16] + x=verts[:, 0], + y=verts[:, 1], + z=verts[:, 2], + marker={"color": color, "size": marker_size}, + mode="markers", + name=trace_name, + ), + row=row, + col=col, + ) + + # Access the current subplot's scene configuration + plot_scene = "scene" + str(subplot_idx + 1) + current_layout = fig["layout"][plot_scene] + + # update the bounds of the axes for the current trace + verts_center = verts.mean(0) + max_expand = (verts.max(0)[0] - verts.min(0)[0]).max() + _update_axes_bounds(verts_center, max_expand, current_layout) + + +def _gen_fig_with_subplots(batch_size: int, ncols: int, subplot_titles: List[str]): """ Takes in the number of objects to be plotted and generate a plotly figure - with the appropriate number and orientation of subplots + with the appropriate number and orientation of titled subplots. Args: batch_size: the number of elements in the batch of objects to be visualized. - in_subplots: if each object should be visualized in an individual subplot. - ncols: number of subplots in the same row if in_subplots is set to be True. Otherwise - ncols will be ignored. - subplot_titles: titles for the subplot(s). list of strings of length batch_size - if in_subplots is True, otherwise length 1. + ncols: number of subplots in the same row. + subplot_titles: titles for the subplot(s). list of strings of length batch_size. Returns: - Plotly figure with one plot if in_subplots is false. Otherwise, returns a plotly figure - with ncols subplots per row. + Plotly figure with ncols subplots per row, and batch_size subplots. """ if batch_size % ncols != 0: msg = "ncols is invalid for the given mesh batch size." warnings.warn(msg) - fig_rows = batch_size // ncols if in_subplots else 1 - fig_cols = ncols if in_subplots else 1 + fig_rows = batch_size // ncols + fig_cols = ncols fig_type = [{"type": "scene"}] specs = [fig_type * fig_cols] * fig_rows # subplot_titles must have one title per subplot - if subplot_titles is not None and len(subplot_titles) != fig_cols * fig_rows: - subplot_titles = None fig = make_subplots( rows=fig_rows, cols=fig_cols, @@ -251,27 +368,19 @@ def _gen_fig_with_subplots( return fig -def _gen_updated_axis_bounds( - verts: torch.Tensor, +def _update_axes_bounds( verts_center: torch.Tensor, + max_expand: float, current_layout: go.Scene, # pyre-ignore[11] - axis_args: AxisArgs, -) -> Tuple[dict, dict, dict]: +): """ - Takes in the vertices, center point of the vertices, and the current plotly figure and - outputs axes with bounds that capture all points in the current subplot. + Takes in the vertices' center point and max spread, and the current plotly figure + layout and updates the layout to have bounds that include all traces for that subplot. Args: - verts: tensor of size (N, 3) representing N points with xyz coordinates. - verts_center: tensor of size (3) corresponding to the center point of verts. - current_layout: the current plotly figure layout scene corresponding to verts' trace. - axis_args: an AxisArgs object with default and/or user-set values for plotly's axes. - - Returns: - a 3 item tuple of xaxis, yaxis, and zaxis, which are dictionaries with axis arguments - for plotly including a range key with value the minimum and maximum value for that axis. + verts_center: tensor of size (3) corresponding to a trace's vertices' center point. + max_expand: the maximum spread in any dimension of the trace's vertices. + current_layout: the plotly figure layout scene corresponding to the referenced trace. """ - # Get ranges of vertices. - max_expand = (verts.max(0)[0] - verts.min(0)[0]).max() verts_min = verts_center - max_expand verts_max = verts_center + max_expand bounds = torch.t(torch.stack((verts_min, verts_max))) @@ -292,8 +401,8 @@ def _gen_updated_axis_bounds( if old_zrange is not None: z_range[0] = min(z_range[0], old_zrange[0]) z_range[1] = max(z_range[1], old_zrange[1]) - axis_args_dict = axis_args._asdict() - xaxis = {"range": x_range, **axis_args_dict} - yaxis = {"range": y_range, **axis_args_dict} - zaxis = {"range": z_range, **axis_args_dict} - return xaxis, yaxis, zaxis + + xaxis = {"range": x_range} + yaxis = {"range": y_range} + zaxis = {"range": z_range} + current_layout.update({"xaxis": xaxis, "yaxis": yaxis, "zaxis": zaxis})