diff --git a/pytorch3d/vis/__init__.py b/pytorch3d/vis/__init__.py index 2e41cd71..972cc5ce 100644 --- a/pytorch3d/vis/__init__.py +++ b/pytorch3d/vis/__init__.py @@ -3,3 +3,19 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + +import warnings + + +try: + from .plotly_vis import get_camera_wireframe, plot_batch_individually, plot_scene +except ModuleNotFoundError as err: + if "plotly" in str(err): + warnings.warn( + "Cannot import plotly-based visualization code." + " Please install plotly to enable (pip install plotly)." + ) + else: + raise + +from .texture_vis import texturesuv_image_matplotlib, texturesuv_image_PIL diff --git a/pytorch3d/vis/plotly_vis.py b/pytorch3d/vis/plotly_vis.py index 1cb4985d..c9b62adc 100644 --- a/pytorch3d/vis/plotly_vis.py +++ b/pytorch3d/vis/plotly_vis.py @@ -100,6 +100,7 @@ class Lighting(NamedTuple): # pragma: no cover vertexnormalsepsilon: float = 1e-12 +@torch.no_grad() def plot_scene( plots: Dict[str, Dict[str, Struct]], *, @@ -407,6 +408,7 @@ def plot_scene( return fig +@torch.no_grad() def plot_batch_individually( batched_structs: Union[ List[Struct], @@ -888,8 +890,12 @@ def _add_ray_bundle_trace( ) # make the ray lines for plotly plotting - nan_tensor = torch.Tensor([[float("NaN")] * 3]) - ray_lines = torch.empty(size=(1, 3)) + nan_tensor = torch.tensor( + [[float("NaN")] * 3], + device=ray_lines_endpoints.device, + dtype=ray_lines_endpoints.dtype, + ) + ray_lines = torch.empty(size=(1, 3), device=ray_lines_endpoints.device) for ray_line in ray_lines_endpoints: # We combine the ray lines into a single tensor to plot them in a # single trace. The NaNs are inserted between sets of ray lines @@ -952,7 +958,7 @@ def _add_ray_bundle_trace( current_layout = fig["layout"][plot_scene] # update the bounds of the axes for the current trace - all_ray_points = ray_bundle_to_ray_points(ray_bundle).view(-1, 3) + all_ray_points = ray_bundle_to_ray_points(ray_bundle).reshape(-1, 3) ray_points_center = all_ray_points.mean(dim=0) max_expand = (all_ray_points.max(0)[0] - all_ray_points.min(0)[0]).max().item() _update_axes_bounds(ray_points_center, float(max_expand), current_layout) @@ -1002,6 +1008,7 @@ def _update_axes_bounds( max_expand: the maximum spread in any dimension of the trace's vertices. current_layout: the plotly figure layout scene corresponding to the referenced trace. """ + verts_center = verts_center.detach().cpu() verts_min = verts_center - max_expand verts_max = verts_center + max_expand bounds = torch.t(torch.stack((verts_min, verts_max))) diff --git a/tests/test_vis.py b/tests/test_vis.py new file mode 100644 index 00000000..00a3abe4 --- /dev/null +++ b/tests/test_vis.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from pytorch3d.renderer import HeterogeneousRayBundle, PerspectiveCameras, RayBundle +from pytorch3d.structures import Meshes, Pointclouds +from pytorch3d.transforms import random_rotations + +# Some of these imports are only needed for testing code coverage +from pytorch3d.vis import ( # noqa: F401 + get_camera_wireframe, # noqa: F401 + plot_batch_individually, # noqa: F401 + plot_scene, + texturesuv_image_PIL, # noqa: F401 +) + + +class TestPlotlyVis(unittest.TestCase): + def test_plot_scene( + self, + B: int = 3, + n_rays: int = 128, + n_pts_per_ray: int = 32, + n_verts: int = 32, + n_edges: int = 64, + n_pts: int = 256, + ): + """ + Tests plotting of all supported structures using plot_scene. + """ + for device in ["cpu", "cuda:0"]: + plot_scene( + { + "scene": { + "ray_bundle": RayBundle( + origins=torch.randn(B, n_rays, 3, device=device), + xys=torch.randn(B, n_rays, 2, device=device), + directions=torch.randn(B, n_rays, 3, device=device), + lengths=torch.randn( + B, n_rays, n_pts_per_ray, device=device + ), + ), + "heterogeneous_ray_bundle": HeterogeneousRayBundle( + origins=torch.randn(B * n_rays, 3, device=device), + xys=torch.randn(B * n_rays, 2, device=device), + directions=torch.randn(B * n_rays, 3, device=device), + lengths=torch.randn( + B * n_rays, n_pts_per_ray, device=device + ), + camera_ids=torch.randint( + low=0, high=B, size=(B * n_rays,), device=device + ), + ), + "camera": PerspectiveCameras( + R=random_rotations(B, device=device), + T=torch.randn(B, 3, device=device), + ), + "mesh": Meshes( + verts=torch.randn(B, n_verts, 3, device=device), + faces=torch.randint( + low=0, high=n_verts, size=(B, n_edges, 3), device=device + ), + ), + "point_clouds": Pointclouds( + points=torch.randn(B, n_pts, 3, device=device), + ), + } + } + )