From bf7aca320a2a7d7cd38438f1e07dc2c6fda91cf1 Mon Sep 17 00:00:00 2001 From: Amitav Baruah Date: Tue, 20 Oct 2020 17:14:38 -0700 Subject: [PATCH] Add wrapper function to plot batches Summary: - adds plot_batch_individually - for each batched object, plots each object in its own subplot with other same-indexed elements of the other batched objects provided as input Reviewed By: nikhilaravi Differential Revision: D24258389 fbshipit-source-id: a80128e6e7a03a44c257b0598569159afadb2d39 --- pytorch3d/vis/__init__.py | 2 +- pytorch3d/vis/plotly_vis.py | 141 +++++++++++++++++++++++++++++++++++- 2 files changed, 141 insertions(+), 2 deletions(-) diff --git a/pytorch3d/vis/__init__.py b/pytorch3d/vis/__init__.py index 2ce6f7bb..3dfbf532 100644 --- a/pytorch3d/vis/__init__.py +++ b/pytorch3d/vis/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from .plotly_vis import AxisArgs, Lighting, plot_scene +from .plotly_vis import AxisArgs, Lighting, plot_batch_individually, plot_scene __all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/pytorch3d/vis/plotly_vis.py b/pytorch3d/vis/plotly_vis.py index fbe8e1b2..f417f4b3 100644 --- a/pytorch3d/vis/plotly_vis.py +++ b/pytorch3d/vis/plotly_vis.py @@ -1,7 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. import warnings -from typing import Dict, List, NamedTuple, Union +from typing import Dict, List, NamedTuple, Optional, Union import numpy as np import plotly.graph_objects as go @@ -202,6 +202,145 @@ def plot_scene( return fig +def plot_batch_individually( + batched_structs: Union[List[Union[Meshes, Pointclouds]], Meshes, Pointclouds], + *, + ncols: int = 1, + extend_struct: bool = True, + subplot_titles: Optional[List[str]] = None, + **kwargs, +): + """ + This is a higher level plotting function than plot_scene, for plotting + Meshes and Pointclouds in simple cases. The simplest use is to plot a + single Meshes or Pointclouds object, where you just pass it in as a + one element list. This will plot each batch element in a separate subplot. + + More generally, you can supply multiple Meshes or Pointclouds + having the same batch size `n`. In this case, there will be `n` subplots, + each depicting the corresponding batch element of all the inputs. + + In addition, you can include Meshes and Pointclouds of size 1 in + the input. These will either be rendered in the first subplot + (if extend_struct is False), or in every subplot. + + Args: + batched_structs: a list of Meshes and/or Pointclouds to be rendered. + Each structure's corresponding batch element will be plotted in + a single subplot, resulting in n subplots for a batch of size n. + Every struct should either have the same batch size or be of batch size 1. + See extend_struct and the description above for how batch size 1 structs + are handled. Also accepts a single Meshes or Pointclouds object, which will have + each individual element plotted in its own subplot. + ncols: the number of subplots per row + extend_struct: if True, indicates that structs of batch size 1 + should be plotted in every subplot. + subplot_titles: strings to name each subplot + **kwargs: keyword arguments which are passed to plot_scene. + See plot_scene documentation for details. + + Example: + + ..code-block::python + + mesh = ... # mesh of batch size 2 + point_cloud = ... # point_cloud of batch size 2 + fig = plot_batch_individually([mesh, point_cloud], subplot_titles=["plot1", "plot2"]) + fig.show() + + # this is equivalent to the below figure + fig = plot_scene({ + "plot1": { + "trace1-1": mesh[0], + "trace1-2": point_cloud[0] + }, + "plot2":{ + "trace2-1": mesh[1], + "trace2-2": point_cloud[1] + } + }) + fig.show() + + The above example will render two subplots which each have both a mesh and pointcloud. + For more examples look at the pytorch3d tutorials at `pytorch3d/docs/tutorials`, + in particular the files rendered_color_points.ipynb and rendered_textured_meshes.ipynb. + """ + + # check that every batch is the same size or is size 1 + if len(batched_structs) == 0: + msg = "No structs to plot" + warnings.warn(msg) + return + max_size = 0 + if isinstance(batched_structs, list): + max_size = max(len(s) for s in batched_structs) + for struct in batched_structs: + if len(struct) not in (1, max_size): + msg = "invalid batch size {} provided: {}".format(len(struct), struct) + raise ValueError(msg) + else: + max_size = len(batched_structs) + + if max_size == 0: + msg = "No data is provided with at least one element" + raise ValueError(msg) + + if subplot_titles: + if len(subplot_titles) != max_size: + msg = "invalid number of subplot titles" + raise ValueError(msg) + + scene_dictionary = {} + # construct the scene dictionary + for scene_num in range(max_size): + subplot_title = ( + subplot_titles[scene_num] + if subplot_titles + else "subplot " + str(scene_num + 1) + ) + scene_dictionary[subplot_title] = {} + + if isinstance(batched_structs, list): + for i, batched_struct in enumerate(batched_structs): + # check for whether this struct needs to be extended + if i >= len(batched_struct) and not extend_struct: + continue + _add_struct_from_batch( + batched_struct, scene_num, subplot_title, scene_dictionary, i + 1 + ) + else: # batched_structs is a single struct + _add_struct_from_batch( + batched_structs, scene_num, subplot_title, scene_dictionary + ) + + return plot_scene(scene_dictionary, ncols=ncols, **kwargs) + + +def _add_struct_from_batch( + batched_struct: Union[Meshes, Pointclouds], + scene_num: int, + subplot_title: str, + scene_dictionary: Dict[str, Dict[str, Union[Meshes, Pointclouds]]], + trace_idx: int = 1, +): + """ + Adds the struct corresponding to the given scene_num index to + a provided scene_dictionary to be passed in to plot_scene + + Args: + batched_struct: the batched data structure to add to the dict + scene_num: the subplot from plot_batch_individually which this struct + should be added to + subplot_title: the title of the subplot + scene_dictionary: the dictionary to add the indexed struct to + trace_idx: the trace number, starting at 1 for this struct's trace + """ + struct_idx = min(scene_num, len(batched_struct) - 1) + struct = batched_struct[struct_idx] + trace_name = "trace{}-{}".format(scene_num + 1, trace_idx) + scene_dictionary[subplot_title][trace_name] = struct + + def _add_mesh_trace( fig: go.Figure, # pyre-ignore[11] meshes: Meshes,