Add wrapper function to plot batches

Summary:
- adds plot_batch_individually
- for each batched object, plots each object in its own subplot with other same-indexed elements of the other batched objects provided as input

Reviewed By: nikhilaravi

Differential Revision: D24258389

fbshipit-source-id: a80128e6e7a03a44c257b0598569159afadb2d39
This commit is contained in:
Amitav Baruah 2020-10-20 17:14:38 -07:00 committed by Facebook GitHub Bot
parent 964893cdcb
commit bf7aca320a
2 changed files with 141 additions and 2 deletions

View File

@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .plotly_vis import AxisArgs, Lighting, plot_scene
from .plotly_vis import AxisArgs, Lighting, plot_batch_individually, plot_scene
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@ -1,7 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import warnings
from typing import Dict, List, NamedTuple, Union
from typing import Dict, List, NamedTuple, Optional, Union
import numpy as np
import plotly.graph_objects as go
@ -202,6 +202,145 @@ def plot_scene(
return fig
def plot_batch_individually(
batched_structs: Union[List[Union[Meshes, Pointclouds]], Meshes, Pointclouds],
*,
ncols: int = 1,
extend_struct: bool = True,
subplot_titles: Optional[List[str]] = None,
**kwargs,
):
"""
This is a higher level plotting function than plot_scene, for plotting
Meshes and Pointclouds in simple cases. The simplest use is to plot a
single Meshes or Pointclouds object, where you just pass it in as a
one element list. This will plot each batch element in a separate subplot.
More generally, you can supply multiple Meshes or Pointclouds
having the same batch size `n`. In this case, there will be `n` subplots,
each depicting the corresponding batch element of all the inputs.
In addition, you can include Meshes and Pointclouds of size 1 in
the input. These will either be rendered in the first subplot
(if extend_struct is False), or in every subplot.
Args:
batched_structs: a list of Meshes and/or Pointclouds to be rendered.
Each structure's corresponding batch element will be plotted in
a single subplot, resulting in n subplots for a batch of size n.
Every struct should either have the same batch size or be of batch size 1.
See extend_struct and the description above for how batch size 1 structs
are handled. Also accepts a single Meshes or Pointclouds object, which will have
each individual element plotted in its own subplot.
ncols: the number of subplots per row
extend_struct: if True, indicates that structs of batch size 1
should be plotted in every subplot.
subplot_titles: strings to name each subplot
**kwargs: keyword arguments which are passed to plot_scene.
See plot_scene documentation for details.
Example:
..code-block::python
mesh = ... # mesh of batch size 2
point_cloud = ... # point_cloud of batch size 2
fig = plot_batch_individually([mesh, point_cloud], subplot_titles=["plot1", "plot2"])
fig.show()
# this is equivalent to the below figure
fig = plot_scene({
"plot1": {
"trace1-1": mesh[0],
"trace1-2": point_cloud[0]
},
"plot2":{
"trace2-1": mesh[1],
"trace2-2": point_cloud[1]
}
})
fig.show()
The above example will render two subplots which each have both a mesh and pointcloud.
For more examples look at the pytorch3d tutorials at `pytorch3d/docs/tutorials`,
in particular the files rendered_color_points.ipynb and rendered_textured_meshes.ipynb.
"""
# check that every batch is the same size or is size 1
if len(batched_structs) == 0:
msg = "No structs to plot"
warnings.warn(msg)
return
max_size = 0
if isinstance(batched_structs, list):
max_size = max(len(s) for s in batched_structs)
for struct in batched_structs:
if len(struct) not in (1, max_size):
msg = "invalid batch size {} provided: {}".format(len(struct), struct)
raise ValueError(msg)
else:
max_size = len(batched_structs)
if max_size == 0:
msg = "No data is provided with at least one element"
raise ValueError(msg)
if subplot_titles:
if len(subplot_titles) != max_size:
msg = "invalid number of subplot titles"
raise ValueError(msg)
scene_dictionary = {}
# construct the scene dictionary
for scene_num in range(max_size):
subplot_title = (
subplot_titles[scene_num]
if subplot_titles
else "subplot " + str(scene_num + 1)
)
scene_dictionary[subplot_title] = {}
if isinstance(batched_structs, list):
for i, batched_struct in enumerate(batched_structs):
# check for whether this struct needs to be extended
if i >= len(batched_struct) and not extend_struct:
continue
_add_struct_from_batch(
batched_struct, scene_num, subplot_title, scene_dictionary, i + 1
)
else: # batched_structs is a single struct
_add_struct_from_batch(
batched_structs, scene_num, subplot_title, scene_dictionary
)
return plot_scene(scene_dictionary, ncols=ncols, **kwargs)
def _add_struct_from_batch(
batched_struct: Union[Meshes, Pointclouds],
scene_num: int,
subplot_title: str,
scene_dictionary: Dict[str, Dict[str, Union[Meshes, Pointclouds]]],
trace_idx: int = 1,
):
"""
Adds the struct corresponding to the given scene_num index to
a provided scene_dictionary to be passed in to plot_scene
Args:
batched_struct: the batched data structure to add to the dict
scene_num: the subplot from plot_batch_individually which this struct
should be added to
subplot_title: the title of the subplot
scene_dictionary: the dictionary to add the indexed struct to
trace_idx: the trace number, starting at 1 for this struct's trace
"""
struct_idx = min(scene_num, len(batched_struct) - 1)
struct = batched_struct[struct_idx]
trace_name = "trace{}-{}".format(scene_num + 1, trace_idx)
scene_dictionary[subplot_title][trace_name] = struct
def _add_mesh_trace(
fig: go.Figure, # pyre-ignore[11]
meshes: Meshes,