mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Plotly figure for visualizing pointclouds
Summary: Visualize a pointcloud in plotly. - customize lighting and light position - customizable axis arguments - customizable height and width of plotly figure - render batches in subplots or the same plot Reviewed By: nikhilaravi Differential Revision: D23872391 fbshipit-source-id: 9b1e1fd417500521be9d0eb85d71c77a538fa77c
This commit is contained in:
parent
8b6310359f
commit
8f1e9e1f06
@ -1,6 +1,6 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from .mesh_plotly import AxisArgs, Lighting, plot_meshes
|
||||
from .plotly_vis import AxisArgs, Lighting, plot_meshes, plot_pointclouds
|
||||
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
|
@ -3,11 +3,12 @@
|
||||
import warnings
|
||||
from typing import NamedTuple, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import plotly.graph_objects as go
|
||||
import torch
|
||||
from plotly.subplots import make_subplots
|
||||
from pytorch3d.renderer import TexturesVertex
|
||||
from pytorch3d.structures import Meshes
|
||||
from pytorch3d.structures import Meshes, Pointclouds
|
||||
|
||||
|
||||
class AxisArgs(NamedTuple):
|
||||
@ -113,6 +114,105 @@ def plot_meshes(meshes: Meshes, *, in_subplots: bool = False, ncols: int = 1, **
|
||||
return fig
|
||||
|
||||
|
||||
def plot_pointclouds(
|
||||
pointclouds: Pointclouds,
|
||||
*,
|
||||
in_subplots: bool = False,
|
||||
ncols: int = 1,
|
||||
max_points: int = 20000,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Takes a Pointclouds object and generates a plotly figure. If there is more than
|
||||
one pointcloud in the batch, and in_subplots is set to be True, each pointcloud will be
|
||||
visualized in an individual subplot with ncols number of subplots in the same row.
|
||||
Otherwise, each pointcloud in the batch will be visualized as an individual trace in the
|
||||
same plot. If the Pointclouds object has features that are size (3) or (4) then those
|
||||
rgb/rgba values will be used for the plotly figure. Otherwise, plotly's default colors
|
||||
will be used. Assumes that all rgb/rgba feature values are in the range [0,1].
|
||||
|
||||
Args:
|
||||
pointclouds: Pointclouds object which can contain a batch of pointclouds.
|
||||
in_subplots: if each pointcloud should be visualized in an individual subplot.
|
||||
ncols: number of subplots in the same row if in_subplots is set to be True. Otherwise
|
||||
ncols will be ignored.
|
||||
max_points: maximum number of points to plot. If the cloud has more, they are
|
||||
randomly subsampled.
|
||||
**kwargs: Accepts lighting (a Lighting object) and lightposition for Scatter3D
|
||||
and any of the args xaxis, yaxis and zaxis which scene accepts.
|
||||
Accepts axis_args which is an AxisArgs object that is applied to the 3 axes.
|
||||
Also accepts subplot_titles, whichshould be a list of string titles
|
||||
matching the number of subplots. Example settings for axis_args and lighting are
|
||||
given at the top of this file.
|
||||
|
||||
Returns:
|
||||
Plotly figure of the pointcloud(s). If there is more than one pointcloud in the batch,
|
||||
the plotly figure will contain a plot with one trace per pointcloud, or with each
|
||||
pointcloud in a separate subplot if in_subplots is True.
|
||||
"""
|
||||
pointclouds = pointclouds.detach().cpu()
|
||||
subplot_titles = kwargs.get("subplot_titles", None)
|
||||
fig = _gen_fig_with_subplots(len(pointclouds), in_subplots, ncols, subplot_titles)
|
||||
for i in range(len(pointclouds)):
|
||||
verts = pointclouds[i].points_packed()
|
||||
features = pointclouds[i].features_packed()
|
||||
|
||||
indices = None
|
||||
if max_points is not None and verts.shape[0] > max_points:
|
||||
indices = np.random.choice(verts.shape[0], max_points, replace=False)
|
||||
verts = verts[indices]
|
||||
|
||||
color = None
|
||||
if features is not None:
|
||||
features = features[indices]
|
||||
if features.shape[1] == 4: # rgba
|
||||
template = "rgb(%d, %d, %d, %f)"
|
||||
rgb = (features[:, :3] * 255).int()
|
||||
color = [
|
||||
template % (*rgb_, a_) for rgb_, a_ in zip(rgb, features[:, 3])
|
||||
]
|
||||
|
||||
if features.shape[1] == 3:
|
||||
template = "rgb(%d, %d, %d)"
|
||||
rgb = (features * 255).int()
|
||||
color = [template % (r, g, b) for r, g, b in rgb]
|
||||
|
||||
trace_row = i // ncols + 1 if in_subplots else 1
|
||||
trace_col = i % ncols + 1 if in_subplots else 1
|
||||
fig.add_trace(
|
||||
go.Scatter3d( # pyre-ignore[16]
|
||||
x=verts[:, 0],
|
||||
y=verts[:, 1],
|
||||
z=verts[:, 2],
|
||||
marker={"color": color, "size": 1},
|
||||
mode="markers",
|
||||
),
|
||||
row=trace_row,
|
||||
col=trace_col,
|
||||
)
|
||||
|
||||
# Ensure update for every subplot.
|
||||
plot_scene = "scene" + str(i + 1) if in_subplots else "scene"
|
||||
current_layout = fig["layout"][plot_scene]
|
||||
|
||||
verts_center = verts.mean(0)
|
||||
|
||||
axis_args = kwargs.get("axis_args", AxisArgs())
|
||||
|
||||
xaxis, yaxis, zaxis = _gen_updated_axis_bounds(
|
||||
verts, verts_center, current_layout, axis_args
|
||||
)
|
||||
xaxis.update(**kwargs.get("xaxis", {}))
|
||||
yaxis.update(**kwargs.get("yaxis", {}))
|
||||
zaxis.update(**kwargs.get("zaxis", {}))
|
||||
|
||||
current_layout.update(
|
||||
{"xaxis": xaxis, "yaxis": yaxis, "zaxis": zaxis, "aspectmode": "cube"}
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def _gen_fig_with_subplots(
|
||||
batch_size: int, in_subplots: bool, ncols: int, subplot_titles: Optional[list]
|
||||
):
|
Loading…
x
Reference in New Issue
Block a user