mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
Add wrapper function to plot batches
Summary: - adds plot_batch_individually - for each batched object, plots each object in its own subplot with other same-indexed elements of the other batched objects provided as input Reviewed By: nikhilaravi Differential Revision: D24258389 fbshipit-source-id: a80128e6e7a03a44c257b0598569159afadb2d39
This commit is contained in:
parent
964893cdcb
commit
bf7aca320a
@ -1,6 +1,6 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from .plotly_vis import AxisArgs, Lighting, plot_scene
|
||||
from .plotly_vis import AxisArgs, Lighting, plot_batch_individually, plot_scene
|
||||
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import warnings
|
||||
from typing import Dict, List, NamedTuple, Union
|
||||
from typing import Dict, List, NamedTuple, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import plotly.graph_objects as go
|
||||
@ -202,6 +202,145 @@ def plot_scene(
|
||||
return fig
|
||||
|
||||
|
||||
def plot_batch_individually(
|
||||
batched_structs: Union[List[Union[Meshes, Pointclouds]], Meshes, Pointclouds],
|
||||
*,
|
||||
ncols: int = 1,
|
||||
extend_struct: bool = True,
|
||||
subplot_titles: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
This is a higher level plotting function than plot_scene, for plotting
|
||||
Meshes and Pointclouds in simple cases. The simplest use is to plot a
|
||||
single Meshes or Pointclouds object, where you just pass it in as a
|
||||
one element list. This will plot each batch element in a separate subplot.
|
||||
|
||||
More generally, you can supply multiple Meshes or Pointclouds
|
||||
having the same batch size `n`. In this case, there will be `n` subplots,
|
||||
each depicting the corresponding batch element of all the inputs.
|
||||
|
||||
In addition, you can include Meshes and Pointclouds of size 1 in
|
||||
the input. These will either be rendered in the first subplot
|
||||
(if extend_struct is False), or in every subplot.
|
||||
|
||||
Args:
|
||||
batched_structs: a list of Meshes and/or Pointclouds to be rendered.
|
||||
Each structure's corresponding batch element will be plotted in
|
||||
a single subplot, resulting in n subplots for a batch of size n.
|
||||
Every struct should either have the same batch size or be of batch size 1.
|
||||
See extend_struct and the description above for how batch size 1 structs
|
||||
are handled. Also accepts a single Meshes or Pointclouds object, which will have
|
||||
each individual element plotted in its own subplot.
|
||||
ncols: the number of subplots per row
|
||||
extend_struct: if True, indicates that structs of batch size 1
|
||||
should be plotted in every subplot.
|
||||
subplot_titles: strings to name each subplot
|
||||
**kwargs: keyword arguments which are passed to plot_scene.
|
||||
See plot_scene documentation for details.
|
||||
|
||||
Example:
|
||||
|
||||
..code-block::python
|
||||
|
||||
mesh = ... # mesh of batch size 2
|
||||
point_cloud = ... # point_cloud of batch size 2
|
||||
fig = plot_batch_individually([mesh, point_cloud], subplot_titles=["plot1", "plot2"])
|
||||
fig.show()
|
||||
|
||||
# this is equivalent to the below figure
|
||||
fig = plot_scene({
|
||||
"plot1": {
|
||||
"trace1-1": mesh[0],
|
||||
"trace1-2": point_cloud[0]
|
||||
},
|
||||
"plot2":{
|
||||
"trace2-1": mesh[1],
|
||||
"trace2-2": point_cloud[1]
|
||||
}
|
||||
})
|
||||
fig.show()
|
||||
|
||||
The above example will render two subplots which each have both a mesh and pointcloud.
|
||||
For more examples look at the pytorch3d tutorials at `pytorch3d/docs/tutorials`,
|
||||
in particular the files rendered_color_points.ipynb and rendered_textured_meshes.ipynb.
|
||||
"""
|
||||
|
||||
# check that every batch is the same size or is size 1
|
||||
if len(batched_structs) == 0:
|
||||
msg = "No structs to plot"
|
||||
warnings.warn(msg)
|
||||
return
|
||||
max_size = 0
|
||||
if isinstance(batched_structs, list):
|
||||
max_size = max(len(s) for s in batched_structs)
|
||||
for struct in batched_structs:
|
||||
if len(struct) not in (1, max_size):
|
||||
msg = "invalid batch size {} provided: {}".format(len(struct), struct)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
max_size = len(batched_structs)
|
||||
|
||||
if max_size == 0:
|
||||
msg = "No data is provided with at least one element"
|
||||
raise ValueError(msg)
|
||||
|
||||
if subplot_titles:
|
||||
if len(subplot_titles) != max_size:
|
||||
msg = "invalid number of subplot titles"
|
||||
raise ValueError(msg)
|
||||
|
||||
scene_dictionary = {}
|
||||
# construct the scene dictionary
|
||||
for scene_num in range(max_size):
|
||||
subplot_title = (
|
||||
subplot_titles[scene_num]
|
||||
if subplot_titles
|
||||
else "subplot " + str(scene_num + 1)
|
||||
)
|
||||
scene_dictionary[subplot_title] = {}
|
||||
|
||||
if isinstance(batched_structs, list):
|
||||
for i, batched_struct in enumerate(batched_structs):
|
||||
# check for whether this struct needs to be extended
|
||||
if i >= len(batched_struct) and not extend_struct:
|
||||
continue
|
||||
_add_struct_from_batch(
|
||||
batched_struct, scene_num, subplot_title, scene_dictionary, i + 1
|
||||
)
|
||||
else: # batched_structs is a single struct
|
||||
_add_struct_from_batch(
|
||||
batched_structs, scene_num, subplot_title, scene_dictionary
|
||||
)
|
||||
|
||||
return plot_scene(scene_dictionary, ncols=ncols, **kwargs)
|
||||
|
||||
|
||||
def _add_struct_from_batch(
|
||||
batched_struct: Union[Meshes, Pointclouds],
|
||||
scene_num: int,
|
||||
subplot_title: str,
|
||||
scene_dictionary: Dict[str, Dict[str, Union[Meshes, Pointclouds]]],
|
||||
trace_idx: int = 1,
|
||||
):
|
||||
"""
|
||||
Adds the struct corresponding to the given scene_num index to
|
||||
a provided scene_dictionary to be passed in to plot_scene
|
||||
|
||||
Args:
|
||||
batched_struct: the batched data structure to add to the dict
|
||||
scene_num: the subplot from plot_batch_individually which this struct
|
||||
should be added to
|
||||
subplot_title: the title of the subplot
|
||||
scene_dictionary: the dictionary to add the indexed struct to
|
||||
trace_idx: the trace number, starting at 1 for this struct's trace
|
||||
"""
|
||||
struct_idx = min(scene_num, len(batched_struct) - 1)
|
||||
struct = batched_struct[struct_idx]
|
||||
trace_name = "trace{}-{}".format(scene_num + 1, trace_idx)
|
||||
scene_dictionary[subplot_title][trace_name] = struct
|
||||
|
||||
|
||||
def _add_mesh_trace(
|
||||
fig: go.Figure, # pyre-ignore[11]
|
||||
meshes: Meshes,
|
||||
|
Loading…
x
Reference in New Issue
Block a user