Refactor plot_meshes and plot_pointclouds to one generalizable API, plot_scene

Summary: Defines a function plot_scene that takes in a dictionary defining subplot and trace layouts for Mesh/Pointcloud objects and plots them. Also supports other plotly axis arguments and mesh lighting. Plot_batch_individually is a wrapper function that takes in one or multiple batched Meshes/Pointclouds and uses plot_scene to plot each element within a batch in an individual subplot, possibly sharing that subplot with traces of other individual elements of the other batched structures passed in.

Reviewed By: nikhilaravi

Differential Revision: D24235479

fbshipit-source-id: 9f669f1b186d55fe5c75552083316c0cf1387472
This commit is contained in:
Amitav Baruah 2020-10-20 17:14:38 -07:00 committed by Facebook GitHub Bot
parent abd390319c
commit 964893cdcb
2 changed files with 305 additions and 196 deletions

View File

@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .plotly_vis import AxisArgs, Lighting, plot_meshes, plot_pointclouds
from .plotly_vis import AxisArgs, Lighting, plot_scene
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@ -1,14 +1,14 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import warnings
from typing import NamedTuple, Optional, Tuple
from typing import Dict, List, NamedTuple, Union
import numpy as np
import plotly.graph_objects as go
import torch
from plotly.subplots import make_subplots
from pytorch3d.renderer import TexturesVertex
from pytorch3d.structures import Meshes, Pointclouds
from pytorch3d.structures import Meshes, Pointclouds, join_meshes_as_scene
class AxisArgs(NamedTuple):
@ -31,42 +31,207 @@ class Lighting(NamedTuple):
vertexnormalsepsilon: float = 1e-12
def plot_meshes(meshes: Meshes, *, in_subplots: bool = False, ncols: int = 1, **kwargs):
def plot_scene(
plots: Dict[str, Dict[str, Union[Pointclouds, Meshes]]],
*,
ncols: int = 1,
pointcloud_max_points: int = 20000,
pointcloud_marker_size: int = 1,
**kwargs,
):
"""
Takes a Meshes object and generates a plotly figure. If there is more than
one mesh in the batch and in_subplots=True, each mesh will be
visualized in an individual subplot with ncols number of subplots in the same row.
Otherwise, each mesh in the batch will be visualized as an individual trace in the
same plot. If the Meshes object has vertex colors defined as its texture, the vertex
colors will be used for generating the plotly figure. Otherwise plotly's default
colors will be used.
Main function to visualize Meshes and Pointclouds.
Plots input Pointclouds and Meshes data into named subplots,
with named traces based on the dictionary keys.
Args:
meshes: Meshes object to be visualized in a plotly figure.
in_subplots: if each mesh in the batch should be visualized in an individual subplot
ncols: number of subplots in the same row if in_subplots is set to be True. Otherwise
ncols will be ignored.
**kwargs: Accepts lighting (a Lighting object) and lightposition for Mesh3D and any of
the args xaxis, yaxis and zaxis accept for scene. Accepts axis_args which is an
AxisArgs object that is applied to the 3 axes. Also accepts subplot_titles, which
should be a list of string titles matching the number of subplots.
Example settings for axis_args and lighting are given above.
plots: A dict containing subplot and trace names,
as well as the Meshes and Pointclouds objects to be rendered.
See below for examples of the format.
ncols: the number of subplots per row
pointcloud_max_points: the maximum number of points to plot from
a pointcloud. If more are present, a random sample of size
pointcloud_max_points is used.
pointcloud_marker_size: the size of the points rendered by plotly
when plotting a pointcloud.
**kwargs: Accepts lighting (a Lighting object) and any of the args xaxis,
yaxis and zaxis which Plotly's scene accepts. Accepts axis_args,
which is an AxisArgs object that is applied to all 3 axes.
Example settings for axis_args and lighting are given at the
top of this file.
Returns:
Plotly figure of the mesh. If there is more than one mesh in the batch,
the plotly figure will contain a series of vertically stacked subplots.
Example:
..code-block::python
mesh = ...
point_cloud = ...
fig = plot_scene({
"subplot_title": {
"mesh_trace_title": mesh,
"pointcloud_trace_title": point_cloud
}
})
fig.show()
The above example will render one subplot which has both a mesh and pointcloud.
If the Meshes or Pointclouds objects are batched, then every object in that batch
will be plotted in a single trace.
..code-block::python
mesh = ... # batch size 2
point_cloud = ... # batch size 2
fig = plot_scene({
"subplot_title": {
"mesh_trace_title": mesh,
"pointcloud_trace_title": point_cloud
}
})
fig.show()
The above example renders one subplot with 2 traces, each of which renders
both objects from their respective batched data.
Multiple subplots follow the same pattern:
..code-block::python
mesh = ... # batch size 2
point_cloud = ... # batch size 2
fig = plot_scene({
"subplot1_title": {
"mesh_trace_title": mesh[0],
"pointcloud_trace_title": point_cloud[0]
},
"subplot2_title": {
"mesh_trace_title": mesh[1],
"pointcloud_trace_title": point_cloud[1]
}
},
ncols=2) # specify the number of subplots per row
fig.show()
The above example will render two subplots, each containing a mesh
and a pointcloud. The ncols argument will render two subplots in one row
instead of having them vertically stacked because the default is one subplot
per row.
For an example of using kwargs, see below:
..code-block::python
mesh = ...
point_cloud = ...
fig = plot_scene({
"subplot_title": {
"mesh_trace_title": mesh,
"pointcloud_trace_title": point_cloud
}
},
axis_args=AxisArgs(backgroundcolor="rgb(200,230,200)")) # kwarg axis_args
fig.show()
The above example will render each axis with the input background color.
See the tutorials in pytorch3d/docs/tutorials for more examples
(namely rendered_color_points.ipynb and rendered_textured_meshes.ipynb).
"""
meshes = meshes.detach().cpu()
subplot_titles = kwargs.get("subplot_titles", None)
fig = _gen_fig_with_subplots(len(meshes), in_subplots, ncols, subplot_titles)
for i in range(len(meshes)):
verts = meshes[i].verts_packed()
faces = meshes[i].faces_packed()
subplots = list(plots.keys())
fig = _gen_fig_with_subplots(len(subplots), ncols, subplots)
lighting = kwargs.get("lighting", Lighting())._asdict()
axis_args_dict = kwargs.get("axis_args", AxisArgs())._asdict()
# Set axis arguments to defaults defined at the top of this file
x_settings = {**axis_args_dict}
y_settings = {**axis_args_dict}
z_settings = {**axis_args_dict}
# Update the axes with any axis settings passed in as kwargs.
x_settings.update(**kwargs.get("xaxis", {}))
y_settings.update(**kwargs.get("yaxis", {}))
z_settings.update(**kwargs.get("zaxis", {}))
camera = {
"up": {
"x": 0,
"y": 1,
"z": 0,
} # set the up vector to match PyTorch3D world coordinates conventions
}
for subplot_idx in range(len(subplots)):
subplot_name = subplots[subplot_idx]
traces = plots[subplot_name]
for trace_name, struct in traces.items():
if isinstance(struct, Meshes):
_add_mesh_trace(fig, struct, trace_name, subplot_idx, ncols, lighting)
elif isinstance(struct, Pointclouds):
_add_pointcloud_trace(
fig,
struct,
trace_name,
subplot_idx,
ncols,
pointcloud_max_points,
pointcloud_marker_size,
)
else:
raise ValueError(
"struct {} is not a Meshes or Pointclouds object".format(struct)
)
# Ensure update for every subplot.
plot_scene = "scene" + str(subplot_idx + 1)
current_layout = fig["layout"][plot_scene]
xaxis = current_layout["xaxis"]
yaxis = current_layout["yaxis"]
zaxis = current_layout["zaxis"]
# Update the axes with our above default and provided settings.
xaxis.update(**x_settings)
yaxis.update(**y_settings)
zaxis.update(**z_settings)
current_layout.update(
{
"xaxis": xaxis,
"yaxis": yaxis,
"zaxis": zaxis,
"aspectmode": "cube",
"camera": camera,
}
)
return fig
def _add_mesh_trace(
fig: go.Figure, # pyre-ignore[11]
meshes: Meshes,
trace_name: str,
subplot_idx: int,
ncols: int,
lighting: Lighting,
):
"""
Adds a trace rendering a Meshes object to the passed in figure, with
a given name and in a specific subplot.
Args:
fig: plotly figure to add the trace within.
meshes: Meshes object to render. It can be batched.
trace_name: name to label the trace with.
subplot_idx: identifies the subplot, with 0 being the top left.
ncols: the number of subplots per row.
lighting: a Lighting object that specifies the Mesh3D lighting.
"""
mesh = join_meshes_as_scene(meshes)
mesh = mesh.detach().cpu()
verts = mesh.verts_packed()
faces = mesh.faces_packed()
# If mesh has vertex colors defined as texture, use vertex colors
# for figure, otherwise use plotly's default colors.
verts_rgb = None
if isinstance(meshes[i].textures, TexturesVertex):
verts_rgb = meshes[i].textures.verts_features_packed()
if isinstance(mesh.textures, TexturesVertex):
verts_rgb = mesh.textures.verts_features_packed()
verts_rgb.clamp_(min=0.0, max=1.0)
verts_rgb = torch.tensor(255.0) * verts_rgb
@ -77,8 +242,7 @@ def plot_meshes(meshes: Meshes, *, in_subplots: bool = False, ncols: int = 1, **
verts_center = verts[verts_used].mean(0)
verts[~verts_used] = verts_center
trace_row = i // ncols + 1 if in_subplots else 1
trace_col = i % ncols + 1 if in_subplots else 1
row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1
fig.add_trace(
go.Mesh3d( # pyre-ignore[16]
x=verts[:, 0],
@ -88,78 +252,52 @@ def plot_meshes(meshes: Meshes, *, in_subplots: bool = False, ncols: int = 1, **
i=faces[:, 0],
j=faces[:, 1],
k=faces[:, 2],
lighting=kwargs.get("lighting", Lighting())._asdict(),
lightposition=kwargs.get("lightposition", {}),
lighting=lighting,
name=trace_name,
),
row=trace_row,
col=trace_col,
row=row,
col=col,
)
# Ensure update for every subplot.
plot_scene = "scene" + str(i + 1) if in_subplots else "scene"
# Access the current subplot's scene configuration
plot_scene = "scene" + str(subplot_idx + 1)
current_layout = fig["layout"][plot_scene]
axis_args = kwargs.get("axis_args", AxisArgs())
xaxis, yaxis, zaxis = _gen_updated_axis_bounds(
verts, verts_center, current_layout, axis_args
)
# Update the axis bounds with the axis settings passed in as kwargs.
xaxis.update(**kwargs.get("xaxis", {}))
yaxis.update(**kwargs.get("yaxis", {}))
zaxis.update(**kwargs.get("zaxis", {}))
current_layout.update(
{"xaxis": xaxis, "yaxis": yaxis, "zaxis": zaxis, "aspectmode": "cube"}
)
return fig
# update the bounds of the axes for the current trace
max_expand = (verts.max(0)[0] - verts.min(0)[0]).max()
_update_axes_bounds(verts_center, max_expand, current_layout)
def plot_pointclouds(
def _add_pointcloud_trace(
fig: go.Figure,
pointclouds: Pointclouds,
*,
in_subplots: bool = False,
ncols: int = 1,
max_points: int = 20000,
**kwargs,
trace_name: str,
subplot_idx: int,
ncols: int,
max_points_per_pointcloud: int,
marker_size: int,
):
"""
Takes a Pointclouds object and generates a plotly figure. If there is more than
one pointcloud in the batch, and in_subplots is set to be True, each pointcloud will be
visualized in an individual subplot with ncols number of subplots in the same row.
Otherwise, each pointcloud in the batch will be visualized as an individual trace in the
same plot. If the Pointclouds object has features that are size (3) or (4) then those
rgb/rgba values will be used for the plotly figure. Otherwise, plotly's default colors
will be used. Assumes that all rgb/rgba feature values are in the range [0,1].
Adds a trace rendering a Pointclouds object to the passed in figure, with
a given name and in a specific subplot.
Args:
pointclouds: Pointclouds object which can contain a batch of pointclouds.
in_subplots: if each pointcloud should be visualized in an individual subplot.
ncols: number of subplots in the same row if in_subplots is set to be True. Otherwise
ncols will be ignored.
max_points: maximum number of points to plot. If the cloud has more, they are
randomly subsampled.
**kwargs: Accepts lighting (a Lighting object) and lightposition for Scatter3D
and any of the args xaxis, yaxis and zaxis which scene accepts.
Accepts axis_args which is an AxisArgs object that is applied to the 3 axes.
Also accepts subplot_titles, whichshould be a list of string titles
matching the number of subplots. Example settings for axis_args and lighting are
given at the top of this file.
Returns:
Plotly figure of the pointcloud(s). If there is more than one pointcloud in the batch,
the plotly figure will contain a plot with one trace per pointcloud, or with each
pointcloud in a separate subplot if in_subplots is True.
fig: plotly figure to add the trace within.
pointclouds: Pointclouds object to render. It can be batched.
trace_name: name to label the trace with.
subplot_idx: identifies the subplot, with 0 being the top left.
ncols: the number of sublpots per row.
max_points_per_pointcloud: the number of points to render, which are randomly sampled.
marker_size: the size of the rendered points
"""
pointclouds = pointclouds.detach().cpu()
subplot_titles = kwargs.get("subplot_titles", None)
fig = _gen_fig_with_subplots(len(pointclouds), in_subplots, ncols, subplot_titles)
for i in range(len(pointclouds)):
verts = pointclouds[i].points_packed()
features = pointclouds[i].features_packed()
verts = pointclouds.points_packed()
features = pointclouds.features_packed()
total_points_count = max_points_per_pointcloud * len(pointclouds)
indices = None
if max_points is not None and verts.shape[0] > max_points:
indices = np.random.choice(verts.shape[0], max_points, replace=False)
if verts.shape[0] > total_points_count:
indices = np.random.choice(verts.shape[0], total_points_count, replace=False)
verts = verts[indices]
color = None
@ -167,80 +305,59 @@ def plot_pointclouds(
features = features[indices]
if features.shape[1] == 4: # rgba
template = "rgb(%d, %d, %d, %f)"
rgb = (features[:, :3] * 255).int()
color = [
template % (*rgb_, a_) for rgb_, a_ in zip(rgb, features[:, 3])
]
rgb = (features[:, :3].clamp(0.0, 1.0) * 255).int()
color = [template % (*rgb_, a_) for rgb_, a_ in zip(rgb, features[:, 3])]
if features.shape[1] == 3:
template = "rgb(%d, %d, %d)"
rgb = (features * 255).int()
rgb = (features.clamp(0.0, 1.0) * 255).int()
color = [template % (r, g, b) for r, g, b in rgb]
trace_row = i // ncols + 1 if in_subplots else 1
trace_col = i % ncols + 1 if in_subplots else 1
row = subplot_idx // ncols + 1
col = subplot_idx % ncols + 1
fig.add_trace(
go.Scatter3d( # pyre-ignore[16]
x=verts[:, 0],
y=verts[:, 1],
z=verts[:, 2],
marker={"color": color, "size": 1},
marker={"color": color, "size": marker_size},
mode="markers",
name=trace_name,
),
row=trace_row,
col=trace_col,
row=row,
col=col,
)
# Ensure update for every subplot.
plot_scene = "scene" + str(i + 1) if in_subplots else "scene"
# Access the current subplot's scene configuration
plot_scene = "scene" + str(subplot_idx + 1)
current_layout = fig["layout"][plot_scene]
# update the bounds of the axes for the current trace
verts_center = verts.mean(0)
axis_args = kwargs.get("axis_args", AxisArgs())
xaxis, yaxis, zaxis = _gen_updated_axis_bounds(
verts, verts_center, current_layout, axis_args
)
xaxis.update(**kwargs.get("xaxis", {}))
yaxis.update(**kwargs.get("yaxis", {}))
zaxis.update(**kwargs.get("zaxis", {}))
current_layout.update(
{"xaxis": xaxis, "yaxis": yaxis, "zaxis": zaxis, "aspectmode": "cube"}
)
return fig
max_expand = (verts.max(0)[0] - verts.min(0)[0]).max()
_update_axes_bounds(verts_center, max_expand, current_layout)
def _gen_fig_with_subplots(
batch_size: int, in_subplots: bool, ncols: int, subplot_titles: Optional[list]
):
def _gen_fig_with_subplots(batch_size: int, ncols: int, subplot_titles: List[str]):
"""
Takes in the number of objects to be plotted and generate a plotly figure
with the appropriate number and orientation of subplots
with the appropriate number and orientation of titled subplots.
Args:
batch_size: the number of elements in the batch of objects to be visualized.
in_subplots: if each object should be visualized in an individual subplot.
ncols: number of subplots in the same row if in_subplots is set to be True. Otherwise
ncols will be ignored.
subplot_titles: titles for the subplot(s). list of strings of length batch_size
if in_subplots is True, otherwise length 1.
ncols: number of subplots in the same row.
subplot_titles: titles for the subplot(s). list of strings of length batch_size.
Returns:
Plotly figure with one plot if in_subplots is false. Otherwise, returns a plotly figure
with ncols subplots per row.
Plotly figure with ncols subplots per row, and batch_size subplots.
"""
if batch_size % ncols != 0:
msg = "ncols is invalid for the given mesh batch size."
warnings.warn(msg)
fig_rows = batch_size // ncols if in_subplots else 1
fig_cols = ncols if in_subplots else 1
fig_rows = batch_size // ncols
fig_cols = ncols
fig_type = [{"type": "scene"}]
specs = [fig_type * fig_cols] * fig_rows
# subplot_titles must have one title per subplot
if subplot_titles is not None and len(subplot_titles) != fig_cols * fig_rows:
subplot_titles = None
fig = make_subplots(
rows=fig_rows,
cols=fig_cols,
@ -251,27 +368,19 @@ def _gen_fig_with_subplots(
return fig
def _gen_updated_axis_bounds(
verts: torch.Tensor,
def _update_axes_bounds(
verts_center: torch.Tensor,
max_expand: float,
current_layout: go.Scene, # pyre-ignore[11]
axis_args: AxisArgs,
) -> Tuple[dict, dict, dict]:
):
"""
Takes in the vertices, center point of the vertices, and the current plotly figure and
outputs axes with bounds that capture all points in the current subplot.
Takes in the vertices' center point and max spread, and the current plotly figure
layout and updates the layout to have bounds that include all traces for that subplot.
Args:
verts: tensor of size (N, 3) representing N points with xyz coordinates.
verts_center: tensor of size (3) corresponding to the center point of verts.
current_layout: the current plotly figure layout scene corresponding to verts' trace.
axis_args: an AxisArgs object with default and/or user-set values for plotly's axes.
Returns:
a 3 item tuple of xaxis, yaxis, and zaxis, which are dictionaries with axis arguments
for plotly including a range key with value the minimum and maximum value for that axis.
verts_center: tensor of size (3) corresponding to a trace's vertices' center point.
max_expand: the maximum spread in any dimension of the trace's vertices.
current_layout: the plotly figure layout scene corresponding to the referenced trace.
"""
# Get ranges of vertices.
max_expand = (verts.max(0)[0] - verts.min(0)[0]).max()
verts_min = verts_center - max_expand
verts_max = verts_center + max_expand
bounds = torch.t(torch.stack((verts_min, verts_max)))
@ -292,8 +401,8 @@ def _gen_updated_axis_bounds(
if old_zrange is not None:
z_range[0] = min(z_range[0], old_zrange[0])
z_range[1] = max(z_range[1], old_zrange[1])
axis_args_dict = axis_args._asdict()
xaxis = {"range": x_range, **axis_args_dict}
yaxis = {"range": y_range, **axis_args_dict}
zaxis = {"range": z_range, **axis_args_dict}
return xaxis, yaxis, zaxis
xaxis = {"range": x_range}
yaxis = {"range": y_range}
zaxis = {"range": z_range}
current_layout.update({"xaxis": xaxis, "yaxis": yaxis, "zaxis": zaxis})