mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 09:52:11 +08:00 
			
		
		
		
	Fixes for RayBundle plotting
Summary: Fixes some issues with RayBundle plotting: - allows plotting raybundles on gpu - view -> reshape since we do not require contiguous raybundle tensors as input Reviewed By: bottler, shapovalov Differential Revision: D42665923 fbshipit-source-id: e9c6c7810428365dca4cb5ec80ef15ff28644163
This commit is contained in:
		
							parent
							
								
									a12612a48f
								
							
						
					
					
						commit
						9dc28f5dd5
					
				@ -3,3 +3,19 @@
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import warnings
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from .plotly_vis import get_camera_wireframe, plot_batch_individually, plot_scene
 | 
			
		||||
except ModuleNotFoundError as err:
 | 
			
		||||
    if "plotly" in str(err):
 | 
			
		||||
        warnings.warn(
 | 
			
		||||
            "Cannot import plotly-based visualization code."
 | 
			
		||||
            " Please install plotly to enable (pip install plotly)."
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        raise
 | 
			
		||||
 | 
			
		||||
from .texture_vis import texturesuv_image_matplotlib, texturesuv_image_PIL
 | 
			
		||||
 | 
			
		||||
@ -100,6 +100,7 @@ class Lighting(NamedTuple):  # pragma: no cover
 | 
			
		||||
    vertexnormalsepsilon: float = 1e-12
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def plot_scene(
 | 
			
		||||
    plots: Dict[str, Dict[str, Struct]],
 | 
			
		||||
    *,
 | 
			
		||||
@ -407,6 +408,7 @@ def plot_scene(
 | 
			
		||||
    return fig
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def plot_batch_individually(
 | 
			
		||||
    batched_structs: Union[
 | 
			
		||||
        List[Struct],
 | 
			
		||||
@ -888,8 +890,12 @@ def _add_ray_bundle_trace(
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # make the ray lines for plotly plotting
 | 
			
		||||
    nan_tensor = torch.Tensor([[float("NaN")] * 3])
 | 
			
		||||
    ray_lines = torch.empty(size=(1, 3))
 | 
			
		||||
    nan_tensor = torch.tensor(
 | 
			
		||||
        [[float("NaN")] * 3],
 | 
			
		||||
        device=ray_lines_endpoints.device,
 | 
			
		||||
        dtype=ray_lines_endpoints.dtype,
 | 
			
		||||
    )
 | 
			
		||||
    ray_lines = torch.empty(size=(1, 3), device=ray_lines_endpoints.device)
 | 
			
		||||
    for ray_line in ray_lines_endpoints:
 | 
			
		||||
        # We combine the ray lines into a single tensor to plot them in a
 | 
			
		||||
        # single trace. The NaNs are inserted between sets of ray lines
 | 
			
		||||
@ -952,7 +958,7 @@ def _add_ray_bundle_trace(
 | 
			
		||||
    current_layout = fig["layout"][plot_scene]
 | 
			
		||||
 | 
			
		||||
    # update the bounds of the axes for the current trace
 | 
			
		||||
    all_ray_points = ray_bundle_to_ray_points(ray_bundle).view(-1, 3)
 | 
			
		||||
    all_ray_points = ray_bundle_to_ray_points(ray_bundle).reshape(-1, 3)
 | 
			
		||||
    ray_points_center = all_ray_points.mean(dim=0)
 | 
			
		||||
    max_expand = (all_ray_points.max(0)[0] - all_ray_points.min(0)[0]).max().item()
 | 
			
		||||
    _update_axes_bounds(ray_points_center, float(max_expand), current_layout)
 | 
			
		||||
@ -1002,6 +1008,7 @@ def _update_axes_bounds(
 | 
			
		||||
        max_expand: the maximum spread in any dimension of the trace's vertices.
 | 
			
		||||
        current_layout: the plotly figure layout scene corresponding to the referenced trace.
 | 
			
		||||
    """
 | 
			
		||||
    verts_center = verts_center.detach().cpu()
 | 
			
		||||
    verts_min = verts_center - max_expand
 | 
			
		||||
    verts_max = verts_center + max_expand
 | 
			
		||||
    bounds = torch.t(torch.stack((verts_min, verts_max)))
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										74
									
								
								tests/test_vis.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								tests/test_vis.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,74 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.renderer import HeterogeneousRayBundle, PerspectiveCameras, RayBundle
 | 
			
		||||
from pytorch3d.structures import Meshes, Pointclouds
 | 
			
		||||
from pytorch3d.transforms import random_rotations
 | 
			
		||||
 | 
			
		||||
# Some of these imports are only needed for testing code coverage
 | 
			
		||||
from pytorch3d.vis import (  # noqa: F401
 | 
			
		||||
    get_camera_wireframe,  # noqa: F401
 | 
			
		||||
    plot_batch_individually,  # noqa: F401
 | 
			
		||||
    plot_scene,
 | 
			
		||||
    texturesuv_image_PIL,  # noqa: F401
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestPlotlyVis(unittest.TestCase):
 | 
			
		||||
    def test_plot_scene(
 | 
			
		||||
        self,
 | 
			
		||||
        B: int = 3,
 | 
			
		||||
        n_rays: int = 128,
 | 
			
		||||
        n_pts_per_ray: int = 32,
 | 
			
		||||
        n_verts: int = 32,
 | 
			
		||||
        n_edges: int = 64,
 | 
			
		||||
        n_pts: int = 256,
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        Tests plotting of all supported structures using plot_scene.
 | 
			
		||||
        """
 | 
			
		||||
        for device in ["cpu", "cuda:0"]:
 | 
			
		||||
            plot_scene(
 | 
			
		||||
                {
 | 
			
		||||
                    "scene": {
 | 
			
		||||
                        "ray_bundle": RayBundle(
 | 
			
		||||
                            origins=torch.randn(B, n_rays, 3, device=device),
 | 
			
		||||
                            xys=torch.randn(B, n_rays, 2, device=device),
 | 
			
		||||
                            directions=torch.randn(B, n_rays, 3, device=device),
 | 
			
		||||
                            lengths=torch.randn(
 | 
			
		||||
                                B, n_rays, n_pts_per_ray, device=device
 | 
			
		||||
                            ),
 | 
			
		||||
                        ),
 | 
			
		||||
                        "heterogeneous_ray_bundle": HeterogeneousRayBundle(
 | 
			
		||||
                            origins=torch.randn(B * n_rays, 3, device=device),
 | 
			
		||||
                            xys=torch.randn(B * n_rays, 2, device=device),
 | 
			
		||||
                            directions=torch.randn(B * n_rays, 3, device=device),
 | 
			
		||||
                            lengths=torch.randn(
 | 
			
		||||
                                B * n_rays, n_pts_per_ray, device=device
 | 
			
		||||
                            ),
 | 
			
		||||
                            camera_ids=torch.randint(
 | 
			
		||||
                                low=0, high=B, size=(B * n_rays,), device=device
 | 
			
		||||
                            ),
 | 
			
		||||
                        ),
 | 
			
		||||
                        "camera": PerspectiveCameras(
 | 
			
		||||
                            R=random_rotations(B, device=device),
 | 
			
		||||
                            T=torch.randn(B, 3, device=device),
 | 
			
		||||
                        ),
 | 
			
		||||
                        "mesh": Meshes(
 | 
			
		||||
                            verts=torch.randn(B, n_verts, 3, device=device),
 | 
			
		||||
                            faces=torch.randint(
 | 
			
		||||
                                low=0, high=n_verts, size=(B, n_edges, 3), device=device
 | 
			
		||||
                            ),
 | 
			
		||||
                        ),
 | 
			
		||||
                        "point_clouds": Pointclouds(
 | 
			
		||||
                            points=torch.randn(B, n_pts, 3, device=device),
 | 
			
		||||
                        ),
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user