mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-03 04:12:48 +08:00
RayBundle visualization
Summary: Extends plotly_vis to visualize `RayBundle`s. Reviewed By: patricklabatut Differential Revision: D29014098 fbshipit-source-id: 4dee426510a1fa53d4afefbe1bcdd003684c9932
This commit is contained in:
parent
62ff77b49a
commit
4426a9d12c
@ -11,12 +11,22 @@ import numpy as np
|
|||||||
import plotly.graph_objects as go
|
import plotly.graph_objects as go
|
||||||
import torch
|
import torch
|
||||||
from plotly.subplots import make_subplots
|
from plotly.subplots import make_subplots
|
||||||
from pytorch3d.renderer import TexturesVertex
|
from pytorch3d.renderer import TexturesVertex, RayBundle, ray_bundle_to_ray_points
|
||||||
from pytorch3d.renderer.camera_utils import camera_to_eye_at_up
|
from pytorch3d.renderer.camera_utils import camera_to_eye_at_up
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
from pytorch3d.structures import Meshes, Pointclouds, join_meshes_as_scene
|
from pytorch3d.structures import Meshes, Pointclouds, join_meshes_as_scene
|
||||||
|
|
||||||
|
|
||||||
|
Struct = Union[CamerasBase, Meshes, Pointclouds, RayBundle]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_struct_len(struct: Struct): # pragma: no cover
|
||||||
|
"""
|
||||||
|
Returns the length (usually corresponds to the batch size) of the input structure.
|
||||||
|
"""
|
||||||
|
return len(struct.directions) if isinstance(struct, RayBundle) else len(struct)
|
||||||
|
|
||||||
|
|
||||||
def get_camera_wireframe(scale: float = 0.3): # pragma: no cover
|
def get_camera_wireframe(scale: float = 0.3): # pragma: no cover
|
||||||
"""
|
"""
|
||||||
Returns a wireframe of a 3D line-plot of a camera symbol.
|
Returns a wireframe of a 3D line-plot of a camera symbol.
|
||||||
@ -55,18 +65,22 @@ class Lighting(NamedTuple): # pragma: no cover
|
|||||||
|
|
||||||
|
|
||||||
def plot_scene(
|
def plot_scene(
|
||||||
plots: Dict[str, Dict[str, Union[Pointclouds, Meshes, CamerasBase]]],
|
plots: Dict[str, Dict[str, Struct]],
|
||||||
*,
|
*,
|
||||||
viewpoint_cameras: Optional[CamerasBase] = None,
|
viewpoint_cameras: Optional[CamerasBase] = None,
|
||||||
ncols: int = 1,
|
ncols: int = 1,
|
||||||
camera_scale: float = 0.3,
|
camera_scale: float = 0.3,
|
||||||
pointcloud_max_points: int = 20000,
|
pointcloud_max_points: int = 20000,
|
||||||
pointcloud_marker_size: int = 1,
|
pointcloud_marker_size: int = 1,
|
||||||
|
raybundle_max_rays: int = 20000,
|
||||||
|
raybundle_max_points_per_ray: int = 1000,
|
||||||
|
raybundle_ray_point_marker_size: int = 1,
|
||||||
|
raybundle_ray_line_width: int = 1,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
): # pragma: no cover
|
): # pragma: no cover
|
||||||
"""
|
"""
|
||||||
Main function to visualize Meshes, Cameras and Pointclouds.
|
Main function to visualize Cameras, Meshes, Pointclouds, and RayBundle.
|
||||||
Plots input Pointclouds, Meshes, and Cameras data into named subplots,
|
Plots input Cameras, Meshes, Pointclouds, and RayBundle data into named subplots,
|
||||||
with named traces based on the dictionary keys. Cameras are
|
with named traces based on the dictionary keys. Cameras are
|
||||||
rendered at the camera center location using a wireframe.
|
rendered at the camera center location using a wireframe.
|
||||||
|
|
||||||
@ -87,6 +101,13 @@ def plot_scene(
|
|||||||
pointcloud_max_points is used.
|
pointcloud_max_points is used.
|
||||||
pointcloud_marker_size: the size of the points rendered by plotly
|
pointcloud_marker_size: the size of the points rendered by plotly
|
||||||
when plotting a pointcloud.
|
when plotting a pointcloud.
|
||||||
|
raybundle_max_rays: maximum number of rays of a RayBundle to visualize. Randomly
|
||||||
|
subsamples without replacement in case the number of rays is bigger than max_rays.
|
||||||
|
raybundle_max_points_per_ray: the maximum number of points per ray in RayBundle
|
||||||
|
to visualize. If more are present, a random sample of size
|
||||||
|
max_points_per_ray is used.
|
||||||
|
raybundle_ray_point_marker_size: the size of the ray points of a plotted RayBundle
|
||||||
|
raybundle_ray_line_width: the width of the plotted rays of a RayBundle
|
||||||
**kwargs: Accepts lighting (a Lighting object) and any of the args xaxis,
|
**kwargs: Accepts lighting (a Lighting object) and any of the args xaxis,
|
||||||
yaxis and zaxis which Plotly's scene accepts. Accepts axis_args,
|
yaxis and zaxis which Plotly's scene accepts. Accepts axis_args,
|
||||||
which is an AxisArgs object that is applied to all 3 axes.
|
which is an AxisArgs object that is applied to all 3 axes.
|
||||||
@ -186,6 +207,18 @@ def plot_scene(
|
|||||||
The above example will render one subplot with the mesh object
|
The above example will render one subplot with the mesh object
|
||||||
and two cameras.
|
and two cameras.
|
||||||
|
|
||||||
|
RayBundle visualization is also supproted:
|
||||||
|
..code-block::python
|
||||||
|
cameras = PerspectiveCameras(...)
|
||||||
|
ray_bundle = RayBundle(origins=..., lengths=..., directions=..., xys=...)
|
||||||
|
fig = plot_scene({
|
||||||
|
"subplot1_title": {
|
||||||
|
"ray_bundle_trace_title": ray_bundle,
|
||||||
|
"cameras_trace_title": cameras,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
fig.show()
|
||||||
|
|
||||||
For an example of using kwargs, see below:
|
For an example of using kwargs, see below:
|
||||||
..code-block::python
|
..code-block::python
|
||||||
mesh = ...
|
mesh = ...
|
||||||
@ -264,11 +297,22 @@ def plot_scene(
|
|||||||
_add_camera_trace(
|
_add_camera_trace(
|
||||||
fig, struct, trace_name, subplot_idx, ncols, camera_scale
|
fig, struct, trace_name, subplot_idx, ncols, camera_scale
|
||||||
)
|
)
|
||||||
|
elif isinstance(struct, RayBundle):
|
||||||
|
_add_ray_bundle_trace(
|
||||||
|
fig,
|
||||||
|
struct,
|
||||||
|
trace_name,
|
||||||
|
subplot_idx,
|
||||||
|
ncols,
|
||||||
|
raybundle_max_rays,
|
||||||
|
raybundle_max_points_per_ray,
|
||||||
|
raybundle_ray_point_marker_size,
|
||||||
|
raybundle_ray_line_width,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"struct {} is not a Cameras, Meshes or Pointclouds object".format(
|
"struct {} is not a Cameras, Meshes, Pointclouds,".format(struct)
|
||||||
struct
|
+ " or RayBundle object."
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure update for every subplot.
|
# Ensure update for every subplot.
|
||||||
@ -329,7 +373,8 @@ def plot_scene(
|
|||||||
|
|
||||||
def plot_batch_individually(
|
def plot_batch_individually(
|
||||||
batched_structs: Union[
|
batched_structs: Union[
|
||||||
List[Union[Meshes, Pointclouds, CamerasBase]], Meshes, Pointclouds, CamerasBase
|
List[Struct],
|
||||||
|
Struct,
|
||||||
],
|
],
|
||||||
*,
|
*,
|
||||||
viewpoint_cameras: Optional[CamerasBase] = None,
|
viewpoint_cameras: Optional[CamerasBase] = None,
|
||||||
@ -340,26 +385,27 @@ def plot_batch_individually(
|
|||||||
): # pragma: no cover
|
): # pragma: no cover
|
||||||
"""
|
"""
|
||||||
This is a higher level plotting function than plot_scene, for plotting
|
This is a higher level plotting function than plot_scene, for plotting
|
||||||
Cameras, Meshes and Pointclouds in simple cases. The simplest use is to plot a
|
Cameras, Meshes, Pointclouds, and RayBundle in simple cases. The simplest use
|
||||||
single Cameras, Meshes or Pointclouds object, where you just pass it in as a
|
is to plot a single Cameras, Meshes, Pointclouds, or a RayBundle object,
|
||||||
one element list. This will plot each batch element in a separate subplot.
|
where you just pass it in as a one element list. This will plot each batch
|
||||||
|
element in a separate subplot.
|
||||||
|
|
||||||
More generally, you can supply multiple Cameras, Meshes or Pointclouds
|
More generally, you can supply multiple Cameras, Meshes, Pointclouds, or RayBundle
|
||||||
having the same batch size `n`. In this case, there will be `n` subplots,
|
having the same batch size `n`. In this case, there will be `n` subplots,
|
||||||
each depicting the corresponding batch element of all the inputs.
|
each depicting the corresponding batch element of all the inputs.
|
||||||
|
|
||||||
In addition, you can include Cameras, Meshes and Pointclouds of size 1 in
|
In addition, you can include Cameras, Meshes, Pointclouds, or RayBundle of size 1 in
|
||||||
the input. These will either be rendered in the first subplot
|
the input. These will either be rendered in the first subplot
|
||||||
(if extend_struct is False), or in every subplot.
|
(if extend_struct is False), or in every subplot.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batched_structs: a list of Cameras, Meshes and/or Pointclouds to be rendered.
|
batched_structs: a list of Cameras, Meshes, Pointclouds, and RayBundle
|
||||||
Each structure's corresponding batch element will be plotted in
|
to be rendered. Each structure's corresponding batch element will be
|
||||||
a single subplot, resulting in n subplots for a batch of size n.
|
plotted in a single subplot, resulting in n subplots for a batch of size n.
|
||||||
Every struct should either have the same batch size or be of batch size 1.
|
Every struct should either have the same batch size or be of batch size 1.
|
||||||
See extend_struct and the description above for how batch size 1 structs
|
See extend_struct and the description above for how batch size 1 structs
|
||||||
are handled. Also accepts a single Cameras, Meshes or Pointclouds object,
|
are handled. Also accepts a single Cameras, Meshes, Pointclouds, and RayBundle
|
||||||
which will have each individual element plotted in its own subplot.
|
object, which will have each individual element plotted in its own subplot.
|
||||||
viewpoint_cameras: an instance of a Cameras object providing a location
|
viewpoint_cameras: an instance of a Cameras object providing a location
|
||||||
to view the plotly plot from. If the batch size is equal
|
to view the plotly plot from. If the batch size is equal
|
||||||
to the number of subplots, it is a one to one mapping.
|
to the number of subplots, it is a one to one mapping.
|
||||||
@ -407,13 +453,14 @@ def plot_batch_individually(
|
|||||||
return
|
return
|
||||||
max_size = 0
|
max_size = 0
|
||||||
if isinstance(batched_structs, list):
|
if isinstance(batched_structs, list):
|
||||||
max_size = max(len(s) for s in batched_structs)
|
max_size = max(_get_struct_len(s) for s in batched_structs)
|
||||||
for struct in batched_structs:
|
for struct in batched_structs:
|
||||||
if len(struct) not in (1, max_size):
|
struct_len = _get_struct_len(struct)
|
||||||
msg = "invalid batch size {} provided: {}".format(len(struct), struct)
|
if struct_len not in (1, max_size):
|
||||||
|
msg = "invalid batch size {} provided: {}".format(struct_len, struct)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
else:
|
else:
|
||||||
max_size = len(batched_structs)
|
max_size = _get_struct_len(batched_structs)
|
||||||
|
|
||||||
if max_size == 0:
|
if max_size == 0:
|
||||||
msg = "No data is provided with at least one element"
|
msg = "No data is provided with at least one element"
|
||||||
@ -437,7 +484,8 @@ def plot_batch_individually(
|
|||||||
if isinstance(batched_structs, list):
|
if isinstance(batched_structs, list):
|
||||||
for i, batched_struct in enumerate(batched_structs):
|
for i, batched_struct in enumerate(batched_structs):
|
||||||
# check for whether this struct needs to be extended
|
# check for whether this struct needs to be extended
|
||||||
if i >= len(batched_struct) and not extend_struct:
|
batched_struct_len = _get_struct_len(batched_struct)
|
||||||
|
if i >= batched_struct_len and not extend_struct:
|
||||||
continue
|
continue
|
||||||
_add_struct_from_batch(
|
_add_struct_from_batch(
|
||||||
batched_struct, scene_num, subplot_title, scene_dictionary, i + 1
|
batched_struct, scene_num, subplot_title, scene_dictionary, i + 1
|
||||||
@ -453,10 +501,10 @@ def plot_batch_individually(
|
|||||||
|
|
||||||
|
|
||||||
def _add_struct_from_batch(
|
def _add_struct_from_batch(
|
||||||
batched_struct: Union[CamerasBase, Meshes, Pointclouds],
|
batched_struct: Struct,
|
||||||
scene_num: int,
|
scene_num: int,
|
||||||
subplot_title: str,
|
subplot_title: str,
|
||||||
scene_dictionary: Dict[str, Dict[str, Union[CamerasBase, Meshes, Pointclouds]]],
|
scene_dictionary: Dict[str, Dict[str, Struct]],
|
||||||
trace_idx: int = 1,
|
trace_idx: int = 1,
|
||||||
): # pragma: no cover
|
): # pragma: no cover
|
||||||
"""
|
"""
|
||||||
@ -492,6 +540,15 @@ def _add_struct_from_batch(
|
|||||||
# torch.Tensor, torch.nn.Module]` is not a function.
|
# torch.Tensor, torch.nn.Module]` is not a function.
|
||||||
T = T[t_idx].unsqueeze(0)
|
T = T[t_idx].unsqueeze(0)
|
||||||
struct = CamerasBase(device=batched_struct.device, R=R, T=T)
|
struct = CamerasBase(device=batched_struct.device, R=R, T=T)
|
||||||
|
elif isinstance(batched_struct, RayBundle):
|
||||||
|
# for RayBundle we treat the 1st dim as the batch index
|
||||||
|
struct_idx = min(scene_num, len(batched_struct.lengths) - 1)
|
||||||
|
struct = RayBundle(
|
||||||
|
**{
|
||||||
|
attr: getattr(batched_struct, attr)[struct_idx]
|
||||||
|
for attr in ["origins", "directions", "lengths", "xys"]
|
||||||
|
}
|
||||||
|
)
|
||||||
else: # batched meshes and pointclouds are indexable
|
else: # batched meshes and pointclouds are indexable
|
||||||
struct_idx = min(scene_num, len(batched_struct) - 1)
|
struct_idx = min(scene_num, len(batched_struct) - 1)
|
||||||
struct = batched_struct[struct_idx]
|
struct = batched_struct[struct_idx]
|
||||||
@ -702,6 +759,138 @@ def _add_camera_trace(
|
|||||||
_update_axes_bounds(verts_center, max_expand, current_layout)
|
_update_axes_bounds(verts_center, max_expand, current_layout)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_ray_bundle_trace(
|
||||||
|
fig: go.Figure,
|
||||||
|
ray_bundle: RayBundle,
|
||||||
|
trace_name: str,
|
||||||
|
subplot_idx: int,
|
||||||
|
ncols: int,
|
||||||
|
max_rays: int,
|
||||||
|
max_points_per_ray: int,
|
||||||
|
marker_size: int,
|
||||||
|
line_width: int,
|
||||||
|
): # pragma: no cover
|
||||||
|
"""
|
||||||
|
Adds a trace rendering a RayBundle object to the passed in figure, with
|
||||||
|
a given name and in a specific subplot.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fig: plotly figure to add the trace within.
|
||||||
|
cameras: the Cameras object to render. It can be batched.
|
||||||
|
trace_name: name to label the trace with.
|
||||||
|
subplot_idx: identifies the subplot, with 0 being the top left.
|
||||||
|
ncols: the number of subplots per row.
|
||||||
|
max_rays: maximum number of plotted rays in total. Randomly subsamples
|
||||||
|
without replacement in case the number of rays is bigger than max_rays.
|
||||||
|
max_points_per_ray: maximum number of points plotted per ray.
|
||||||
|
marker_size: the size of the ray point markers.
|
||||||
|
line_width: the width of the ray lines.
|
||||||
|
"""
|
||||||
|
|
||||||
|
n_pts_per_ray = ray_bundle.lengths.shape[-1]
|
||||||
|
n_rays = ray_bundle.lengths.shape[:-1].numel() # pyre-ignore[16]
|
||||||
|
|
||||||
|
# flatten all batches of rays into a single big bundle
|
||||||
|
ray_bundle_flat = RayBundle(
|
||||||
|
**{
|
||||||
|
attr: torch.flatten(getattr(ray_bundle, attr), start_dim=0, end_dim=-2)
|
||||||
|
for attr in ["origins", "directions", "lengths", "xys"]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# subsample the rays (if needed)
|
||||||
|
if n_rays > max_rays:
|
||||||
|
indices_rays = torch.randperm(n_rays)[:max_rays]
|
||||||
|
ray_bundle_flat = RayBundle(
|
||||||
|
**{
|
||||||
|
attr: getattr(ray_bundle_flat, attr)[indices_rays]
|
||||||
|
for attr in ["origins", "directions", "lengths", "xys"]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# make ray line endpoints
|
||||||
|
min_max_ray_depth = torch.stack(
|
||||||
|
[
|
||||||
|
ray_bundle_flat.lengths.min(dim=1).values, # pyre-ignore[16]
|
||||||
|
ray_bundle_flat.lengths.max(dim=1).values,
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
ray_lines_endpoints = ray_bundle_to_ray_points(
|
||||||
|
ray_bundle_flat._replace(lengths=min_max_ray_depth)
|
||||||
|
)
|
||||||
|
|
||||||
|
# make the ray lines for plotly plotting
|
||||||
|
nan_tensor = torch.Tensor([[float("NaN")] * 3])
|
||||||
|
ray_lines = torch.empty(size=(1, 3))
|
||||||
|
for ray_line in ray_lines_endpoints:
|
||||||
|
# We combine the ray lines into a single tensor to plot them in a
|
||||||
|
# single trace. The NaNs are inserted between sets of ray lines
|
||||||
|
# so that the lines drawn by Plotly are not drawn between
|
||||||
|
# lines that belong to different rays.
|
||||||
|
ray_lines = torch.cat((ray_lines, nan_tensor, ray_line))
|
||||||
|
x, y, z = ray_lines.detach().cpu().numpy().T.astype(float)
|
||||||
|
row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter3d(
|
||||||
|
x=x,
|
||||||
|
y=y,
|
||||||
|
z=z,
|
||||||
|
marker={"size": 0.1},
|
||||||
|
line={"width": line_width},
|
||||||
|
name=trace_name,
|
||||||
|
),
|
||||||
|
row=row,
|
||||||
|
col=col,
|
||||||
|
)
|
||||||
|
|
||||||
|
# subsample the ray points (if needed)
|
||||||
|
if n_pts_per_ray > max_points_per_ray:
|
||||||
|
indices_ray_pts = torch.cat(
|
||||||
|
[
|
||||||
|
torch.randperm(n_pts_per_ray)[:max_points_per_ray] + ri * n_pts_per_ray
|
||||||
|
for ri in range(ray_bundle_flat.lengths.shape[0])
|
||||||
|
]
|
||||||
|
)
|
||||||
|
ray_bundle_flat = ray_bundle_flat._replace(
|
||||||
|
lengths=ray_bundle_flat.lengths.reshape(-1)[indices_ray_pts].reshape(
|
||||||
|
ray_bundle_flat.lengths.shape[0], -1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# plot the ray points
|
||||||
|
ray_points = (
|
||||||
|
ray_bundle_to_ray_points(ray_bundle_flat)
|
||||||
|
.view(-1, 3)
|
||||||
|
.detach()
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
.astype(float)
|
||||||
|
)
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter3d(
|
||||||
|
x=ray_points[:, 0],
|
||||||
|
y=ray_points[:, 1],
|
||||||
|
z=ray_points[:, 2],
|
||||||
|
mode="markers",
|
||||||
|
name=trace_name + "_points",
|
||||||
|
marker={"size": marker_size},
|
||||||
|
),
|
||||||
|
row=row,
|
||||||
|
col=col,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Access the current subplot's scene configuration
|
||||||
|
plot_scene = "scene" + str(subplot_idx + 1)
|
||||||
|
current_layout = fig["layout"][plot_scene]
|
||||||
|
|
||||||
|
# update the bounds of the axes for the current trace
|
||||||
|
all_ray_points = ray_bundle_to_ray_points(ray_bundle).view(-1, 3)
|
||||||
|
ray_points_center = all_ray_points.mean(dim=0)
|
||||||
|
max_expand = (all_ray_points.max(0)[0] - all_ray_points.min(0)[0]).max().item()
|
||||||
|
_update_axes_bounds(ray_points_center, float(max_expand), current_layout)
|
||||||
|
|
||||||
|
|
||||||
def _gen_fig_with_subplots(
|
def _gen_fig_with_subplots(
|
||||||
batch_size: int, ncols: int, subplot_titles: List[str]
|
batch_size: int, ncols: int, subplot_titles: List[str]
|
||||||
): # pragma: no cover
|
): # pragma: no cover
|
||||||
|
Loading…
x
Reference in New Issue
Block a user