mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-03 04:12:48 +08:00
Refactor plot_meshes and plot_pointclouds to one generalizable API, plot_scene
Summary: Defines a function plot_scene that takes in a dictionary defining subplot and trace layouts for Mesh/Pointcloud objects and plots them. Also supports other plotly axis arguments and mesh lighting. Plot_batch_individually is a wrapper function that takes in one or multiple batched Meshes/Pointclouds and uses plot_scene to plot each element within a batch in an individual subplot, possibly sharing that subplot with traces of other individual elements of the other batched structures passed in. Reviewed By: nikhilaravi Differential Revision: D24235479 fbshipit-source-id: 9f669f1b186d55fe5c75552083316c0cf1387472
This commit is contained in:
parent
abd390319c
commit
964893cdcb
@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
from .plotly_vis import AxisArgs, Lighting, plot_meshes, plot_pointclouds
|
from .plotly_vis import AxisArgs, Lighting, plot_scene
|
||||||
|
|
||||||
|
|
||||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import NamedTuple, Optional, Tuple
|
from typing import Dict, List, NamedTuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import plotly.graph_objects as go
|
import plotly.graph_objects as go
|
||||||
import torch
|
import torch
|
||||||
from plotly.subplots import make_subplots
|
from plotly.subplots import make_subplots
|
||||||
from pytorch3d.renderer import TexturesVertex
|
from pytorch3d.renderer import TexturesVertex
|
||||||
from pytorch3d.structures import Meshes, Pointclouds
|
from pytorch3d.structures import Meshes, Pointclouds, join_meshes_as_scene
|
||||||
|
|
||||||
|
|
||||||
class AxisArgs(NamedTuple):
|
class AxisArgs(NamedTuple):
|
||||||
@ -31,216 +31,333 @@ class Lighting(NamedTuple):
|
|||||||
vertexnormalsepsilon: float = 1e-12
|
vertexnormalsepsilon: float = 1e-12
|
||||||
|
|
||||||
|
|
||||||
def plot_meshes(meshes: Meshes, *, in_subplots: bool = False, ncols: int = 1, **kwargs):
|
def plot_scene(
|
||||||
"""
|
plots: Dict[str, Dict[str, Union[Pointclouds, Meshes]]],
|
||||||
Takes a Meshes object and generates a plotly figure. If there is more than
|
|
||||||
one mesh in the batch and in_subplots=True, each mesh will be
|
|
||||||
visualized in an individual subplot with ncols number of subplots in the same row.
|
|
||||||
Otherwise, each mesh in the batch will be visualized as an individual trace in the
|
|
||||||
same plot. If the Meshes object has vertex colors defined as its texture, the vertex
|
|
||||||
colors will be used for generating the plotly figure. Otherwise plotly's default
|
|
||||||
colors will be used.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
meshes: Meshes object to be visualized in a plotly figure.
|
|
||||||
in_subplots: if each mesh in the batch should be visualized in an individual subplot
|
|
||||||
ncols: number of subplots in the same row if in_subplots is set to be True. Otherwise
|
|
||||||
ncols will be ignored.
|
|
||||||
**kwargs: Accepts lighting (a Lighting object) and lightposition for Mesh3D and any of
|
|
||||||
the args xaxis, yaxis and zaxis accept for scene. Accepts axis_args which is an
|
|
||||||
AxisArgs object that is applied to the 3 axes. Also accepts subplot_titles, which
|
|
||||||
should be a list of string titles matching the number of subplots.
|
|
||||||
Example settings for axis_args and lighting are given above.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Plotly figure of the mesh. If there is more than one mesh in the batch,
|
|
||||||
the plotly figure will contain a series of vertically stacked subplots.
|
|
||||||
"""
|
|
||||||
meshes = meshes.detach().cpu()
|
|
||||||
subplot_titles = kwargs.get("subplot_titles", None)
|
|
||||||
fig = _gen_fig_with_subplots(len(meshes), in_subplots, ncols, subplot_titles)
|
|
||||||
for i in range(len(meshes)):
|
|
||||||
verts = meshes[i].verts_packed()
|
|
||||||
faces = meshes[i].faces_packed()
|
|
||||||
# If mesh has vertex colors defined as texture, use vertex colors
|
|
||||||
# for figure, otherwise use plotly's default colors.
|
|
||||||
verts_rgb = None
|
|
||||||
if isinstance(meshes[i].textures, TexturesVertex):
|
|
||||||
verts_rgb = meshes[i].textures.verts_features_packed()
|
|
||||||
verts_rgb.clamp_(min=0.0, max=1.0)
|
|
||||||
verts_rgb = torch.tensor(255.0) * verts_rgb
|
|
||||||
|
|
||||||
# Reposition the unused vertices to be "inside" the object
|
|
||||||
# (i.e. they won't be visible in the plot).
|
|
||||||
verts_used = torch.zeros((verts.shape[0],), dtype=torch.bool)
|
|
||||||
verts_used[torch.unique(faces)] = True
|
|
||||||
verts_center = verts[verts_used].mean(0)
|
|
||||||
verts[~verts_used] = verts_center
|
|
||||||
|
|
||||||
trace_row = i // ncols + 1 if in_subplots else 1
|
|
||||||
trace_col = i % ncols + 1 if in_subplots else 1
|
|
||||||
fig.add_trace(
|
|
||||||
go.Mesh3d( # pyre-ignore[16]
|
|
||||||
x=verts[:, 0],
|
|
||||||
y=verts[:, 1],
|
|
||||||
z=verts[:, 2],
|
|
||||||
vertexcolor=verts_rgb,
|
|
||||||
i=faces[:, 0],
|
|
||||||
j=faces[:, 1],
|
|
||||||
k=faces[:, 2],
|
|
||||||
lighting=kwargs.get("lighting", Lighting())._asdict(),
|
|
||||||
lightposition=kwargs.get("lightposition", {}),
|
|
||||||
),
|
|
||||||
row=trace_row,
|
|
||||||
col=trace_col,
|
|
||||||
)
|
|
||||||
# Ensure update for every subplot.
|
|
||||||
plot_scene = "scene" + str(i + 1) if in_subplots else "scene"
|
|
||||||
current_layout = fig["layout"][plot_scene]
|
|
||||||
|
|
||||||
axis_args = kwargs.get("axis_args", AxisArgs())
|
|
||||||
|
|
||||||
xaxis, yaxis, zaxis = _gen_updated_axis_bounds(
|
|
||||||
verts, verts_center, current_layout, axis_args
|
|
||||||
)
|
|
||||||
# Update the axis bounds with the axis settings passed in as kwargs.
|
|
||||||
xaxis.update(**kwargs.get("xaxis", {}))
|
|
||||||
yaxis.update(**kwargs.get("yaxis", {}))
|
|
||||||
zaxis.update(**kwargs.get("zaxis", {}))
|
|
||||||
|
|
||||||
current_layout.update(
|
|
||||||
{"xaxis": xaxis, "yaxis": yaxis, "zaxis": zaxis, "aspectmode": "cube"}
|
|
||||||
)
|
|
||||||
return fig
|
|
||||||
|
|
||||||
|
|
||||||
def plot_pointclouds(
|
|
||||||
pointclouds: Pointclouds,
|
|
||||||
*,
|
*,
|
||||||
in_subplots: bool = False,
|
|
||||||
ncols: int = 1,
|
ncols: int = 1,
|
||||||
max_points: int = 20000,
|
pointcloud_max_points: int = 20000,
|
||||||
|
pointcloud_marker_size: int = 1,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Takes a Pointclouds object and generates a plotly figure. If there is more than
|
Main function to visualize Meshes and Pointclouds.
|
||||||
one pointcloud in the batch, and in_subplots is set to be True, each pointcloud will be
|
Plots input Pointclouds and Meshes data into named subplots,
|
||||||
visualized in an individual subplot with ncols number of subplots in the same row.
|
with named traces based on the dictionary keys.
|
||||||
Otherwise, each pointcloud in the batch will be visualized as an individual trace in the
|
|
||||||
same plot. If the Pointclouds object has features that are size (3) or (4) then those
|
|
||||||
rgb/rgba values will be used for the plotly figure. Otherwise, plotly's default colors
|
|
||||||
will be used. Assumes that all rgb/rgba feature values are in the range [0,1].
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pointclouds: Pointclouds object which can contain a batch of pointclouds.
|
plots: A dict containing subplot and trace names,
|
||||||
in_subplots: if each pointcloud should be visualized in an individual subplot.
|
as well as the Meshes and Pointclouds objects to be rendered.
|
||||||
ncols: number of subplots in the same row if in_subplots is set to be True. Otherwise
|
See below for examples of the format.
|
||||||
ncols will be ignored.
|
ncols: the number of subplots per row
|
||||||
max_points: maximum number of points to plot. If the cloud has more, they are
|
pointcloud_max_points: the maximum number of points to plot from
|
||||||
randomly subsampled.
|
a pointcloud. If more are present, a random sample of size
|
||||||
**kwargs: Accepts lighting (a Lighting object) and lightposition for Scatter3D
|
pointcloud_max_points is used.
|
||||||
and any of the args xaxis, yaxis and zaxis which scene accepts.
|
pointcloud_marker_size: the size of the points rendered by plotly
|
||||||
Accepts axis_args which is an AxisArgs object that is applied to the 3 axes.
|
when plotting a pointcloud.
|
||||||
Also accepts subplot_titles, whichshould be a list of string titles
|
**kwargs: Accepts lighting (a Lighting object) and any of the args xaxis,
|
||||||
matching the number of subplots. Example settings for axis_args and lighting are
|
yaxis and zaxis which Plotly's scene accepts. Accepts axis_args,
|
||||||
given at the top of this file.
|
which is an AxisArgs object that is applied to all 3 axes.
|
||||||
|
Example settings for axis_args and lighting are given at the
|
||||||
|
top of this file.
|
||||||
|
|
||||||
Returns:
|
Example:
|
||||||
Plotly figure of the pointcloud(s). If there is more than one pointcloud in the batch,
|
|
||||||
the plotly figure will contain a plot with one trace per pointcloud, or with each
|
..code-block::python
|
||||||
pointcloud in a separate subplot if in_subplots is True.
|
|
||||||
|
mesh = ...
|
||||||
|
point_cloud = ...
|
||||||
|
fig = plot_scene({
|
||||||
|
"subplot_title": {
|
||||||
|
"mesh_trace_title": mesh,
|
||||||
|
"pointcloud_trace_title": point_cloud
|
||||||
|
}
|
||||||
|
})
|
||||||
|
fig.show()
|
||||||
|
|
||||||
|
The above example will render one subplot which has both a mesh and pointcloud.
|
||||||
|
|
||||||
|
If the Meshes or Pointclouds objects are batched, then every object in that batch
|
||||||
|
will be plotted in a single trace.
|
||||||
|
|
||||||
|
..code-block::python
|
||||||
|
mesh = ... # batch size 2
|
||||||
|
point_cloud = ... # batch size 2
|
||||||
|
fig = plot_scene({
|
||||||
|
"subplot_title": {
|
||||||
|
"mesh_trace_title": mesh,
|
||||||
|
"pointcloud_trace_title": point_cloud
|
||||||
|
}
|
||||||
|
})
|
||||||
|
fig.show()
|
||||||
|
|
||||||
|
The above example renders one subplot with 2 traces, each of which renders
|
||||||
|
both objects from their respective batched data.
|
||||||
|
|
||||||
|
Multiple subplots follow the same pattern:
|
||||||
|
..code-block::python
|
||||||
|
mesh = ... # batch size 2
|
||||||
|
point_cloud = ... # batch size 2
|
||||||
|
fig = plot_scene({
|
||||||
|
"subplot1_title": {
|
||||||
|
"mesh_trace_title": mesh[0],
|
||||||
|
"pointcloud_trace_title": point_cloud[0]
|
||||||
|
},
|
||||||
|
"subplot2_title": {
|
||||||
|
"mesh_trace_title": mesh[1],
|
||||||
|
"pointcloud_trace_title": point_cloud[1]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
ncols=2) # specify the number of subplots per row
|
||||||
|
fig.show()
|
||||||
|
|
||||||
|
The above example will render two subplots, each containing a mesh
|
||||||
|
and a pointcloud. The ncols argument will render two subplots in one row
|
||||||
|
instead of having them vertically stacked because the default is one subplot
|
||||||
|
per row.
|
||||||
|
|
||||||
|
For an example of using kwargs, see below:
|
||||||
|
..code-block::python
|
||||||
|
mesh = ...
|
||||||
|
point_cloud = ...
|
||||||
|
fig = plot_scene({
|
||||||
|
"subplot_title": {
|
||||||
|
"mesh_trace_title": mesh,
|
||||||
|
"pointcloud_trace_title": point_cloud
|
||||||
|
}
|
||||||
|
},
|
||||||
|
axis_args=AxisArgs(backgroundcolor="rgb(200,230,200)")) # kwarg axis_args
|
||||||
|
fig.show()
|
||||||
|
|
||||||
|
The above example will render each axis with the input background color.
|
||||||
|
|
||||||
|
See the tutorials in pytorch3d/docs/tutorials for more examples
|
||||||
|
(namely rendered_color_points.ipynb and rendered_textured_meshes.ipynb).
|
||||||
"""
|
"""
|
||||||
pointclouds = pointclouds.detach().cpu()
|
subplots = list(plots.keys())
|
||||||
subplot_titles = kwargs.get("subplot_titles", None)
|
fig = _gen_fig_with_subplots(len(subplots), ncols, subplots)
|
||||||
fig = _gen_fig_with_subplots(len(pointclouds), in_subplots, ncols, subplot_titles)
|
lighting = kwargs.get("lighting", Lighting())._asdict()
|
||||||
for i in range(len(pointclouds)):
|
axis_args_dict = kwargs.get("axis_args", AxisArgs())._asdict()
|
||||||
verts = pointclouds[i].points_packed()
|
|
||||||
features = pointclouds[i].features_packed()
|
|
||||||
|
|
||||||
indices = None
|
# Set axis arguments to defaults defined at the top of this file
|
||||||
if max_points is not None and verts.shape[0] > max_points:
|
x_settings = {**axis_args_dict}
|
||||||
indices = np.random.choice(verts.shape[0], max_points, replace=False)
|
y_settings = {**axis_args_dict}
|
||||||
verts = verts[indices]
|
z_settings = {**axis_args_dict}
|
||||||
|
|
||||||
color = None
|
# Update the axes with any axis settings passed in as kwargs.
|
||||||
if features is not None:
|
x_settings.update(**kwargs.get("xaxis", {}))
|
||||||
features = features[indices]
|
y_settings.update(**kwargs.get("yaxis", {}))
|
||||||
if features.shape[1] == 4: # rgba
|
z_settings.update(**kwargs.get("zaxis", {}))
|
||||||
template = "rgb(%d, %d, %d, %f)"
|
|
||||||
rgb = (features[:, :3] * 255).int()
|
|
||||||
color = [
|
|
||||||
template % (*rgb_, a_) for rgb_, a_ in zip(rgb, features[:, 3])
|
|
||||||
]
|
|
||||||
|
|
||||||
if features.shape[1] == 3:
|
camera = {
|
||||||
template = "rgb(%d, %d, %d)"
|
"up": {
|
||||||
rgb = (features * 255).int()
|
"x": 0,
|
||||||
color = [template % (r, g, b) for r, g, b in rgb]
|
"y": 1,
|
||||||
|
"z": 0,
|
||||||
|
} # set the up vector to match PyTorch3D world coordinates conventions
|
||||||
|
}
|
||||||
|
|
||||||
trace_row = i // ncols + 1 if in_subplots else 1
|
for subplot_idx in range(len(subplots)):
|
||||||
trace_col = i % ncols + 1 if in_subplots else 1
|
subplot_name = subplots[subplot_idx]
|
||||||
fig.add_trace(
|
traces = plots[subplot_name]
|
||||||
go.Scatter3d( # pyre-ignore[16]
|
for trace_name, struct in traces.items():
|
||||||
x=verts[:, 0],
|
if isinstance(struct, Meshes):
|
||||||
y=verts[:, 1],
|
_add_mesh_trace(fig, struct, trace_name, subplot_idx, ncols, lighting)
|
||||||
z=verts[:, 2],
|
elif isinstance(struct, Pointclouds):
|
||||||
marker={"color": color, "size": 1},
|
_add_pointcloud_trace(
|
||||||
mode="markers",
|
fig,
|
||||||
),
|
struct,
|
||||||
row=trace_row,
|
trace_name,
|
||||||
col=trace_col,
|
subplot_idx,
|
||||||
)
|
ncols,
|
||||||
|
pointcloud_max_points,
|
||||||
|
pointcloud_marker_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"struct {} is not a Meshes or Pointclouds object".format(struct)
|
||||||
|
)
|
||||||
|
|
||||||
# Ensure update for every subplot.
|
# Ensure update for every subplot.
|
||||||
plot_scene = "scene" + str(i + 1) if in_subplots else "scene"
|
plot_scene = "scene" + str(subplot_idx + 1)
|
||||||
current_layout = fig["layout"][plot_scene]
|
current_layout = fig["layout"][plot_scene]
|
||||||
|
xaxis = current_layout["xaxis"]
|
||||||
|
yaxis = current_layout["yaxis"]
|
||||||
|
zaxis = current_layout["zaxis"]
|
||||||
|
|
||||||
verts_center = verts.mean(0)
|
# Update the axes with our above default and provided settings.
|
||||||
|
xaxis.update(**x_settings)
|
||||||
axis_args = kwargs.get("axis_args", AxisArgs())
|
yaxis.update(**y_settings)
|
||||||
|
zaxis.update(**z_settings)
|
||||||
xaxis, yaxis, zaxis = _gen_updated_axis_bounds(
|
|
||||||
verts, verts_center, current_layout, axis_args
|
|
||||||
)
|
|
||||||
xaxis.update(**kwargs.get("xaxis", {}))
|
|
||||||
yaxis.update(**kwargs.get("yaxis", {}))
|
|
||||||
zaxis.update(**kwargs.get("zaxis", {}))
|
|
||||||
|
|
||||||
current_layout.update(
|
current_layout.update(
|
||||||
{"xaxis": xaxis, "yaxis": yaxis, "zaxis": zaxis, "aspectmode": "cube"}
|
{
|
||||||
|
"xaxis": xaxis,
|
||||||
|
"yaxis": yaxis,
|
||||||
|
"zaxis": zaxis,
|
||||||
|
"aspectmode": "cube",
|
||||||
|
"camera": camera,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
def _gen_fig_with_subplots(
|
def _add_mesh_trace(
|
||||||
batch_size: int, in_subplots: bool, ncols: int, subplot_titles: Optional[list]
|
fig: go.Figure, # pyre-ignore[11]
|
||||||
|
meshes: Meshes,
|
||||||
|
trace_name: str,
|
||||||
|
subplot_idx: int,
|
||||||
|
ncols: int,
|
||||||
|
lighting: Lighting,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Adds a trace rendering a Meshes object to the passed in figure, with
|
||||||
|
a given name and in a specific subplot.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fig: plotly figure to add the trace within.
|
||||||
|
meshes: Meshes object to render. It can be batched.
|
||||||
|
trace_name: name to label the trace with.
|
||||||
|
subplot_idx: identifies the subplot, with 0 being the top left.
|
||||||
|
ncols: the number of subplots per row.
|
||||||
|
lighting: a Lighting object that specifies the Mesh3D lighting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
mesh = join_meshes_as_scene(meshes)
|
||||||
|
mesh = mesh.detach().cpu()
|
||||||
|
verts = mesh.verts_packed()
|
||||||
|
faces = mesh.faces_packed()
|
||||||
|
# If mesh has vertex colors defined as texture, use vertex colors
|
||||||
|
# for figure, otherwise use plotly's default colors.
|
||||||
|
verts_rgb = None
|
||||||
|
if isinstance(mesh.textures, TexturesVertex):
|
||||||
|
verts_rgb = mesh.textures.verts_features_packed()
|
||||||
|
verts_rgb.clamp_(min=0.0, max=1.0)
|
||||||
|
verts_rgb = torch.tensor(255.0) * verts_rgb
|
||||||
|
|
||||||
|
# Reposition the unused vertices to be "inside" the object
|
||||||
|
# (i.e. they won't be visible in the plot).
|
||||||
|
verts_used = torch.zeros((verts.shape[0],), dtype=torch.bool)
|
||||||
|
verts_used[torch.unique(faces)] = True
|
||||||
|
verts_center = verts[verts_used].mean(0)
|
||||||
|
verts[~verts_used] = verts_center
|
||||||
|
|
||||||
|
row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1
|
||||||
|
fig.add_trace(
|
||||||
|
go.Mesh3d( # pyre-ignore[16]
|
||||||
|
x=verts[:, 0],
|
||||||
|
y=verts[:, 1],
|
||||||
|
z=verts[:, 2],
|
||||||
|
vertexcolor=verts_rgb,
|
||||||
|
i=faces[:, 0],
|
||||||
|
j=faces[:, 1],
|
||||||
|
k=faces[:, 2],
|
||||||
|
lighting=lighting,
|
||||||
|
name=trace_name,
|
||||||
|
),
|
||||||
|
row=row,
|
||||||
|
col=col,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Access the current subplot's scene configuration
|
||||||
|
plot_scene = "scene" + str(subplot_idx + 1)
|
||||||
|
current_layout = fig["layout"][plot_scene]
|
||||||
|
|
||||||
|
# update the bounds of the axes for the current trace
|
||||||
|
max_expand = (verts.max(0)[0] - verts.min(0)[0]).max()
|
||||||
|
_update_axes_bounds(verts_center, max_expand, current_layout)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_pointcloud_trace(
|
||||||
|
fig: go.Figure,
|
||||||
|
pointclouds: Pointclouds,
|
||||||
|
trace_name: str,
|
||||||
|
subplot_idx: int,
|
||||||
|
ncols: int,
|
||||||
|
max_points_per_pointcloud: int,
|
||||||
|
marker_size: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Adds a trace rendering a Pointclouds object to the passed in figure, with
|
||||||
|
a given name and in a specific subplot.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fig: plotly figure to add the trace within.
|
||||||
|
pointclouds: Pointclouds object to render. It can be batched.
|
||||||
|
trace_name: name to label the trace with.
|
||||||
|
subplot_idx: identifies the subplot, with 0 being the top left.
|
||||||
|
ncols: the number of sublpots per row.
|
||||||
|
max_points_per_pointcloud: the number of points to render, which are randomly sampled.
|
||||||
|
marker_size: the size of the rendered points
|
||||||
|
"""
|
||||||
|
pointclouds = pointclouds.detach().cpu()
|
||||||
|
verts = pointclouds.points_packed()
|
||||||
|
features = pointclouds.features_packed()
|
||||||
|
total_points_count = max_points_per_pointcloud * len(pointclouds)
|
||||||
|
|
||||||
|
indices = None
|
||||||
|
if verts.shape[0] > total_points_count:
|
||||||
|
indices = np.random.choice(verts.shape[0], total_points_count, replace=False)
|
||||||
|
verts = verts[indices]
|
||||||
|
|
||||||
|
color = None
|
||||||
|
if features is not None:
|
||||||
|
features = features[indices]
|
||||||
|
if features.shape[1] == 4: # rgba
|
||||||
|
template = "rgb(%d, %d, %d, %f)"
|
||||||
|
rgb = (features[:, :3].clamp(0.0, 1.0) * 255).int()
|
||||||
|
color = [template % (*rgb_, a_) for rgb_, a_ in zip(rgb, features[:, 3])]
|
||||||
|
|
||||||
|
if features.shape[1] == 3:
|
||||||
|
template = "rgb(%d, %d, %d)"
|
||||||
|
rgb = (features.clamp(0.0, 1.0) * 255).int()
|
||||||
|
color = [template % (r, g, b) for r, g, b in rgb]
|
||||||
|
|
||||||
|
row = subplot_idx // ncols + 1
|
||||||
|
col = subplot_idx % ncols + 1
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter3d( # pyre-ignore[16]
|
||||||
|
x=verts[:, 0],
|
||||||
|
y=verts[:, 1],
|
||||||
|
z=verts[:, 2],
|
||||||
|
marker={"color": color, "size": marker_size},
|
||||||
|
mode="markers",
|
||||||
|
name=trace_name,
|
||||||
|
),
|
||||||
|
row=row,
|
||||||
|
col=col,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Access the current subplot's scene configuration
|
||||||
|
plot_scene = "scene" + str(subplot_idx + 1)
|
||||||
|
current_layout = fig["layout"][plot_scene]
|
||||||
|
|
||||||
|
# update the bounds of the axes for the current trace
|
||||||
|
verts_center = verts.mean(0)
|
||||||
|
max_expand = (verts.max(0)[0] - verts.min(0)[0]).max()
|
||||||
|
_update_axes_bounds(verts_center, max_expand, current_layout)
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_fig_with_subplots(batch_size: int, ncols: int, subplot_titles: List[str]):
|
||||||
"""
|
"""
|
||||||
Takes in the number of objects to be plotted and generate a plotly figure
|
Takes in the number of objects to be plotted and generate a plotly figure
|
||||||
with the appropriate number and orientation of subplots
|
with the appropriate number and orientation of titled subplots.
|
||||||
Args:
|
Args:
|
||||||
batch_size: the number of elements in the batch of objects to be visualized.
|
batch_size: the number of elements in the batch of objects to be visualized.
|
||||||
in_subplots: if each object should be visualized in an individual subplot.
|
ncols: number of subplots in the same row.
|
||||||
ncols: number of subplots in the same row if in_subplots is set to be True. Otherwise
|
subplot_titles: titles for the subplot(s). list of strings of length batch_size.
|
||||||
ncols will be ignored.
|
|
||||||
subplot_titles: titles for the subplot(s). list of strings of length batch_size
|
|
||||||
if in_subplots is True, otherwise length 1.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Plotly figure with one plot if in_subplots is false. Otherwise, returns a plotly figure
|
Plotly figure with ncols subplots per row, and batch_size subplots.
|
||||||
with ncols subplots per row.
|
|
||||||
"""
|
"""
|
||||||
if batch_size % ncols != 0:
|
if batch_size % ncols != 0:
|
||||||
msg = "ncols is invalid for the given mesh batch size."
|
msg = "ncols is invalid for the given mesh batch size."
|
||||||
warnings.warn(msg)
|
warnings.warn(msg)
|
||||||
fig_rows = batch_size // ncols if in_subplots else 1
|
fig_rows = batch_size // ncols
|
||||||
fig_cols = ncols if in_subplots else 1
|
fig_cols = ncols
|
||||||
fig_type = [{"type": "scene"}]
|
fig_type = [{"type": "scene"}]
|
||||||
specs = [fig_type * fig_cols] * fig_rows
|
specs = [fig_type * fig_cols] * fig_rows
|
||||||
# subplot_titles must have one title per subplot
|
# subplot_titles must have one title per subplot
|
||||||
if subplot_titles is not None and len(subplot_titles) != fig_cols * fig_rows:
|
|
||||||
subplot_titles = None
|
|
||||||
fig = make_subplots(
|
fig = make_subplots(
|
||||||
rows=fig_rows,
|
rows=fig_rows,
|
||||||
cols=fig_cols,
|
cols=fig_cols,
|
||||||
@ -251,27 +368,19 @@ def _gen_fig_with_subplots(
|
|||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
def _gen_updated_axis_bounds(
|
def _update_axes_bounds(
|
||||||
verts: torch.Tensor,
|
|
||||||
verts_center: torch.Tensor,
|
verts_center: torch.Tensor,
|
||||||
|
max_expand: float,
|
||||||
current_layout: go.Scene, # pyre-ignore[11]
|
current_layout: go.Scene, # pyre-ignore[11]
|
||||||
axis_args: AxisArgs,
|
):
|
||||||
) -> Tuple[dict, dict, dict]:
|
|
||||||
"""
|
"""
|
||||||
Takes in the vertices, center point of the vertices, and the current plotly figure and
|
Takes in the vertices' center point and max spread, and the current plotly figure
|
||||||
outputs axes with bounds that capture all points in the current subplot.
|
layout and updates the layout to have bounds that include all traces for that subplot.
|
||||||
Args:
|
Args:
|
||||||
verts: tensor of size (N, 3) representing N points with xyz coordinates.
|
verts_center: tensor of size (3) corresponding to a trace's vertices' center point.
|
||||||
verts_center: tensor of size (3) corresponding to the center point of verts.
|
max_expand: the maximum spread in any dimension of the trace's vertices.
|
||||||
current_layout: the current plotly figure layout scene corresponding to verts' trace.
|
current_layout: the plotly figure layout scene corresponding to the referenced trace.
|
||||||
axis_args: an AxisArgs object with default and/or user-set values for plotly's axes.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
a 3 item tuple of xaxis, yaxis, and zaxis, which are dictionaries with axis arguments
|
|
||||||
for plotly including a range key with value the minimum and maximum value for that axis.
|
|
||||||
"""
|
"""
|
||||||
# Get ranges of vertices.
|
|
||||||
max_expand = (verts.max(0)[0] - verts.min(0)[0]).max()
|
|
||||||
verts_min = verts_center - max_expand
|
verts_min = verts_center - max_expand
|
||||||
verts_max = verts_center + max_expand
|
verts_max = verts_center + max_expand
|
||||||
bounds = torch.t(torch.stack((verts_min, verts_max)))
|
bounds = torch.t(torch.stack((verts_min, verts_max)))
|
||||||
@ -292,8 +401,8 @@ def _gen_updated_axis_bounds(
|
|||||||
if old_zrange is not None:
|
if old_zrange is not None:
|
||||||
z_range[0] = min(z_range[0], old_zrange[0])
|
z_range[0] = min(z_range[0], old_zrange[0])
|
||||||
z_range[1] = max(z_range[1], old_zrange[1])
|
z_range[1] = max(z_range[1], old_zrange[1])
|
||||||
axis_args_dict = axis_args._asdict()
|
|
||||||
xaxis = {"range": x_range, **axis_args_dict}
|
xaxis = {"range": x_range}
|
||||||
yaxis = {"range": y_range, **axis_args_dict}
|
yaxis = {"range": y_range}
|
||||||
zaxis = {"range": z_range, **axis_args_dict}
|
zaxis = {"range": z_range}
|
||||||
return xaxis, yaxis, zaxis
|
current_layout.update({"xaxis": xaxis, "yaxis": yaxis, "zaxis": zaxis})
|
||||||
|
Loading…
x
Reference in New Issue
Block a user