RayBundle visualization

Summary: Extends plotly_vis to visualize `RayBundle`s.

Reviewed By: patricklabatut

Differential Revision: D29014098

fbshipit-source-id: 4dee426510a1fa53d4afefbe1bcdd003684c9932
This commit is contained in:
David Novotny 2021-07-01 17:30:20 -07:00 committed by Facebook GitHub Bot
parent 62ff77b49a
commit 4426a9d12c

View File

@ -11,12 +11,22 @@ import numpy as np
import plotly.graph_objects as go import plotly.graph_objects as go
import torch import torch
from plotly.subplots import make_subplots from plotly.subplots import make_subplots
from pytorch3d.renderer import TexturesVertex from pytorch3d.renderer import TexturesVertex, RayBundle, ray_bundle_to_ray_points
from pytorch3d.renderer.camera_utils import camera_to_eye_at_up from pytorch3d.renderer.camera_utils import camera_to_eye_at_up
from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.structures import Meshes, Pointclouds, join_meshes_as_scene from pytorch3d.structures import Meshes, Pointclouds, join_meshes_as_scene
Struct = Union[CamerasBase, Meshes, Pointclouds, RayBundle]
def _get_struct_len(struct: Struct): # pragma: no cover
"""
Returns the length (usually corresponds to the batch size) of the input structure.
"""
return len(struct.directions) if isinstance(struct, RayBundle) else len(struct)
def get_camera_wireframe(scale: float = 0.3): # pragma: no cover def get_camera_wireframe(scale: float = 0.3): # pragma: no cover
""" """
Returns a wireframe of a 3D line-plot of a camera symbol. Returns a wireframe of a 3D line-plot of a camera symbol.
@ -55,18 +65,22 @@ class Lighting(NamedTuple): # pragma: no cover
def plot_scene( def plot_scene(
plots: Dict[str, Dict[str, Union[Pointclouds, Meshes, CamerasBase]]], plots: Dict[str, Dict[str, Struct]],
*, *,
viewpoint_cameras: Optional[CamerasBase] = None, viewpoint_cameras: Optional[CamerasBase] = None,
ncols: int = 1, ncols: int = 1,
camera_scale: float = 0.3, camera_scale: float = 0.3,
pointcloud_max_points: int = 20000, pointcloud_max_points: int = 20000,
pointcloud_marker_size: int = 1, pointcloud_marker_size: int = 1,
raybundle_max_rays: int = 20000,
raybundle_max_points_per_ray: int = 1000,
raybundle_ray_point_marker_size: int = 1,
raybundle_ray_line_width: int = 1,
**kwargs, **kwargs,
): # pragma: no cover ): # pragma: no cover
""" """
Main function to visualize Meshes, Cameras and Pointclouds. Main function to visualize Cameras, Meshes, Pointclouds, and RayBundle.
Plots input Pointclouds, Meshes, and Cameras data into named subplots, Plots input Cameras, Meshes, Pointclouds, and RayBundle data into named subplots,
with named traces based on the dictionary keys. Cameras are with named traces based on the dictionary keys. Cameras are
rendered at the camera center location using a wireframe. rendered at the camera center location using a wireframe.
@ -87,6 +101,13 @@ def plot_scene(
pointcloud_max_points is used. pointcloud_max_points is used.
pointcloud_marker_size: the size of the points rendered by plotly pointcloud_marker_size: the size of the points rendered by plotly
when plotting a pointcloud. when plotting a pointcloud.
raybundle_max_rays: maximum number of rays of a RayBundle to visualize. Randomly
subsamples without replacement in case the number of rays is bigger than max_rays.
raybundle_max_points_per_ray: the maximum number of points per ray in RayBundle
to visualize. If more are present, a random sample of size
max_points_per_ray is used.
raybundle_ray_point_marker_size: the size of the ray points of a plotted RayBundle
raybundle_ray_line_width: the width of the plotted rays of a RayBundle
**kwargs: Accepts lighting (a Lighting object) and any of the args xaxis, **kwargs: Accepts lighting (a Lighting object) and any of the args xaxis,
yaxis and zaxis which Plotly's scene accepts. Accepts axis_args, yaxis and zaxis which Plotly's scene accepts. Accepts axis_args,
which is an AxisArgs object that is applied to all 3 axes. which is an AxisArgs object that is applied to all 3 axes.
@ -186,6 +207,18 @@ def plot_scene(
The above example will render one subplot with the mesh object The above example will render one subplot with the mesh object
and two cameras. and two cameras.
RayBundle visualization is also supproted:
..code-block::python
cameras = PerspectiveCameras(...)
ray_bundle = RayBundle(origins=..., lengths=..., directions=..., xys=...)
fig = plot_scene({
"subplot1_title": {
"ray_bundle_trace_title": ray_bundle,
"cameras_trace_title": cameras,
},
})
fig.show()
For an example of using kwargs, see below: For an example of using kwargs, see below:
..code-block::python ..code-block::python
mesh = ... mesh = ...
@ -264,11 +297,22 @@ def plot_scene(
_add_camera_trace( _add_camera_trace(
fig, struct, trace_name, subplot_idx, ncols, camera_scale fig, struct, trace_name, subplot_idx, ncols, camera_scale
) )
elif isinstance(struct, RayBundle):
_add_ray_bundle_trace(
fig,
struct,
trace_name,
subplot_idx,
ncols,
raybundle_max_rays,
raybundle_max_points_per_ray,
raybundle_ray_point_marker_size,
raybundle_ray_line_width,
)
else: else:
raise ValueError( raise ValueError(
"struct {} is not a Cameras, Meshes or Pointclouds object".format( "struct {} is not a Cameras, Meshes, Pointclouds,".format(struct)
struct + " or RayBundle object."
)
) )
# Ensure update for every subplot. # Ensure update for every subplot.
@ -329,7 +373,8 @@ def plot_scene(
def plot_batch_individually( def plot_batch_individually(
batched_structs: Union[ batched_structs: Union[
List[Union[Meshes, Pointclouds, CamerasBase]], Meshes, Pointclouds, CamerasBase List[Struct],
Struct,
], ],
*, *,
viewpoint_cameras: Optional[CamerasBase] = None, viewpoint_cameras: Optional[CamerasBase] = None,
@ -340,26 +385,27 @@ def plot_batch_individually(
): # pragma: no cover ): # pragma: no cover
""" """
This is a higher level plotting function than plot_scene, for plotting This is a higher level plotting function than plot_scene, for plotting
Cameras, Meshes and Pointclouds in simple cases. The simplest use is to plot a Cameras, Meshes, Pointclouds, and RayBundle in simple cases. The simplest use
single Cameras, Meshes or Pointclouds object, where you just pass it in as a is to plot a single Cameras, Meshes, Pointclouds, or a RayBundle object,
one element list. This will plot each batch element in a separate subplot. where you just pass it in as a one element list. This will plot each batch
element in a separate subplot.
More generally, you can supply multiple Cameras, Meshes or Pointclouds More generally, you can supply multiple Cameras, Meshes, Pointclouds, or RayBundle
having the same batch size `n`. In this case, there will be `n` subplots, having the same batch size `n`. In this case, there will be `n` subplots,
each depicting the corresponding batch element of all the inputs. each depicting the corresponding batch element of all the inputs.
In addition, you can include Cameras, Meshes and Pointclouds of size 1 in In addition, you can include Cameras, Meshes, Pointclouds, or RayBundle of size 1 in
the input. These will either be rendered in the first subplot the input. These will either be rendered in the first subplot
(if extend_struct is False), or in every subplot. (if extend_struct is False), or in every subplot.
Args: Args:
batched_structs: a list of Cameras, Meshes and/or Pointclouds to be rendered. batched_structs: a list of Cameras, Meshes, Pointclouds, and RayBundle
Each structure's corresponding batch element will be plotted in to be rendered. Each structure's corresponding batch element will be
a single subplot, resulting in n subplots for a batch of size n. plotted in a single subplot, resulting in n subplots for a batch of size n.
Every struct should either have the same batch size or be of batch size 1. Every struct should either have the same batch size or be of batch size 1.
See extend_struct and the description above for how batch size 1 structs See extend_struct and the description above for how batch size 1 structs
are handled. Also accepts a single Cameras, Meshes or Pointclouds object, are handled. Also accepts a single Cameras, Meshes, Pointclouds, and RayBundle
which will have each individual element plotted in its own subplot. object, which will have each individual element plotted in its own subplot.
viewpoint_cameras: an instance of a Cameras object providing a location viewpoint_cameras: an instance of a Cameras object providing a location
to view the plotly plot from. If the batch size is equal to view the plotly plot from. If the batch size is equal
to the number of subplots, it is a one to one mapping. to the number of subplots, it is a one to one mapping.
@ -407,13 +453,14 @@ def plot_batch_individually(
return return
max_size = 0 max_size = 0
if isinstance(batched_structs, list): if isinstance(batched_structs, list):
max_size = max(len(s) for s in batched_structs) max_size = max(_get_struct_len(s) for s in batched_structs)
for struct in batched_structs: for struct in batched_structs:
if len(struct) not in (1, max_size): struct_len = _get_struct_len(struct)
msg = "invalid batch size {} provided: {}".format(len(struct), struct) if struct_len not in (1, max_size):
msg = "invalid batch size {} provided: {}".format(struct_len, struct)
raise ValueError(msg) raise ValueError(msg)
else: else:
max_size = len(batched_structs) max_size = _get_struct_len(batched_structs)
if max_size == 0: if max_size == 0:
msg = "No data is provided with at least one element" msg = "No data is provided with at least one element"
@ -437,7 +484,8 @@ def plot_batch_individually(
if isinstance(batched_structs, list): if isinstance(batched_structs, list):
for i, batched_struct in enumerate(batched_structs): for i, batched_struct in enumerate(batched_structs):
# check for whether this struct needs to be extended # check for whether this struct needs to be extended
if i >= len(batched_struct) and not extend_struct: batched_struct_len = _get_struct_len(batched_struct)
if i >= batched_struct_len and not extend_struct:
continue continue
_add_struct_from_batch( _add_struct_from_batch(
batched_struct, scene_num, subplot_title, scene_dictionary, i + 1 batched_struct, scene_num, subplot_title, scene_dictionary, i + 1
@ -453,10 +501,10 @@ def plot_batch_individually(
def _add_struct_from_batch( def _add_struct_from_batch(
batched_struct: Union[CamerasBase, Meshes, Pointclouds], batched_struct: Struct,
scene_num: int, scene_num: int,
subplot_title: str, subplot_title: str,
scene_dictionary: Dict[str, Dict[str, Union[CamerasBase, Meshes, Pointclouds]]], scene_dictionary: Dict[str, Dict[str, Struct]],
trace_idx: int = 1, trace_idx: int = 1,
): # pragma: no cover ): # pragma: no cover
""" """
@ -492,6 +540,15 @@ def _add_struct_from_batch(
# torch.Tensor, torch.nn.Module]` is not a function. # torch.Tensor, torch.nn.Module]` is not a function.
T = T[t_idx].unsqueeze(0) T = T[t_idx].unsqueeze(0)
struct = CamerasBase(device=batched_struct.device, R=R, T=T) struct = CamerasBase(device=batched_struct.device, R=R, T=T)
elif isinstance(batched_struct, RayBundle):
# for RayBundle we treat the 1st dim as the batch index
struct_idx = min(scene_num, len(batched_struct.lengths) - 1)
struct = RayBundle(
**{
attr: getattr(batched_struct, attr)[struct_idx]
for attr in ["origins", "directions", "lengths", "xys"]
}
)
else: # batched meshes and pointclouds are indexable else: # batched meshes and pointclouds are indexable
struct_idx = min(scene_num, len(batched_struct) - 1) struct_idx = min(scene_num, len(batched_struct) - 1)
struct = batched_struct[struct_idx] struct = batched_struct[struct_idx]
@ -702,6 +759,138 @@ def _add_camera_trace(
_update_axes_bounds(verts_center, max_expand, current_layout) _update_axes_bounds(verts_center, max_expand, current_layout)
def _add_ray_bundle_trace(
fig: go.Figure,
ray_bundle: RayBundle,
trace_name: str,
subplot_idx: int,
ncols: int,
max_rays: int,
max_points_per_ray: int,
marker_size: int,
line_width: int,
): # pragma: no cover
"""
Adds a trace rendering a RayBundle object to the passed in figure, with
a given name and in a specific subplot.
Args:
fig: plotly figure to add the trace within.
cameras: the Cameras object to render. It can be batched.
trace_name: name to label the trace with.
subplot_idx: identifies the subplot, with 0 being the top left.
ncols: the number of subplots per row.
max_rays: maximum number of plotted rays in total. Randomly subsamples
without replacement in case the number of rays is bigger than max_rays.
max_points_per_ray: maximum number of points plotted per ray.
marker_size: the size of the ray point markers.
line_width: the width of the ray lines.
"""
n_pts_per_ray = ray_bundle.lengths.shape[-1]
n_rays = ray_bundle.lengths.shape[:-1].numel() # pyre-ignore[16]
# flatten all batches of rays into a single big bundle
ray_bundle_flat = RayBundle(
**{
attr: torch.flatten(getattr(ray_bundle, attr), start_dim=0, end_dim=-2)
for attr in ["origins", "directions", "lengths", "xys"]
}
)
# subsample the rays (if needed)
if n_rays > max_rays:
indices_rays = torch.randperm(n_rays)[:max_rays]
ray_bundle_flat = RayBundle(
**{
attr: getattr(ray_bundle_flat, attr)[indices_rays]
for attr in ["origins", "directions", "lengths", "xys"]
}
)
# make ray line endpoints
min_max_ray_depth = torch.stack(
[
ray_bundle_flat.lengths.min(dim=1).values, # pyre-ignore[16]
ray_bundle_flat.lengths.max(dim=1).values,
],
dim=-1,
)
ray_lines_endpoints = ray_bundle_to_ray_points(
ray_bundle_flat._replace(lengths=min_max_ray_depth)
)
# make the ray lines for plotly plotting
nan_tensor = torch.Tensor([[float("NaN")] * 3])
ray_lines = torch.empty(size=(1, 3))
for ray_line in ray_lines_endpoints:
# We combine the ray lines into a single tensor to plot them in a
# single trace. The NaNs are inserted between sets of ray lines
# so that the lines drawn by Plotly are not drawn between
# lines that belong to different rays.
ray_lines = torch.cat((ray_lines, nan_tensor, ray_line))
x, y, z = ray_lines.detach().cpu().numpy().T.astype(float)
row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1
fig.add_trace(
go.Scatter3d(
x=x,
y=y,
z=z,
marker={"size": 0.1},
line={"width": line_width},
name=trace_name,
),
row=row,
col=col,
)
# subsample the ray points (if needed)
if n_pts_per_ray > max_points_per_ray:
indices_ray_pts = torch.cat(
[
torch.randperm(n_pts_per_ray)[:max_points_per_ray] + ri * n_pts_per_ray
for ri in range(ray_bundle_flat.lengths.shape[0])
]
)
ray_bundle_flat = ray_bundle_flat._replace(
lengths=ray_bundle_flat.lengths.reshape(-1)[indices_ray_pts].reshape(
ray_bundle_flat.lengths.shape[0], -1
)
)
# plot the ray points
ray_points = (
ray_bundle_to_ray_points(ray_bundle_flat)
.view(-1, 3)
.detach()
.cpu()
.numpy()
.astype(float)
)
fig.add_trace(
go.Scatter3d(
x=ray_points[:, 0],
y=ray_points[:, 1],
z=ray_points[:, 2],
mode="markers",
name=trace_name + "_points",
marker={"size": marker_size},
),
row=row,
col=col,
)
# Access the current subplot's scene configuration
plot_scene = "scene" + str(subplot_idx + 1)
current_layout = fig["layout"][plot_scene]
# update the bounds of the axes for the current trace
all_ray_points = ray_bundle_to_ray_points(ray_bundle).view(-1, 3)
ray_points_center = all_ray_points.mean(dim=0)
max_expand = (all_ray_points.max(0)[0] - all_ray_points.min(0)[0]).max().item()
_update_axes_bounds(ray_points_center, float(max_expand), current_layout)
def _gen_fig_with_subplots( def _gen_fig_with_subplots(
batch_size: int, ncols: int, subplot_titles: List[str] batch_size: int, ncols: int, subplot_titles: List[str]
): # pragma: no cover ): # pragma: no cover