diff --git a/pytorch3d/vis/plotly_vis.py b/pytorch3d/vis/plotly_vis.py index ec9717b5..9e29aa1c 100644 --- a/pytorch3d/vis/plotly_vis.py +++ b/pytorch3d/vis/plotly_vis.py @@ -11,12 +11,22 @@ import numpy as np import plotly.graph_objects as go import torch from plotly.subplots import make_subplots -from pytorch3d.renderer import TexturesVertex +from pytorch3d.renderer import TexturesVertex, RayBundle, ray_bundle_to_ray_points from pytorch3d.renderer.camera_utils import camera_to_eye_at_up from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.structures import Meshes, Pointclouds, join_meshes_as_scene +Struct = Union[CamerasBase, Meshes, Pointclouds, RayBundle] + + +def _get_struct_len(struct: Struct): # pragma: no cover + """ + Returns the length (usually corresponds to the batch size) of the input structure. + """ + return len(struct.directions) if isinstance(struct, RayBundle) else len(struct) + + def get_camera_wireframe(scale: float = 0.3): # pragma: no cover """ Returns a wireframe of a 3D line-plot of a camera symbol. @@ -55,18 +65,22 @@ class Lighting(NamedTuple): # pragma: no cover def plot_scene( - plots: Dict[str, Dict[str, Union[Pointclouds, Meshes, CamerasBase]]], + plots: Dict[str, Dict[str, Struct]], *, viewpoint_cameras: Optional[CamerasBase] = None, ncols: int = 1, camera_scale: float = 0.3, pointcloud_max_points: int = 20000, pointcloud_marker_size: int = 1, + raybundle_max_rays: int = 20000, + raybundle_max_points_per_ray: int = 1000, + raybundle_ray_point_marker_size: int = 1, + raybundle_ray_line_width: int = 1, **kwargs, ): # pragma: no cover """ - Main function to visualize Meshes, Cameras and Pointclouds. - Plots input Pointclouds, Meshes, and Cameras data into named subplots, + Main function to visualize Cameras, Meshes, Pointclouds, and RayBundle. + Plots input Cameras, Meshes, Pointclouds, and RayBundle data into named subplots, with named traces based on the dictionary keys. Cameras are rendered at the camera center location using a wireframe. @@ -87,6 +101,13 @@ def plot_scene( pointcloud_max_points is used. pointcloud_marker_size: the size of the points rendered by plotly when plotting a pointcloud. + raybundle_max_rays: maximum number of rays of a RayBundle to visualize. Randomly + subsamples without replacement in case the number of rays is bigger than max_rays. + raybundle_max_points_per_ray: the maximum number of points per ray in RayBundle + to visualize. If more are present, a random sample of size + max_points_per_ray is used. + raybundle_ray_point_marker_size: the size of the ray points of a plotted RayBundle + raybundle_ray_line_width: the width of the plotted rays of a RayBundle **kwargs: Accepts lighting (a Lighting object) and any of the args xaxis, yaxis and zaxis which Plotly's scene accepts. Accepts axis_args, which is an AxisArgs object that is applied to all 3 axes. @@ -186,6 +207,18 @@ def plot_scene( The above example will render one subplot with the mesh object and two cameras. + RayBundle visualization is also supproted: + ..code-block::python + cameras = PerspectiveCameras(...) + ray_bundle = RayBundle(origins=..., lengths=..., directions=..., xys=...) + fig = plot_scene({ + "subplot1_title": { + "ray_bundle_trace_title": ray_bundle, + "cameras_trace_title": cameras, + }, + }) + fig.show() + For an example of using kwargs, see below: ..code-block::python mesh = ... @@ -264,11 +297,22 @@ def plot_scene( _add_camera_trace( fig, struct, trace_name, subplot_idx, ncols, camera_scale ) + elif isinstance(struct, RayBundle): + _add_ray_bundle_trace( + fig, + struct, + trace_name, + subplot_idx, + ncols, + raybundle_max_rays, + raybundle_max_points_per_ray, + raybundle_ray_point_marker_size, + raybundle_ray_line_width, + ) else: raise ValueError( - "struct {} is not a Cameras, Meshes or Pointclouds object".format( - struct - ) + "struct {} is not a Cameras, Meshes, Pointclouds,".format(struct) + + " or RayBundle object." ) # Ensure update for every subplot. @@ -329,7 +373,8 @@ def plot_scene( def plot_batch_individually( batched_structs: Union[ - List[Union[Meshes, Pointclouds, CamerasBase]], Meshes, Pointclouds, CamerasBase + List[Struct], + Struct, ], *, viewpoint_cameras: Optional[CamerasBase] = None, @@ -340,26 +385,27 @@ def plot_batch_individually( ): # pragma: no cover """ This is a higher level plotting function than plot_scene, for plotting - Cameras, Meshes and Pointclouds in simple cases. The simplest use is to plot a - single Cameras, Meshes or Pointclouds object, where you just pass it in as a - one element list. This will plot each batch element in a separate subplot. + Cameras, Meshes, Pointclouds, and RayBundle in simple cases. The simplest use + is to plot a single Cameras, Meshes, Pointclouds, or a RayBundle object, + where you just pass it in as a one element list. This will plot each batch + element in a separate subplot. - More generally, you can supply multiple Cameras, Meshes or Pointclouds + More generally, you can supply multiple Cameras, Meshes, Pointclouds, or RayBundle having the same batch size `n`. In this case, there will be `n` subplots, each depicting the corresponding batch element of all the inputs. - In addition, you can include Cameras, Meshes and Pointclouds of size 1 in + In addition, you can include Cameras, Meshes, Pointclouds, or RayBundle of size 1 in the input. These will either be rendered in the first subplot (if extend_struct is False), or in every subplot. Args: - batched_structs: a list of Cameras, Meshes and/or Pointclouds to be rendered. - Each structure's corresponding batch element will be plotted in - a single subplot, resulting in n subplots for a batch of size n. + batched_structs: a list of Cameras, Meshes, Pointclouds, and RayBundle + to be rendered. Each structure's corresponding batch element will be + plotted in a single subplot, resulting in n subplots for a batch of size n. Every struct should either have the same batch size or be of batch size 1. See extend_struct and the description above for how batch size 1 structs - are handled. Also accepts a single Cameras, Meshes or Pointclouds object, - which will have each individual element plotted in its own subplot. + are handled. Also accepts a single Cameras, Meshes, Pointclouds, and RayBundle + object, which will have each individual element plotted in its own subplot. viewpoint_cameras: an instance of a Cameras object providing a location to view the plotly plot from. If the batch size is equal to the number of subplots, it is a one to one mapping. @@ -407,13 +453,14 @@ def plot_batch_individually( return max_size = 0 if isinstance(batched_structs, list): - max_size = max(len(s) for s in batched_structs) + max_size = max(_get_struct_len(s) for s in batched_structs) for struct in batched_structs: - if len(struct) not in (1, max_size): - msg = "invalid batch size {} provided: {}".format(len(struct), struct) + struct_len = _get_struct_len(struct) + if struct_len not in (1, max_size): + msg = "invalid batch size {} provided: {}".format(struct_len, struct) raise ValueError(msg) else: - max_size = len(batched_structs) + max_size = _get_struct_len(batched_structs) if max_size == 0: msg = "No data is provided with at least one element" @@ -437,7 +484,8 @@ def plot_batch_individually( if isinstance(batched_structs, list): for i, batched_struct in enumerate(batched_structs): # check for whether this struct needs to be extended - if i >= len(batched_struct) and not extend_struct: + batched_struct_len = _get_struct_len(batched_struct) + if i >= batched_struct_len and not extend_struct: continue _add_struct_from_batch( batched_struct, scene_num, subplot_title, scene_dictionary, i + 1 @@ -453,10 +501,10 @@ def plot_batch_individually( def _add_struct_from_batch( - batched_struct: Union[CamerasBase, Meshes, Pointclouds], + batched_struct: Struct, scene_num: int, subplot_title: str, - scene_dictionary: Dict[str, Dict[str, Union[CamerasBase, Meshes, Pointclouds]]], + scene_dictionary: Dict[str, Dict[str, Struct]], trace_idx: int = 1, ): # pragma: no cover """ @@ -492,6 +540,15 @@ def _add_struct_from_batch( # torch.Tensor, torch.nn.Module]` is not a function. T = T[t_idx].unsqueeze(0) struct = CamerasBase(device=batched_struct.device, R=R, T=T) + elif isinstance(batched_struct, RayBundle): + # for RayBundle we treat the 1st dim as the batch index + struct_idx = min(scene_num, len(batched_struct.lengths) - 1) + struct = RayBundle( + **{ + attr: getattr(batched_struct, attr)[struct_idx] + for attr in ["origins", "directions", "lengths", "xys"] + } + ) else: # batched meshes and pointclouds are indexable struct_idx = min(scene_num, len(batched_struct) - 1) struct = batched_struct[struct_idx] @@ -702,6 +759,138 @@ def _add_camera_trace( _update_axes_bounds(verts_center, max_expand, current_layout) +def _add_ray_bundle_trace( + fig: go.Figure, + ray_bundle: RayBundle, + trace_name: str, + subplot_idx: int, + ncols: int, + max_rays: int, + max_points_per_ray: int, + marker_size: int, + line_width: int, +): # pragma: no cover + """ + Adds a trace rendering a RayBundle object to the passed in figure, with + a given name and in a specific subplot. + + Args: + fig: plotly figure to add the trace within. + cameras: the Cameras object to render. It can be batched. + trace_name: name to label the trace with. + subplot_idx: identifies the subplot, with 0 being the top left. + ncols: the number of subplots per row. + max_rays: maximum number of plotted rays in total. Randomly subsamples + without replacement in case the number of rays is bigger than max_rays. + max_points_per_ray: maximum number of points plotted per ray. + marker_size: the size of the ray point markers. + line_width: the width of the ray lines. + """ + + n_pts_per_ray = ray_bundle.lengths.shape[-1] + n_rays = ray_bundle.lengths.shape[:-1].numel() # pyre-ignore[16] + + # flatten all batches of rays into a single big bundle + ray_bundle_flat = RayBundle( + **{ + attr: torch.flatten(getattr(ray_bundle, attr), start_dim=0, end_dim=-2) + for attr in ["origins", "directions", "lengths", "xys"] + } + ) + + # subsample the rays (if needed) + if n_rays > max_rays: + indices_rays = torch.randperm(n_rays)[:max_rays] + ray_bundle_flat = RayBundle( + **{ + attr: getattr(ray_bundle_flat, attr)[indices_rays] + for attr in ["origins", "directions", "lengths", "xys"] + } + ) + + # make ray line endpoints + min_max_ray_depth = torch.stack( + [ + ray_bundle_flat.lengths.min(dim=1).values, # pyre-ignore[16] + ray_bundle_flat.lengths.max(dim=1).values, + ], + dim=-1, + ) + ray_lines_endpoints = ray_bundle_to_ray_points( + ray_bundle_flat._replace(lengths=min_max_ray_depth) + ) + + # make the ray lines for plotly plotting + nan_tensor = torch.Tensor([[float("NaN")] * 3]) + ray_lines = torch.empty(size=(1, 3)) + for ray_line in ray_lines_endpoints: + # We combine the ray lines into a single tensor to plot them in a + # single trace. The NaNs are inserted between sets of ray lines + # so that the lines drawn by Plotly are not drawn between + # lines that belong to different rays. + ray_lines = torch.cat((ray_lines, nan_tensor, ray_line)) + x, y, z = ray_lines.detach().cpu().numpy().T.astype(float) + row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1 + fig.add_trace( + go.Scatter3d( + x=x, + y=y, + z=z, + marker={"size": 0.1}, + line={"width": line_width}, + name=trace_name, + ), + row=row, + col=col, + ) + + # subsample the ray points (if needed) + if n_pts_per_ray > max_points_per_ray: + indices_ray_pts = torch.cat( + [ + torch.randperm(n_pts_per_ray)[:max_points_per_ray] + ri * n_pts_per_ray + for ri in range(ray_bundle_flat.lengths.shape[0]) + ] + ) + ray_bundle_flat = ray_bundle_flat._replace( + lengths=ray_bundle_flat.lengths.reshape(-1)[indices_ray_pts].reshape( + ray_bundle_flat.lengths.shape[0], -1 + ) + ) + + # plot the ray points + ray_points = ( + ray_bundle_to_ray_points(ray_bundle_flat) + .view(-1, 3) + .detach() + .cpu() + .numpy() + .astype(float) + ) + fig.add_trace( + go.Scatter3d( + x=ray_points[:, 0], + y=ray_points[:, 1], + z=ray_points[:, 2], + mode="markers", + name=trace_name + "_points", + marker={"size": marker_size}, + ), + row=row, + col=col, + ) + + # Access the current subplot's scene configuration + plot_scene = "scene" + str(subplot_idx + 1) + current_layout = fig["layout"][plot_scene] + + # update the bounds of the axes for the current trace + all_ray_points = ray_bundle_to_ray_points(ray_bundle).view(-1, 3) + ray_points_center = all_ray_points.mean(dim=0) + max_expand = (all_ray_points.max(0)[0] - all_ray_points.min(0)[0]).max().item() + _update_axes_bounds(ray_points_center, float(max_expand), current_layout) + + def _gen_fig_with_subplots( batch_size: int, ncols: int, subplot_titles: List[str] ): # pragma: no cover