Plotly figure for visualizing pointclouds

Summary:
Visualize a pointcloud in plotly.
- customize lighting and light position
- customizable axis arguments
- customizable height and width of plotly figure
- render batches in subplots or the same plot

Reviewed By: nikhilaravi

Differential Revision: D23872391

fbshipit-source-id: 9b1e1fd417500521be9d0eb85d71c77a538fa77c
This commit is contained in:
Amitav Baruah 2020-10-01 16:47:41 -07:00 committed by Facebook GitHub Bot
parent 8b6310359f
commit 8f1e9e1f06
2 changed files with 102 additions and 2 deletions

View File

@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .mesh_plotly import AxisArgs, Lighting, plot_meshes
from .plotly_vis import AxisArgs, Lighting, plot_meshes, plot_pointclouds
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@ -3,11 +3,12 @@
import warnings
from typing import NamedTuple, Optional, Tuple
import numpy as np
import plotly.graph_objects as go
import torch
from plotly.subplots import make_subplots
from pytorch3d.renderer import TexturesVertex
from pytorch3d.structures import Meshes
from pytorch3d.structures import Meshes, Pointclouds
class AxisArgs(NamedTuple):
@ -113,6 +114,105 @@ def plot_meshes(meshes: Meshes, *, in_subplots: bool = False, ncols: int = 1, **
return fig
def plot_pointclouds(
pointclouds: Pointclouds,
*,
in_subplots: bool = False,
ncols: int = 1,
max_points: int = 20000,
**kwargs,
):
"""
Takes a Pointclouds object and generates a plotly figure. If there is more than
one pointcloud in the batch, and in_subplots is set to be True, each pointcloud will be
visualized in an individual subplot with ncols number of subplots in the same row.
Otherwise, each pointcloud in the batch will be visualized as an individual trace in the
same plot. If the Pointclouds object has features that are size (3) or (4) then those
rgb/rgba values will be used for the plotly figure. Otherwise, plotly's default colors
will be used. Assumes that all rgb/rgba feature values are in the range [0,1].
Args:
pointclouds: Pointclouds object which can contain a batch of pointclouds.
in_subplots: if each pointcloud should be visualized in an individual subplot.
ncols: number of subplots in the same row if in_subplots is set to be True. Otherwise
ncols will be ignored.
max_points: maximum number of points to plot. If the cloud has more, they are
randomly subsampled.
**kwargs: Accepts lighting (a Lighting object) and lightposition for Scatter3D
and any of the args xaxis, yaxis and zaxis which scene accepts.
Accepts axis_args which is an AxisArgs object that is applied to the 3 axes.
Also accepts subplot_titles, whichshould be a list of string titles
matching the number of subplots. Example settings for axis_args and lighting are
given at the top of this file.
Returns:
Plotly figure of the pointcloud(s). If there is more than one pointcloud in the batch,
the plotly figure will contain a plot with one trace per pointcloud, or with each
pointcloud in a separate subplot if in_subplots is True.
"""
pointclouds = pointclouds.detach().cpu()
subplot_titles = kwargs.get("subplot_titles", None)
fig = _gen_fig_with_subplots(len(pointclouds), in_subplots, ncols, subplot_titles)
for i in range(len(pointclouds)):
verts = pointclouds[i].points_packed()
features = pointclouds[i].features_packed()
indices = None
if max_points is not None and verts.shape[0] > max_points:
indices = np.random.choice(verts.shape[0], max_points, replace=False)
verts = verts[indices]
color = None
if features is not None:
features = features[indices]
if features.shape[1] == 4: # rgba
template = "rgb(%d, %d, %d, %f)"
rgb = (features[:, :3] * 255).int()
color = [
template % (*rgb_, a_) for rgb_, a_ in zip(rgb, features[:, 3])
]
if features.shape[1] == 3:
template = "rgb(%d, %d, %d)"
rgb = (features * 255).int()
color = [template % (r, g, b) for r, g, b in rgb]
trace_row = i // ncols + 1 if in_subplots else 1
trace_col = i % ncols + 1 if in_subplots else 1
fig.add_trace(
go.Scatter3d( # pyre-ignore[16]
x=verts[:, 0],
y=verts[:, 1],
z=verts[:, 2],
marker={"color": color, "size": 1},
mode="markers",
),
row=trace_row,
col=trace_col,
)
# Ensure update for every subplot.
plot_scene = "scene" + str(i + 1) if in_subplots else "scene"
current_layout = fig["layout"][plot_scene]
verts_center = verts.mean(0)
axis_args = kwargs.get("axis_args", AxisArgs())
xaxis, yaxis, zaxis = _gen_updated_axis_bounds(
verts, verts_center, current_layout, axis_args
)
xaxis.update(**kwargs.get("xaxis", {}))
yaxis.update(**kwargs.get("yaxis", {}))
zaxis.update(**kwargs.get("zaxis", {}))
current_layout.update(
{"xaxis": xaxis, "yaxis": yaxis, "zaxis": zaxis, "aspectmode": "cube"}
)
return fig
def _gen_fig_with_subplots(
batch_size: int, in_subplots: bool, ncols: int, subplot_titles: Optional[list]
):