Refactor plot_meshes and plot_pointclouds to one generalizable API, plot_scene

Summary: Defines a function plot_scene that takes in a dictionary defining subplot and trace layouts for Mesh/Pointcloud objects and plots them. Also supports other plotly axis arguments and mesh lighting. Plot_batch_individually is a wrapper function that takes in one or multiple batched Meshes/Pointclouds and uses plot_scene to plot each element within a batch in an individual subplot, possibly sharing that subplot with traces of other individual elements of the other batched structures passed in.

Reviewed By: nikhilaravi

Differential Revision: D24235479

fbshipit-source-id: 9f669f1b186d55fe5c75552083316c0cf1387472
This commit is contained in:
Amitav Baruah 2020-10-20 17:14:38 -07:00 committed by Facebook GitHub Bot
parent abd390319c
commit 964893cdcb
2 changed files with 305 additions and 196 deletions

View File

@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .plotly_vis import AxisArgs, Lighting, plot_meshes, plot_pointclouds from .plotly_vis import AxisArgs, Lighting, plot_scene
__all__ = [k for k in globals().keys() if not k.startswith("_")] __all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@ -1,14 +1,14 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import warnings import warnings
from typing import NamedTuple, Optional, Tuple from typing import Dict, List, NamedTuple, Union
import numpy as np import numpy as np
import plotly.graph_objects as go import plotly.graph_objects as go
import torch import torch
from plotly.subplots import make_subplots from plotly.subplots import make_subplots
from pytorch3d.renderer import TexturesVertex from pytorch3d.renderer import TexturesVertex
from pytorch3d.structures import Meshes, Pointclouds from pytorch3d.structures import Meshes, Pointclouds, join_meshes_as_scene
class AxisArgs(NamedTuple): class AxisArgs(NamedTuple):
@ -31,216 +31,333 @@ class Lighting(NamedTuple):
vertexnormalsepsilon: float = 1e-12 vertexnormalsepsilon: float = 1e-12
def plot_meshes(meshes: Meshes, *, in_subplots: bool = False, ncols: int = 1, **kwargs): def plot_scene(
""" plots: Dict[str, Dict[str, Union[Pointclouds, Meshes]]],
Takes a Meshes object and generates a plotly figure. If there is more than
one mesh in the batch and in_subplots=True, each mesh will be
visualized in an individual subplot with ncols number of subplots in the same row.
Otherwise, each mesh in the batch will be visualized as an individual trace in the
same plot. If the Meshes object has vertex colors defined as its texture, the vertex
colors will be used for generating the plotly figure. Otherwise plotly's default
colors will be used.
Args:
meshes: Meshes object to be visualized in a plotly figure.
in_subplots: if each mesh in the batch should be visualized in an individual subplot
ncols: number of subplots in the same row if in_subplots is set to be True. Otherwise
ncols will be ignored.
**kwargs: Accepts lighting (a Lighting object) and lightposition for Mesh3D and any of
the args xaxis, yaxis and zaxis accept for scene. Accepts axis_args which is an
AxisArgs object that is applied to the 3 axes. Also accepts subplot_titles, which
should be a list of string titles matching the number of subplots.
Example settings for axis_args and lighting are given above.
Returns:
Plotly figure of the mesh. If there is more than one mesh in the batch,
the plotly figure will contain a series of vertically stacked subplots.
"""
meshes = meshes.detach().cpu()
subplot_titles = kwargs.get("subplot_titles", None)
fig = _gen_fig_with_subplots(len(meshes), in_subplots, ncols, subplot_titles)
for i in range(len(meshes)):
verts = meshes[i].verts_packed()
faces = meshes[i].faces_packed()
# If mesh has vertex colors defined as texture, use vertex colors
# for figure, otherwise use plotly's default colors.
verts_rgb = None
if isinstance(meshes[i].textures, TexturesVertex):
verts_rgb = meshes[i].textures.verts_features_packed()
verts_rgb.clamp_(min=0.0, max=1.0)
verts_rgb = torch.tensor(255.0) * verts_rgb
# Reposition the unused vertices to be "inside" the object
# (i.e. they won't be visible in the plot).
verts_used = torch.zeros((verts.shape[0],), dtype=torch.bool)
verts_used[torch.unique(faces)] = True
verts_center = verts[verts_used].mean(0)
verts[~verts_used] = verts_center
trace_row = i // ncols + 1 if in_subplots else 1
trace_col = i % ncols + 1 if in_subplots else 1
fig.add_trace(
go.Mesh3d( # pyre-ignore[16]
x=verts[:, 0],
y=verts[:, 1],
z=verts[:, 2],
vertexcolor=verts_rgb,
i=faces[:, 0],
j=faces[:, 1],
k=faces[:, 2],
lighting=kwargs.get("lighting", Lighting())._asdict(),
lightposition=kwargs.get("lightposition", {}),
),
row=trace_row,
col=trace_col,
)
# Ensure update for every subplot.
plot_scene = "scene" + str(i + 1) if in_subplots else "scene"
current_layout = fig["layout"][plot_scene]
axis_args = kwargs.get("axis_args", AxisArgs())
xaxis, yaxis, zaxis = _gen_updated_axis_bounds(
verts, verts_center, current_layout, axis_args
)
# Update the axis bounds with the axis settings passed in as kwargs.
xaxis.update(**kwargs.get("xaxis", {}))
yaxis.update(**kwargs.get("yaxis", {}))
zaxis.update(**kwargs.get("zaxis", {}))
current_layout.update(
{"xaxis": xaxis, "yaxis": yaxis, "zaxis": zaxis, "aspectmode": "cube"}
)
return fig
def plot_pointclouds(
pointclouds: Pointclouds,
*, *,
in_subplots: bool = False,
ncols: int = 1, ncols: int = 1,
max_points: int = 20000, pointcloud_max_points: int = 20000,
pointcloud_marker_size: int = 1,
**kwargs, **kwargs,
): ):
""" """
Takes a Pointclouds object and generates a plotly figure. If there is more than Main function to visualize Meshes and Pointclouds.
one pointcloud in the batch, and in_subplots is set to be True, each pointcloud will be Plots input Pointclouds and Meshes data into named subplots,
visualized in an individual subplot with ncols number of subplots in the same row. with named traces based on the dictionary keys.
Otherwise, each pointcloud in the batch will be visualized as an individual trace in the
same plot. If the Pointclouds object has features that are size (3) or (4) then those
rgb/rgba values will be used for the plotly figure. Otherwise, plotly's default colors
will be used. Assumes that all rgb/rgba feature values are in the range [0,1].
Args: Args:
pointclouds: Pointclouds object which can contain a batch of pointclouds. plots: A dict containing subplot and trace names,
in_subplots: if each pointcloud should be visualized in an individual subplot. as well as the Meshes and Pointclouds objects to be rendered.
ncols: number of subplots in the same row if in_subplots is set to be True. Otherwise See below for examples of the format.
ncols will be ignored. ncols: the number of subplots per row
max_points: maximum number of points to plot. If the cloud has more, they are pointcloud_max_points: the maximum number of points to plot from
randomly subsampled. a pointcloud. If more are present, a random sample of size
**kwargs: Accepts lighting (a Lighting object) and lightposition for Scatter3D pointcloud_max_points is used.
and any of the args xaxis, yaxis and zaxis which scene accepts. pointcloud_marker_size: the size of the points rendered by plotly
Accepts axis_args which is an AxisArgs object that is applied to the 3 axes. when plotting a pointcloud.
Also accepts subplot_titles, whichshould be a list of string titles **kwargs: Accepts lighting (a Lighting object) and any of the args xaxis,
matching the number of subplots. Example settings for axis_args and lighting are yaxis and zaxis which Plotly's scene accepts. Accepts axis_args,
given at the top of this file. which is an AxisArgs object that is applied to all 3 axes.
Example settings for axis_args and lighting are given at the
top of this file.
Returns: Example:
Plotly figure of the pointcloud(s). If there is more than one pointcloud in the batch,
the plotly figure will contain a plot with one trace per pointcloud, or with each ..code-block::python
pointcloud in a separate subplot if in_subplots is True.
mesh = ...
point_cloud = ...
fig = plot_scene({
"subplot_title": {
"mesh_trace_title": mesh,
"pointcloud_trace_title": point_cloud
}
})
fig.show()
The above example will render one subplot which has both a mesh and pointcloud.
If the Meshes or Pointclouds objects are batched, then every object in that batch
will be plotted in a single trace.
..code-block::python
mesh = ... # batch size 2
point_cloud = ... # batch size 2
fig = plot_scene({
"subplot_title": {
"mesh_trace_title": mesh,
"pointcloud_trace_title": point_cloud
}
})
fig.show()
The above example renders one subplot with 2 traces, each of which renders
both objects from their respective batched data.
Multiple subplots follow the same pattern:
..code-block::python
mesh = ... # batch size 2
point_cloud = ... # batch size 2
fig = plot_scene({
"subplot1_title": {
"mesh_trace_title": mesh[0],
"pointcloud_trace_title": point_cloud[0]
},
"subplot2_title": {
"mesh_trace_title": mesh[1],
"pointcloud_trace_title": point_cloud[1]
}
},
ncols=2) # specify the number of subplots per row
fig.show()
The above example will render two subplots, each containing a mesh
and a pointcloud. The ncols argument will render two subplots in one row
instead of having them vertically stacked because the default is one subplot
per row.
For an example of using kwargs, see below:
..code-block::python
mesh = ...
point_cloud = ...
fig = plot_scene({
"subplot_title": {
"mesh_trace_title": mesh,
"pointcloud_trace_title": point_cloud
}
},
axis_args=AxisArgs(backgroundcolor="rgb(200,230,200)")) # kwarg axis_args
fig.show()
The above example will render each axis with the input background color.
See the tutorials in pytorch3d/docs/tutorials for more examples
(namely rendered_color_points.ipynb and rendered_textured_meshes.ipynb).
""" """
pointclouds = pointclouds.detach().cpu() subplots = list(plots.keys())
subplot_titles = kwargs.get("subplot_titles", None) fig = _gen_fig_with_subplots(len(subplots), ncols, subplots)
fig = _gen_fig_with_subplots(len(pointclouds), in_subplots, ncols, subplot_titles) lighting = kwargs.get("lighting", Lighting())._asdict()
for i in range(len(pointclouds)): axis_args_dict = kwargs.get("axis_args", AxisArgs())._asdict()
verts = pointclouds[i].points_packed()
features = pointclouds[i].features_packed()
indices = None # Set axis arguments to defaults defined at the top of this file
if max_points is not None and verts.shape[0] > max_points: x_settings = {**axis_args_dict}
indices = np.random.choice(verts.shape[0], max_points, replace=False) y_settings = {**axis_args_dict}
verts = verts[indices] z_settings = {**axis_args_dict}
color = None # Update the axes with any axis settings passed in as kwargs.
if features is not None: x_settings.update(**kwargs.get("xaxis", {}))
features = features[indices] y_settings.update(**kwargs.get("yaxis", {}))
if features.shape[1] == 4: # rgba z_settings.update(**kwargs.get("zaxis", {}))
template = "rgb(%d, %d, %d, %f)"
rgb = (features[:, :3] * 255).int()
color = [
template % (*rgb_, a_) for rgb_, a_ in zip(rgb, features[:, 3])
]
if features.shape[1] == 3: camera = {
template = "rgb(%d, %d, %d)" "up": {
rgb = (features * 255).int() "x": 0,
color = [template % (r, g, b) for r, g, b in rgb] "y": 1,
"z": 0,
} # set the up vector to match PyTorch3D world coordinates conventions
}
trace_row = i // ncols + 1 if in_subplots else 1 for subplot_idx in range(len(subplots)):
trace_col = i % ncols + 1 if in_subplots else 1 subplot_name = subplots[subplot_idx]
fig.add_trace( traces = plots[subplot_name]
go.Scatter3d( # pyre-ignore[16] for trace_name, struct in traces.items():
x=verts[:, 0], if isinstance(struct, Meshes):
y=verts[:, 1], _add_mesh_trace(fig, struct, trace_name, subplot_idx, ncols, lighting)
z=verts[:, 2], elif isinstance(struct, Pointclouds):
marker={"color": color, "size": 1}, _add_pointcloud_trace(
mode="markers", fig,
), struct,
row=trace_row, trace_name,
col=trace_col, subplot_idx,
) ncols,
pointcloud_max_points,
pointcloud_marker_size,
)
else:
raise ValueError(
"struct {} is not a Meshes or Pointclouds object".format(struct)
)
# Ensure update for every subplot. # Ensure update for every subplot.
plot_scene = "scene" + str(i + 1) if in_subplots else "scene" plot_scene = "scene" + str(subplot_idx + 1)
current_layout = fig["layout"][plot_scene] current_layout = fig["layout"][plot_scene]
xaxis = current_layout["xaxis"]
yaxis = current_layout["yaxis"]
zaxis = current_layout["zaxis"]
verts_center = verts.mean(0) # Update the axes with our above default and provided settings.
xaxis.update(**x_settings)
axis_args = kwargs.get("axis_args", AxisArgs()) yaxis.update(**y_settings)
zaxis.update(**z_settings)
xaxis, yaxis, zaxis = _gen_updated_axis_bounds(
verts, verts_center, current_layout, axis_args
)
xaxis.update(**kwargs.get("xaxis", {}))
yaxis.update(**kwargs.get("yaxis", {}))
zaxis.update(**kwargs.get("zaxis", {}))
current_layout.update( current_layout.update(
{"xaxis": xaxis, "yaxis": yaxis, "zaxis": zaxis, "aspectmode": "cube"} {
"xaxis": xaxis,
"yaxis": yaxis,
"zaxis": zaxis,
"aspectmode": "cube",
"camera": camera,
}
) )
return fig return fig
def _gen_fig_with_subplots( def _add_mesh_trace(
batch_size: int, in_subplots: bool, ncols: int, subplot_titles: Optional[list] fig: go.Figure, # pyre-ignore[11]
meshes: Meshes,
trace_name: str,
subplot_idx: int,
ncols: int,
lighting: Lighting,
): ):
"""
Adds a trace rendering a Meshes object to the passed in figure, with
a given name and in a specific subplot.
Args:
fig: plotly figure to add the trace within.
meshes: Meshes object to render. It can be batched.
trace_name: name to label the trace with.
subplot_idx: identifies the subplot, with 0 being the top left.
ncols: the number of subplots per row.
lighting: a Lighting object that specifies the Mesh3D lighting.
"""
mesh = join_meshes_as_scene(meshes)
mesh = mesh.detach().cpu()
verts = mesh.verts_packed()
faces = mesh.faces_packed()
# If mesh has vertex colors defined as texture, use vertex colors
# for figure, otherwise use plotly's default colors.
verts_rgb = None
if isinstance(mesh.textures, TexturesVertex):
verts_rgb = mesh.textures.verts_features_packed()
verts_rgb.clamp_(min=0.0, max=1.0)
verts_rgb = torch.tensor(255.0) * verts_rgb
# Reposition the unused vertices to be "inside" the object
# (i.e. they won't be visible in the plot).
verts_used = torch.zeros((verts.shape[0],), dtype=torch.bool)
verts_used[torch.unique(faces)] = True
verts_center = verts[verts_used].mean(0)
verts[~verts_used] = verts_center
row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1
fig.add_trace(
go.Mesh3d( # pyre-ignore[16]
x=verts[:, 0],
y=verts[:, 1],
z=verts[:, 2],
vertexcolor=verts_rgb,
i=faces[:, 0],
j=faces[:, 1],
k=faces[:, 2],
lighting=lighting,
name=trace_name,
),
row=row,
col=col,
)
# Access the current subplot's scene configuration
plot_scene = "scene" + str(subplot_idx + 1)
current_layout = fig["layout"][plot_scene]
# update the bounds of the axes for the current trace
max_expand = (verts.max(0)[0] - verts.min(0)[0]).max()
_update_axes_bounds(verts_center, max_expand, current_layout)
def _add_pointcloud_trace(
fig: go.Figure,
pointclouds: Pointclouds,
trace_name: str,
subplot_idx: int,
ncols: int,
max_points_per_pointcloud: int,
marker_size: int,
):
"""
Adds a trace rendering a Pointclouds object to the passed in figure, with
a given name and in a specific subplot.
Args:
fig: plotly figure to add the trace within.
pointclouds: Pointclouds object to render. It can be batched.
trace_name: name to label the trace with.
subplot_idx: identifies the subplot, with 0 being the top left.
ncols: the number of sublpots per row.
max_points_per_pointcloud: the number of points to render, which are randomly sampled.
marker_size: the size of the rendered points
"""
pointclouds = pointclouds.detach().cpu()
verts = pointclouds.points_packed()
features = pointclouds.features_packed()
total_points_count = max_points_per_pointcloud * len(pointclouds)
indices = None
if verts.shape[0] > total_points_count:
indices = np.random.choice(verts.shape[0], total_points_count, replace=False)
verts = verts[indices]
color = None
if features is not None:
features = features[indices]
if features.shape[1] == 4: # rgba
template = "rgb(%d, %d, %d, %f)"
rgb = (features[:, :3].clamp(0.0, 1.0) * 255).int()
color = [template % (*rgb_, a_) for rgb_, a_ in zip(rgb, features[:, 3])]
if features.shape[1] == 3:
template = "rgb(%d, %d, %d)"
rgb = (features.clamp(0.0, 1.0) * 255).int()
color = [template % (r, g, b) for r, g, b in rgb]
row = subplot_idx // ncols + 1
col = subplot_idx % ncols + 1
fig.add_trace(
go.Scatter3d( # pyre-ignore[16]
x=verts[:, 0],
y=verts[:, 1],
z=verts[:, 2],
marker={"color": color, "size": marker_size},
mode="markers",
name=trace_name,
),
row=row,
col=col,
)
# Access the current subplot's scene configuration
plot_scene = "scene" + str(subplot_idx + 1)
current_layout = fig["layout"][plot_scene]
# update the bounds of the axes for the current trace
verts_center = verts.mean(0)
max_expand = (verts.max(0)[0] - verts.min(0)[0]).max()
_update_axes_bounds(verts_center, max_expand, current_layout)
def _gen_fig_with_subplots(batch_size: int, ncols: int, subplot_titles: List[str]):
""" """
Takes in the number of objects to be plotted and generate a plotly figure Takes in the number of objects to be plotted and generate a plotly figure
with the appropriate number and orientation of subplots with the appropriate number and orientation of titled subplots.
Args: Args:
batch_size: the number of elements in the batch of objects to be visualized. batch_size: the number of elements in the batch of objects to be visualized.
in_subplots: if each object should be visualized in an individual subplot. ncols: number of subplots in the same row.
ncols: number of subplots in the same row if in_subplots is set to be True. Otherwise subplot_titles: titles for the subplot(s). list of strings of length batch_size.
ncols will be ignored.
subplot_titles: titles for the subplot(s). list of strings of length batch_size
if in_subplots is True, otherwise length 1.
Returns: Returns:
Plotly figure with one plot if in_subplots is false. Otherwise, returns a plotly figure Plotly figure with ncols subplots per row, and batch_size subplots.
with ncols subplots per row.
""" """
if batch_size % ncols != 0: if batch_size % ncols != 0:
msg = "ncols is invalid for the given mesh batch size." msg = "ncols is invalid for the given mesh batch size."
warnings.warn(msg) warnings.warn(msg)
fig_rows = batch_size // ncols if in_subplots else 1 fig_rows = batch_size // ncols
fig_cols = ncols if in_subplots else 1 fig_cols = ncols
fig_type = [{"type": "scene"}] fig_type = [{"type": "scene"}]
specs = [fig_type * fig_cols] * fig_rows specs = [fig_type * fig_cols] * fig_rows
# subplot_titles must have one title per subplot # subplot_titles must have one title per subplot
if subplot_titles is not None and len(subplot_titles) != fig_cols * fig_rows:
subplot_titles = None
fig = make_subplots( fig = make_subplots(
rows=fig_rows, rows=fig_rows,
cols=fig_cols, cols=fig_cols,
@ -251,27 +368,19 @@ def _gen_fig_with_subplots(
return fig return fig
def _gen_updated_axis_bounds( def _update_axes_bounds(
verts: torch.Tensor,
verts_center: torch.Tensor, verts_center: torch.Tensor,
max_expand: float,
current_layout: go.Scene, # pyre-ignore[11] current_layout: go.Scene, # pyre-ignore[11]
axis_args: AxisArgs, ):
) -> Tuple[dict, dict, dict]:
""" """
Takes in the vertices, center point of the vertices, and the current plotly figure and Takes in the vertices' center point and max spread, and the current plotly figure
outputs axes with bounds that capture all points in the current subplot. layout and updates the layout to have bounds that include all traces for that subplot.
Args: Args:
verts: tensor of size (N, 3) representing N points with xyz coordinates. verts_center: tensor of size (3) corresponding to a trace's vertices' center point.
verts_center: tensor of size (3) corresponding to the center point of verts. max_expand: the maximum spread in any dimension of the trace's vertices.
current_layout: the current plotly figure layout scene corresponding to verts' trace. current_layout: the plotly figure layout scene corresponding to the referenced trace.
axis_args: an AxisArgs object with default and/or user-set values for plotly's axes.
Returns:
a 3 item tuple of xaxis, yaxis, and zaxis, which are dictionaries with axis arguments
for plotly including a range key with value the minimum and maximum value for that axis.
""" """
# Get ranges of vertices.
max_expand = (verts.max(0)[0] - verts.min(0)[0]).max()
verts_min = verts_center - max_expand verts_min = verts_center - max_expand
verts_max = verts_center + max_expand verts_max = verts_center + max_expand
bounds = torch.t(torch.stack((verts_min, verts_max))) bounds = torch.t(torch.stack((verts_min, verts_max)))
@ -292,8 +401,8 @@ def _gen_updated_axis_bounds(
if old_zrange is not None: if old_zrange is not None:
z_range[0] = min(z_range[0], old_zrange[0]) z_range[0] = min(z_range[0], old_zrange[0])
z_range[1] = max(z_range[1], old_zrange[1]) z_range[1] = max(z_range[1], old_zrange[1])
axis_args_dict = axis_args._asdict()
xaxis = {"range": x_range, **axis_args_dict} xaxis = {"range": x_range}
yaxis = {"range": y_range, **axis_args_dict} yaxis = {"range": y_range}
zaxis = {"range": z_range, **axis_args_dict} zaxis = {"range": z_range}
return xaxis, yaxis, zaxis current_layout.update({"xaxis": xaxis, "yaxis": yaxis, "zaxis": zaxis})