mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
Refactor plot_meshes and plot_pointclouds to one generalizable API, plot_scene
Summary: Defines a function plot_scene that takes in a dictionary defining subplot and trace layouts for Mesh/Pointcloud objects and plots them. Also supports other plotly axis arguments and mesh lighting. Plot_batch_individually is a wrapper function that takes in one or multiple batched Meshes/Pointclouds and uses plot_scene to plot each element within a batch in an individual subplot, possibly sharing that subplot with traces of other individual elements of the other batched structures passed in. Reviewed By: nikhilaravi Differential Revision: D24235479 fbshipit-source-id: 9f669f1b186d55fe5c75552083316c0cf1387472
This commit is contained in:
parent
abd390319c
commit
964893cdcb
@ -1,6 +1,6 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from .plotly_vis import AxisArgs, Lighting, plot_meshes, plot_pointclouds
|
||||
from .plotly_vis import AxisArgs, Lighting, plot_scene
|
||||
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
|
@ -1,14 +1,14 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import warnings
|
||||
from typing import NamedTuple, Optional, Tuple
|
||||
from typing import Dict, List, NamedTuple, Union
|
||||
|
||||
import numpy as np
|
||||
import plotly.graph_objects as go
|
||||
import torch
|
||||
from plotly.subplots import make_subplots
|
||||
from pytorch3d.renderer import TexturesVertex
|
||||
from pytorch3d.structures import Meshes, Pointclouds
|
||||
from pytorch3d.structures import Meshes, Pointclouds, join_meshes_as_scene
|
||||
|
||||
|
||||
class AxisArgs(NamedTuple):
|
||||
@ -31,216 +31,333 @@ class Lighting(NamedTuple):
|
||||
vertexnormalsepsilon: float = 1e-12
|
||||
|
||||
|
||||
def plot_meshes(meshes: Meshes, *, in_subplots: bool = False, ncols: int = 1, **kwargs):
|
||||
"""
|
||||
Takes a Meshes object and generates a plotly figure. If there is more than
|
||||
one mesh in the batch and in_subplots=True, each mesh will be
|
||||
visualized in an individual subplot with ncols number of subplots in the same row.
|
||||
Otherwise, each mesh in the batch will be visualized as an individual trace in the
|
||||
same plot. If the Meshes object has vertex colors defined as its texture, the vertex
|
||||
colors will be used for generating the plotly figure. Otherwise plotly's default
|
||||
colors will be used.
|
||||
|
||||
Args:
|
||||
meshes: Meshes object to be visualized in a plotly figure.
|
||||
in_subplots: if each mesh in the batch should be visualized in an individual subplot
|
||||
ncols: number of subplots in the same row if in_subplots is set to be True. Otherwise
|
||||
ncols will be ignored.
|
||||
**kwargs: Accepts lighting (a Lighting object) and lightposition for Mesh3D and any of
|
||||
the args xaxis, yaxis and zaxis accept for scene. Accepts axis_args which is an
|
||||
AxisArgs object that is applied to the 3 axes. Also accepts subplot_titles, which
|
||||
should be a list of string titles matching the number of subplots.
|
||||
Example settings for axis_args and lighting are given above.
|
||||
|
||||
Returns:
|
||||
Plotly figure of the mesh. If there is more than one mesh in the batch,
|
||||
the plotly figure will contain a series of vertically stacked subplots.
|
||||
"""
|
||||
meshes = meshes.detach().cpu()
|
||||
subplot_titles = kwargs.get("subplot_titles", None)
|
||||
fig = _gen_fig_with_subplots(len(meshes), in_subplots, ncols, subplot_titles)
|
||||
for i in range(len(meshes)):
|
||||
verts = meshes[i].verts_packed()
|
||||
faces = meshes[i].faces_packed()
|
||||
# If mesh has vertex colors defined as texture, use vertex colors
|
||||
# for figure, otherwise use plotly's default colors.
|
||||
verts_rgb = None
|
||||
if isinstance(meshes[i].textures, TexturesVertex):
|
||||
verts_rgb = meshes[i].textures.verts_features_packed()
|
||||
verts_rgb.clamp_(min=0.0, max=1.0)
|
||||
verts_rgb = torch.tensor(255.0) * verts_rgb
|
||||
|
||||
# Reposition the unused vertices to be "inside" the object
|
||||
# (i.e. they won't be visible in the plot).
|
||||
verts_used = torch.zeros((verts.shape[0],), dtype=torch.bool)
|
||||
verts_used[torch.unique(faces)] = True
|
||||
verts_center = verts[verts_used].mean(0)
|
||||
verts[~verts_used] = verts_center
|
||||
|
||||
trace_row = i // ncols + 1 if in_subplots else 1
|
||||
trace_col = i % ncols + 1 if in_subplots else 1
|
||||
fig.add_trace(
|
||||
go.Mesh3d( # pyre-ignore[16]
|
||||
x=verts[:, 0],
|
||||
y=verts[:, 1],
|
||||
z=verts[:, 2],
|
||||
vertexcolor=verts_rgb,
|
||||
i=faces[:, 0],
|
||||
j=faces[:, 1],
|
||||
k=faces[:, 2],
|
||||
lighting=kwargs.get("lighting", Lighting())._asdict(),
|
||||
lightposition=kwargs.get("lightposition", {}),
|
||||
),
|
||||
row=trace_row,
|
||||
col=trace_col,
|
||||
)
|
||||
# Ensure update for every subplot.
|
||||
plot_scene = "scene" + str(i + 1) if in_subplots else "scene"
|
||||
current_layout = fig["layout"][plot_scene]
|
||||
|
||||
axis_args = kwargs.get("axis_args", AxisArgs())
|
||||
|
||||
xaxis, yaxis, zaxis = _gen_updated_axis_bounds(
|
||||
verts, verts_center, current_layout, axis_args
|
||||
)
|
||||
# Update the axis bounds with the axis settings passed in as kwargs.
|
||||
xaxis.update(**kwargs.get("xaxis", {}))
|
||||
yaxis.update(**kwargs.get("yaxis", {}))
|
||||
zaxis.update(**kwargs.get("zaxis", {}))
|
||||
|
||||
current_layout.update(
|
||||
{"xaxis": xaxis, "yaxis": yaxis, "zaxis": zaxis, "aspectmode": "cube"}
|
||||
)
|
||||
return fig
|
||||
|
||||
|
||||
def plot_pointclouds(
|
||||
pointclouds: Pointclouds,
|
||||
def plot_scene(
|
||||
plots: Dict[str, Dict[str, Union[Pointclouds, Meshes]]],
|
||||
*,
|
||||
in_subplots: bool = False,
|
||||
ncols: int = 1,
|
||||
max_points: int = 20000,
|
||||
pointcloud_max_points: int = 20000,
|
||||
pointcloud_marker_size: int = 1,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Takes a Pointclouds object and generates a plotly figure. If there is more than
|
||||
one pointcloud in the batch, and in_subplots is set to be True, each pointcloud will be
|
||||
visualized in an individual subplot with ncols number of subplots in the same row.
|
||||
Otherwise, each pointcloud in the batch will be visualized as an individual trace in the
|
||||
same plot. If the Pointclouds object has features that are size (3) or (4) then those
|
||||
rgb/rgba values will be used for the plotly figure. Otherwise, plotly's default colors
|
||||
will be used. Assumes that all rgb/rgba feature values are in the range [0,1].
|
||||
Main function to visualize Meshes and Pointclouds.
|
||||
Plots input Pointclouds and Meshes data into named subplots,
|
||||
with named traces based on the dictionary keys.
|
||||
|
||||
Args:
|
||||
pointclouds: Pointclouds object which can contain a batch of pointclouds.
|
||||
in_subplots: if each pointcloud should be visualized in an individual subplot.
|
||||
ncols: number of subplots in the same row if in_subplots is set to be True. Otherwise
|
||||
ncols will be ignored.
|
||||
max_points: maximum number of points to plot. If the cloud has more, they are
|
||||
randomly subsampled.
|
||||
**kwargs: Accepts lighting (a Lighting object) and lightposition for Scatter3D
|
||||
and any of the args xaxis, yaxis and zaxis which scene accepts.
|
||||
Accepts axis_args which is an AxisArgs object that is applied to the 3 axes.
|
||||
Also accepts subplot_titles, whichshould be a list of string titles
|
||||
matching the number of subplots. Example settings for axis_args and lighting are
|
||||
given at the top of this file.
|
||||
plots: A dict containing subplot and trace names,
|
||||
as well as the Meshes and Pointclouds objects to be rendered.
|
||||
See below for examples of the format.
|
||||
ncols: the number of subplots per row
|
||||
pointcloud_max_points: the maximum number of points to plot from
|
||||
a pointcloud. If more are present, a random sample of size
|
||||
pointcloud_max_points is used.
|
||||
pointcloud_marker_size: the size of the points rendered by plotly
|
||||
when plotting a pointcloud.
|
||||
**kwargs: Accepts lighting (a Lighting object) and any of the args xaxis,
|
||||
yaxis and zaxis which Plotly's scene accepts. Accepts axis_args,
|
||||
which is an AxisArgs object that is applied to all 3 axes.
|
||||
Example settings for axis_args and lighting are given at the
|
||||
top of this file.
|
||||
|
||||
Returns:
|
||||
Plotly figure of the pointcloud(s). If there is more than one pointcloud in the batch,
|
||||
the plotly figure will contain a plot with one trace per pointcloud, or with each
|
||||
pointcloud in a separate subplot if in_subplots is True.
|
||||
Example:
|
||||
|
||||
..code-block::python
|
||||
|
||||
mesh = ...
|
||||
point_cloud = ...
|
||||
fig = plot_scene({
|
||||
"subplot_title": {
|
||||
"mesh_trace_title": mesh,
|
||||
"pointcloud_trace_title": point_cloud
|
||||
}
|
||||
})
|
||||
fig.show()
|
||||
|
||||
The above example will render one subplot which has both a mesh and pointcloud.
|
||||
|
||||
If the Meshes or Pointclouds objects are batched, then every object in that batch
|
||||
will be plotted in a single trace.
|
||||
|
||||
..code-block::python
|
||||
mesh = ... # batch size 2
|
||||
point_cloud = ... # batch size 2
|
||||
fig = plot_scene({
|
||||
"subplot_title": {
|
||||
"mesh_trace_title": mesh,
|
||||
"pointcloud_trace_title": point_cloud
|
||||
}
|
||||
})
|
||||
fig.show()
|
||||
|
||||
The above example renders one subplot with 2 traces, each of which renders
|
||||
both objects from their respective batched data.
|
||||
|
||||
Multiple subplots follow the same pattern:
|
||||
..code-block::python
|
||||
mesh = ... # batch size 2
|
||||
point_cloud = ... # batch size 2
|
||||
fig = plot_scene({
|
||||
"subplot1_title": {
|
||||
"mesh_trace_title": mesh[0],
|
||||
"pointcloud_trace_title": point_cloud[0]
|
||||
},
|
||||
"subplot2_title": {
|
||||
"mesh_trace_title": mesh[1],
|
||||
"pointcloud_trace_title": point_cloud[1]
|
||||
}
|
||||
},
|
||||
ncols=2) # specify the number of subplots per row
|
||||
fig.show()
|
||||
|
||||
The above example will render two subplots, each containing a mesh
|
||||
and a pointcloud. The ncols argument will render two subplots in one row
|
||||
instead of having them vertically stacked because the default is one subplot
|
||||
per row.
|
||||
|
||||
For an example of using kwargs, see below:
|
||||
..code-block::python
|
||||
mesh = ...
|
||||
point_cloud = ...
|
||||
fig = plot_scene({
|
||||
"subplot_title": {
|
||||
"mesh_trace_title": mesh,
|
||||
"pointcloud_trace_title": point_cloud
|
||||
}
|
||||
},
|
||||
axis_args=AxisArgs(backgroundcolor="rgb(200,230,200)")) # kwarg axis_args
|
||||
fig.show()
|
||||
|
||||
The above example will render each axis with the input background color.
|
||||
|
||||
See the tutorials in pytorch3d/docs/tutorials for more examples
|
||||
(namely rendered_color_points.ipynb and rendered_textured_meshes.ipynb).
|
||||
"""
|
||||
pointclouds = pointclouds.detach().cpu()
|
||||
subplot_titles = kwargs.get("subplot_titles", None)
|
||||
fig = _gen_fig_with_subplots(len(pointclouds), in_subplots, ncols, subplot_titles)
|
||||
for i in range(len(pointclouds)):
|
||||
verts = pointclouds[i].points_packed()
|
||||
features = pointclouds[i].features_packed()
|
||||
subplots = list(plots.keys())
|
||||
fig = _gen_fig_with_subplots(len(subplots), ncols, subplots)
|
||||
lighting = kwargs.get("lighting", Lighting())._asdict()
|
||||
axis_args_dict = kwargs.get("axis_args", AxisArgs())._asdict()
|
||||
|
||||
indices = None
|
||||
if max_points is not None and verts.shape[0] > max_points:
|
||||
indices = np.random.choice(verts.shape[0], max_points, replace=False)
|
||||
verts = verts[indices]
|
||||
# Set axis arguments to defaults defined at the top of this file
|
||||
x_settings = {**axis_args_dict}
|
||||
y_settings = {**axis_args_dict}
|
||||
z_settings = {**axis_args_dict}
|
||||
|
||||
color = None
|
||||
if features is not None:
|
||||
features = features[indices]
|
||||
if features.shape[1] == 4: # rgba
|
||||
template = "rgb(%d, %d, %d, %f)"
|
||||
rgb = (features[:, :3] * 255).int()
|
||||
color = [
|
||||
template % (*rgb_, a_) for rgb_, a_ in zip(rgb, features[:, 3])
|
||||
]
|
||||
# Update the axes with any axis settings passed in as kwargs.
|
||||
x_settings.update(**kwargs.get("xaxis", {}))
|
||||
y_settings.update(**kwargs.get("yaxis", {}))
|
||||
z_settings.update(**kwargs.get("zaxis", {}))
|
||||
|
||||
if features.shape[1] == 3:
|
||||
template = "rgb(%d, %d, %d)"
|
||||
rgb = (features * 255).int()
|
||||
color = [template % (r, g, b) for r, g, b in rgb]
|
||||
camera = {
|
||||
"up": {
|
||||
"x": 0,
|
||||
"y": 1,
|
||||
"z": 0,
|
||||
} # set the up vector to match PyTorch3D world coordinates conventions
|
||||
}
|
||||
|
||||
trace_row = i // ncols + 1 if in_subplots else 1
|
||||
trace_col = i % ncols + 1 if in_subplots else 1
|
||||
fig.add_trace(
|
||||
go.Scatter3d( # pyre-ignore[16]
|
||||
x=verts[:, 0],
|
||||
y=verts[:, 1],
|
||||
z=verts[:, 2],
|
||||
marker={"color": color, "size": 1},
|
||||
mode="markers",
|
||||
),
|
||||
row=trace_row,
|
||||
col=trace_col,
|
||||
)
|
||||
for subplot_idx in range(len(subplots)):
|
||||
subplot_name = subplots[subplot_idx]
|
||||
traces = plots[subplot_name]
|
||||
for trace_name, struct in traces.items():
|
||||
if isinstance(struct, Meshes):
|
||||
_add_mesh_trace(fig, struct, trace_name, subplot_idx, ncols, lighting)
|
||||
elif isinstance(struct, Pointclouds):
|
||||
_add_pointcloud_trace(
|
||||
fig,
|
||||
struct,
|
||||
trace_name,
|
||||
subplot_idx,
|
||||
ncols,
|
||||
pointcloud_max_points,
|
||||
pointcloud_marker_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"struct {} is not a Meshes or Pointclouds object".format(struct)
|
||||
)
|
||||
|
||||
# Ensure update for every subplot.
|
||||
plot_scene = "scene" + str(i + 1) if in_subplots else "scene"
|
||||
plot_scene = "scene" + str(subplot_idx + 1)
|
||||
current_layout = fig["layout"][plot_scene]
|
||||
xaxis = current_layout["xaxis"]
|
||||
yaxis = current_layout["yaxis"]
|
||||
zaxis = current_layout["zaxis"]
|
||||
|
||||
verts_center = verts.mean(0)
|
||||
|
||||
axis_args = kwargs.get("axis_args", AxisArgs())
|
||||
|
||||
xaxis, yaxis, zaxis = _gen_updated_axis_bounds(
|
||||
verts, verts_center, current_layout, axis_args
|
||||
)
|
||||
xaxis.update(**kwargs.get("xaxis", {}))
|
||||
yaxis.update(**kwargs.get("yaxis", {}))
|
||||
zaxis.update(**kwargs.get("zaxis", {}))
|
||||
# Update the axes with our above default and provided settings.
|
||||
xaxis.update(**x_settings)
|
||||
yaxis.update(**y_settings)
|
||||
zaxis.update(**z_settings)
|
||||
|
||||
current_layout.update(
|
||||
{"xaxis": xaxis, "yaxis": yaxis, "zaxis": zaxis, "aspectmode": "cube"}
|
||||
{
|
||||
"xaxis": xaxis,
|
||||
"yaxis": yaxis,
|
||||
"zaxis": zaxis,
|
||||
"aspectmode": "cube",
|
||||
"camera": camera,
|
||||
}
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def _gen_fig_with_subplots(
|
||||
batch_size: int, in_subplots: bool, ncols: int, subplot_titles: Optional[list]
|
||||
def _add_mesh_trace(
|
||||
fig: go.Figure, # pyre-ignore[11]
|
||||
meshes: Meshes,
|
||||
trace_name: str,
|
||||
subplot_idx: int,
|
||||
ncols: int,
|
||||
lighting: Lighting,
|
||||
):
|
||||
"""
|
||||
Adds a trace rendering a Meshes object to the passed in figure, with
|
||||
a given name and in a specific subplot.
|
||||
|
||||
Args:
|
||||
fig: plotly figure to add the trace within.
|
||||
meshes: Meshes object to render. It can be batched.
|
||||
trace_name: name to label the trace with.
|
||||
subplot_idx: identifies the subplot, with 0 being the top left.
|
||||
ncols: the number of subplots per row.
|
||||
lighting: a Lighting object that specifies the Mesh3D lighting.
|
||||
"""
|
||||
|
||||
mesh = join_meshes_as_scene(meshes)
|
||||
mesh = mesh.detach().cpu()
|
||||
verts = mesh.verts_packed()
|
||||
faces = mesh.faces_packed()
|
||||
# If mesh has vertex colors defined as texture, use vertex colors
|
||||
# for figure, otherwise use plotly's default colors.
|
||||
verts_rgb = None
|
||||
if isinstance(mesh.textures, TexturesVertex):
|
||||
verts_rgb = mesh.textures.verts_features_packed()
|
||||
verts_rgb.clamp_(min=0.0, max=1.0)
|
||||
verts_rgb = torch.tensor(255.0) * verts_rgb
|
||||
|
||||
# Reposition the unused vertices to be "inside" the object
|
||||
# (i.e. they won't be visible in the plot).
|
||||
verts_used = torch.zeros((verts.shape[0],), dtype=torch.bool)
|
||||
verts_used[torch.unique(faces)] = True
|
||||
verts_center = verts[verts_used].mean(0)
|
||||
verts[~verts_used] = verts_center
|
||||
|
||||
row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1
|
||||
fig.add_trace(
|
||||
go.Mesh3d( # pyre-ignore[16]
|
||||
x=verts[:, 0],
|
||||
y=verts[:, 1],
|
||||
z=verts[:, 2],
|
||||
vertexcolor=verts_rgb,
|
||||
i=faces[:, 0],
|
||||
j=faces[:, 1],
|
||||
k=faces[:, 2],
|
||||
lighting=lighting,
|
||||
name=trace_name,
|
||||
),
|
||||
row=row,
|
||||
col=col,
|
||||
)
|
||||
|
||||
# Access the current subplot's scene configuration
|
||||
plot_scene = "scene" + str(subplot_idx + 1)
|
||||
current_layout = fig["layout"][plot_scene]
|
||||
|
||||
# update the bounds of the axes for the current trace
|
||||
max_expand = (verts.max(0)[0] - verts.min(0)[0]).max()
|
||||
_update_axes_bounds(verts_center, max_expand, current_layout)
|
||||
|
||||
|
||||
def _add_pointcloud_trace(
|
||||
fig: go.Figure,
|
||||
pointclouds: Pointclouds,
|
||||
trace_name: str,
|
||||
subplot_idx: int,
|
||||
ncols: int,
|
||||
max_points_per_pointcloud: int,
|
||||
marker_size: int,
|
||||
):
|
||||
"""
|
||||
Adds a trace rendering a Pointclouds object to the passed in figure, with
|
||||
a given name and in a specific subplot.
|
||||
|
||||
Args:
|
||||
fig: plotly figure to add the trace within.
|
||||
pointclouds: Pointclouds object to render. It can be batched.
|
||||
trace_name: name to label the trace with.
|
||||
subplot_idx: identifies the subplot, with 0 being the top left.
|
||||
ncols: the number of sublpots per row.
|
||||
max_points_per_pointcloud: the number of points to render, which are randomly sampled.
|
||||
marker_size: the size of the rendered points
|
||||
"""
|
||||
pointclouds = pointclouds.detach().cpu()
|
||||
verts = pointclouds.points_packed()
|
||||
features = pointclouds.features_packed()
|
||||
total_points_count = max_points_per_pointcloud * len(pointclouds)
|
||||
|
||||
indices = None
|
||||
if verts.shape[0] > total_points_count:
|
||||
indices = np.random.choice(verts.shape[0], total_points_count, replace=False)
|
||||
verts = verts[indices]
|
||||
|
||||
color = None
|
||||
if features is not None:
|
||||
features = features[indices]
|
||||
if features.shape[1] == 4: # rgba
|
||||
template = "rgb(%d, %d, %d, %f)"
|
||||
rgb = (features[:, :3].clamp(0.0, 1.0) * 255).int()
|
||||
color = [template % (*rgb_, a_) for rgb_, a_ in zip(rgb, features[:, 3])]
|
||||
|
||||
if features.shape[1] == 3:
|
||||
template = "rgb(%d, %d, %d)"
|
||||
rgb = (features.clamp(0.0, 1.0) * 255).int()
|
||||
color = [template % (r, g, b) for r, g, b in rgb]
|
||||
|
||||
row = subplot_idx // ncols + 1
|
||||
col = subplot_idx % ncols + 1
|
||||
fig.add_trace(
|
||||
go.Scatter3d( # pyre-ignore[16]
|
||||
x=verts[:, 0],
|
||||
y=verts[:, 1],
|
||||
z=verts[:, 2],
|
||||
marker={"color": color, "size": marker_size},
|
||||
mode="markers",
|
||||
name=trace_name,
|
||||
),
|
||||
row=row,
|
||||
col=col,
|
||||
)
|
||||
|
||||
# Access the current subplot's scene configuration
|
||||
plot_scene = "scene" + str(subplot_idx + 1)
|
||||
current_layout = fig["layout"][plot_scene]
|
||||
|
||||
# update the bounds of the axes for the current trace
|
||||
verts_center = verts.mean(0)
|
||||
max_expand = (verts.max(0)[0] - verts.min(0)[0]).max()
|
||||
_update_axes_bounds(verts_center, max_expand, current_layout)
|
||||
|
||||
|
||||
def _gen_fig_with_subplots(batch_size: int, ncols: int, subplot_titles: List[str]):
|
||||
"""
|
||||
Takes in the number of objects to be plotted and generate a plotly figure
|
||||
with the appropriate number and orientation of subplots
|
||||
with the appropriate number and orientation of titled subplots.
|
||||
Args:
|
||||
batch_size: the number of elements in the batch of objects to be visualized.
|
||||
in_subplots: if each object should be visualized in an individual subplot.
|
||||
ncols: number of subplots in the same row if in_subplots is set to be True. Otherwise
|
||||
ncols will be ignored.
|
||||
subplot_titles: titles for the subplot(s). list of strings of length batch_size
|
||||
if in_subplots is True, otherwise length 1.
|
||||
ncols: number of subplots in the same row.
|
||||
subplot_titles: titles for the subplot(s). list of strings of length batch_size.
|
||||
|
||||
Returns:
|
||||
Plotly figure with one plot if in_subplots is false. Otherwise, returns a plotly figure
|
||||
with ncols subplots per row.
|
||||
Plotly figure with ncols subplots per row, and batch_size subplots.
|
||||
"""
|
||||
if batch_size % ncols != 0:
|
||||
msg = "ncols is invalid for the given mesh batch size."
|
||||
warnings.warn(msg)
|
||||
fig_rows = batch_size // ncols if in_subplots else 1
|
||||
fig_cols = ncols if in_subplots else 1
|
||||
fig_rows = batch_size // ncols
|
||||
fig_cols = ncols
|
||||
fig_type = [{"type": "scene"}]
|
||||
specs = [fig_type * fig_cols] * fig_rows
|
||||
# subplot_titles must have one title per subplot
|
||||
if subplot_titles is not None and len(subplot_titles) != fig_cols * fig_rows:
|
||||
subplot_titles = None
|
||||
fig = make_subplots(
|
||||
rows=fig_rows,
|
||||
cols=fig_cols,
|
||||
@ -251,27 +368,19 @@ def _gen_fig_with_subplots(
|
||||
return fig
|
||||
|
||||
|
||||
def _gen_updated_axis_bounds(
|
||||
verts: torch.Tensor,
|
||||
def _update_axes_bounds(
|
||||
verts_center: torch.Tensor,
|
||||
max_expand: float,
|
||||
current_layout: go.Scene, # pyre-ignore[11]
|
||||
axis_args: AxisArgs,
|
||||
) -> Tuple[dict, dict, dict]:
|
||||
):
|
||||
"""
|
||||
Takes in the vertices, center point of the vertices, and the current plotly figure and
|
||||
outputs axes with bounds that capture all points in the current subplot.
|
||||
Takes in the vertices' center point and max spread, and the current plotly figure
|
||||
layout and updates the layout to have bounds that include all traces for that subplot.
|
||||
Args:
|
||||
verts: tensor of size (N, 3) representing N points with xyz coordinates.
|
||||
verts_center: tensor of size (3) corresponding to the center point of verts.
|
||||
current_layout: the current plotly figure layout scene corresponding to verts' trace.
|
||||
axis_args: an AxisArgs object with default and/or user-set values for plotly's axes.
|
||||
|
||||
Returns:
|
||||
a 3 item tuple of xaxis, yaxis, and zaxis, which are dictionaries with axis arguments
|
||||
for plotly including a range key with value the minimum and maximum value for that axis.
|
||||
verts_center: tensor of size (3) corresponding to a trace's vertices' center point.
|
||||
max_expand: the maximum spread in any dimension of the trace's vertices.
|
||||
current_layout: the plotly figure layout scene corresponding to the referenced trace.
|
||||
"""
|
||||
# Get ranges of vertices.
|
||||
max_expand = (verts.max(0)[0] - verts.min(0)[0]).max()
|
||||
verts_min = verts_center - max_expand
|
||||
verts_max = verts_center + max_expand
|
||||
bounds = torch.t(torch.stack((verts_min, verts_max)))
|
||||
@ -292,8 +401,8 @@ def _gen_updated_axis_bounds(
|
||||
if old_zrange is not None:
|
||||
z_range[0] = min(z_range[0], old_zrange[0])
|
||||
z_range[1] = max(z_range[1], old_zrange[1])
|
||||
axis_args_dict = axis_args._asdict()
|
||||
xaxis = {"range": x_range, **axis_args_dict}
|
||||
yaxis = {"range": y_range, **axis_args_dict}
|
||||
zaxis = {"range": z_range, **axis_args_dict}
|
||||
return xaxis, yaxis, zaxis
|
||||
|
||||
xaxis = {"range": x_range}
|
||||
yaxis = {"range": y_range}
|
||||
zaxis = {"range": z_range}
|
||||
current_layout.update({"xaxis": xaxis, "yaxis": yaxis, "zaxis": zaxis})
|
||||
|
Loading…
x
Reference in New Issue
Block a user