mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-03 04:12:48 +08:00
Plotly figure for visualizing pointclouds
Summary: Visualize a pointcloud in plotly. - customize lighting and light position - customizable axis arguments - customizable height and width of plotly figure - render batches in subplots or the same plot Reviewed By: nikhilaravi Differential Revision: D23872391 fbshipit-source-id: 9b1e1fd417500521be9d0eb85d71c77a538fa77c
This commit is contained in:
parent
8b6310359f
commit
8f1e9e1f06
@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
from .mesh_plotly import AxisArgs, Lighting, plot_meshes
|
from .plotly_vis import AxisArgs, Lighting, plot_meshes, plot_pointclouds
|
||||||
|
|
||||||
|
|
||||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||||
|
@ -3,11 +3,12 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import NamedTuple, Optional, Tuple
|
from typing import NamedTuple, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import plotly.graph_objects as go
|
import plotly.graph_objects as go
|
||||||
import torch
|
import torch
|
||||||
from plotly.subplots import make_subplots
|
from plotly.subplots import make_subplots
|
||||||
from pytorch3d.renderer import TexturesVertex
|
from pytorch3d.renderer import TexturesVertex
|
||||||
from pytorch3d.structures import Meshes
|
from pytorch3d.structures import Meshes, Pointclouds
|
||||||
|
|
||||||
|
|
||||||
class AxisArgs(NamedTuple):
|
class AxisArgs(NamedTuple):
|
||||||
@ -113,6 +114,105 @@ def plot_meshes(meshes: Meshes, *, in_subplots: bool = False, ncols: int = 1, **
|
|||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def plot_pointclouds(
|
||||||
|
pointclouds: Pointclouds,
|
||||||
|
*,
|
||||||
|
in_subplots: bool = False,
|
||||||
|
ncols: int = 1,
|
||||||
|
max_points: int = 20000,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Takes a Pointclouds object and generates a plotly figure. If there is more than
|
||||||
|
one pointcloud in the batch, and in_subplots is set to be True, each pointcloud will be
|
||||||
|
visualized in an individual subplot with ncols number of subplots in the same row.
|
||||||
|
Otherwise, each pointcloud in the batch will be visualized as an individual trace in the
|
||||||
|
same plot. If the Pointclouds object has features that are size (3) or (4) then those
|
||||||
|
rgb/rgba values will be used for the plotly figure. Otherwise, plotly's default colors
|
||||||
|
will be used. Assumes that all rgb/rgba feature values are in the range [0,1].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pointclouds: Pointclouds object which can contain a batch of pointclouds.
|
||||||
|
in_subplots: if each pointcloud should be visualized in an individual subplot.
|
||||||
|
ncols: number of subplots in the same row if in_subplots is set to be True. Otherwise
|
||||||
|
ncols will be ignored.
|
||||||
|
max_points: maximum number of points to plot. If the cloud has more, they are
|
||||||
|
randomly subsampled.
|
||||||
|
**kwargs: Accepts lighting (a Lighting object) and lightposition for Scatter3D
|
||||||
|
and any of the args xaxis, yaxis and zaxis which scene accepts.
|
||||||
|
Accepts axis_args which is an AxisArgs object that is applied to the 3 axes.
|
||||||
|
Also accepts subplot_titles, whichshould be a list of string titles
|
||||||
|
matching the number of subplots. Example settings for axis_args and lighting are
|
||||||
|
given at the top of this file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plotly figure of the pointcloud(s). If there is more than one pointcloud in the batch,
|
||||||
|
the plotly figure will contain a plot with one trace per pointcloud, or with each
|
||||||
|
pointcloud in a separate subplot if in_subplots is True.
|
||||||
|
"""
|
||||||
|
pointclouds = pointclouds.detach().cpu()
|
||||||
|
subplot_titles = kwargs.get("subplot_titles", None)
|
||||||
|
fig = _gen_fig_with_subplots(len(pointclouds), in_subplots, ncols, subplot_titles)
|
||||||
|
for i in range(len(pointclouds)):
|
||||||
|
verts = pointclouds[i].points_packed()
|
||||||
|
features = pointclouds[i].features_packed()
|
||||||
|
|
||||||
|
indices = None
|
||||||
|
if max_points is not None and verts.shape[0] > max_points:
|
||||||
|
indices = np.random.choice(verts.shape[0], max_points, replace=False)
|
||||||
|
verts = verts[indices]
|
||||||
|
|
||||||
|
color = None
|
||||||
|
if features is not None:
|
||||||
|
features = features[indices]
|
||||||
|
if features.shape[1] == 4: # rgba
|
||||||
|
template = "rgb(%d, %d, %d, %f)"
|
||||||
|
rgb = (features[:, :3] * 255).int()
|
||||||
|
color = [
|
||||||
|
template % (*rgb_, a_) for rgb_, a_ in zip(rgb, features[:, 3])
|
||||||
|
]
|
||||||
|
|
||||||
|
if features.shape[1] == 3:
|
||||||
|
template = "rgb(%d, %d, %d)"
|
||||||
|
rgb = (features * 255).int()
|
||||||
|
color = [template % (r, g, b) for r, g, b in rgb]
|
||||||
|
|
||||||
|
trace_row = i // ncols + 1 if in_subplots else 1
|
||||||
|
trace_col = i % ncols + 1 if in_subplots else 1
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter3d( # pyre-ignore[16]
|
||||||
|
x=verts[:, 0],
|
||||||
|
y=verts[:, 1],
|
||||||
|
z=verts[:, 2],
|
||||||
|
marker={"color": color, "size": 1},
|
||||||
|
mode="markers",
|
||||||
|
),
|
||||||
|
row=trace_row,
|
||||||
|
col=trace_col,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure update for every subplot.
|
||||||
|
plot_scene = "scene" + str(i + 1) if in_subplots else "scene"
|
||||||
|
current_layout = fig["layout"][plot_scene]
|
||||||
|
|
||||||
|
verts_center = verts.mean(0)
|
||||||
|
|
||||||
|
axis_args = kwargs.get("axis_args", AxisArgs())
|
||||||
|
|
||||||
|
xaxis, yaxis, zaxis = _gen_updated_axis_bounds(
|
||||||
|
verts, verts_center, current_layout, axis_args
|
||||||
|
)
|
||||||
|
xaxis.update(**kwargs.get("xaxis", {}))
|
||||||
|
yaxis.update(**kwargs.get("yaxis", {}))
|
||||||
|
zaxis.update(**kwargs.get("zaxis", {}))
|
||||||
|
|
||||||
|
current_layout.update(
|
||||||
|
{"xaxis": xaxis, "yaxis": yaxis, "zaxis": zaxis, "aspectmode": "cube"}
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
def _gen_fig_with_subplots(
|
def _gen_fig_with_subplots(
|
||||||
batch_size: int, in_subplots: bool, ncols: int, subplot_titles: Optional[list]
|
batch_size: int, in_subplots: bool, ncols: int, subplot_titles: Optional[list]
|
||||||
):
|
):
|
Loading…
x
Reference in New Issue
Block a user