mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
Fixes for RayBundle plotting
Summary: Fixes some issues with RayBundle plotting: - allows plotting raybundles on gpu - view -> reshape since we do not require contiguous raybundle tensors as input Reviewed By: bottler, shapovalov Differential Revision: D42665923 fbshipit-source-id: e9c6c7810428365dca4cb5ec80ef15ff28644163
This commit is contained in:
parent
a12612a48f
commit
9dc28f5dd5
@ -3,3 +3,19 @@
|
|||||||
#
|
#
|
||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .plotly_vis import get_camera_wireframe, plot_batch_individually, plot_scene
|
||||||
|
except ModuleNotFoundError as err:
|
||||||
|
if "plotly" in str(err):
|
||||||
|
warnings.warn(
|
||||||
|
"Cannot import plotly-based visualization code."
|
||||||
|
" Please install plotly to enable (pip install plotly)."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
from .texture_vis import texturesuv_image_matplotlib, texturesuv_image_PIL
|
||||||
|
@ -100,6 +100,7 @@ class Lighting(NamedTuple): # pragma: no cover
|
|||||||
vertexnormalsepsilon: float = 1e-12
|
vertexnormalsepsilon: float = 1e-12
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def plot_scene(
|
def plot_scene(
|
||||||
plots: Dict[str, Dict[str, Struct]],
|
plots: Dict[str, Dict[str, Struct]],
|
||||||
*,
|
*,
|
||||||
@ -407,6 +408,7 @@ def plot_scene(
|
|||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def plot_batch_individually(
|
def plot_batch_individually(
|
||||||
batched_structs: Union[
|
batched_structs: Union[
|
||||||
List[Struct],
|
List[Struct],
|
||||||
@ -888,8 +890,12 @@ def _add_ray_bundle_trace(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# make the ray lines for plotly plotting
|
# make the ray lines for plotly plotting
|
||||||
nan_tensor = torch.Tensor([[float("NaN")] * 3])
|
nan_tensor = torch.tensor(
|
||||||
ray_lines = torch.empty(size=(1, 3))
|
[[float("NaN")] * 3],
|
||||||
|
device=ray_lines_endpoints.device,
|
||||||
|
dtype=ray_lines_endpoints.dtype,
|
||||||
|
)
|
||||||
|
ray_lines = torch.empty(size=(1, 3), device=ray_lines_endpoints.device)
|
||||||
for ray_line in ray_lines_endpoints:
|
for ray_line in ray_lines_endpoints:
|
||||||
# We combine the ray lines into a single tensor to plot them in a
|
# We combine the ray lines into a single tensor to plot them in a
|
||||||
# single trace. The NaNs are inserted between sets of ray lines
|
# single trace. The NaNs are inserted between sets of ray lines
|
||||||
@ -952,7 +958,7 @@ def _add_ray_bundle_trace(
|
|||||||
current_layout = fig["layout"][plot_scene]
|
current_layout = fig["layout"][plot_scene]
|
||||||
|
|
||||||
# update the bounds of the axes for the current trace
|
# update the bounds of the axes for the current trace
|
||||||
all_ray_points = ray_bundle_to_ray_points(ray_bundle).view(-1, 3)
|
all_ray_points = ray_bundle_to_ray_points(ray_bundle).reshape(-1, 3)
|
||||||
ray_points_center = all_ray_points.mean(dim=0)
|
ray_points_center = all_ray_points.mean(dim=0)
|
||||||
max_expand = (all_ray_points.max(0)[0] - all_ray_points.min(0)[0]).max().item()
|
max_expand = (all_ray_points.max(0)[0] - all_ray_points.min(0)[0]).max().item()
|
||||||
_update_axes_bounds(ray_points_center, float(max_expand), current_layout)
|
_update_axes_bounds(ray_points_center, float(max_expand), current_layout)
|
||||||
@ -1002,6 +1008,7 @@ def _update_axes_bounds(
|
|||||||
max_expand: the maximum spread in any dimension of the trace's vertices.
|
max_expand: the maximum spread in any dimension of the trace's vertices.
|
||||||
current_layout: the plotly figure layout scene corresponding to the referenced trace.
|
current_layout: the plotly figure layout scene corresponding to the referenced trace.
|
||||||
"""
|
"""
|
||||||
|
verts_center = verts_center.detach().cpu()
|
||||||
verts_min = verts_center - max_expand
|
verts_min = verts_center - max_expand
|
||||||
verts_max = verts_center + max_expand
|
verts_max = verts_center + max_expand
|
||||||
bounds = torch.t(torch.stack((verts_min, verts_max)))
|
bounds = torch.t(torch.stack((verts_min, verts_max)))
|
||||||
|
74
tests/test_vis.py
Normal file
74
tests/test_vis.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pytorch3d.renderer import HeterogeneousRayBundle, PerspectiveCameras, RayBundle
|
||||||
|
from pytorch3d.structures import Meshes, Pointclouds
|
||||||
|
from pytorch3d.transforms import random_rotations
|
||||||
|
|
||||||
|
# Some of these imports are only needed for testing code coverage
|
||||||
|
from pytorch3d.vis import ( # noqa: F401
|
||||||
|
get_camera_wireframe, # noqa: F401
|
||||||
|
plot_batch_individually, # noqa: F401
|
||||||
|
plot_scene,
|
||||||
|
texturesuv_image_PIL, # noqa: F401
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPlotlyVis(unittest.TestCase):
|
||||||
|
def test_plot_scene(
|
||||||
|
self,
|
||||||
|
B: int = 3,
|
||||||
|
n_rays: int = 128,
|
||||||
|
n_pts_per_ray: int = 32,
|
||||||
|
n_verts: int = 32,
|
||||||
|
n_edges: int = 64,
|
||||||
|
n_pts: int = 256,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Tests plotting of all supported structures using plot_scene.
|
||||||
|
"""
|
||||||
|
for device in ["cpu", "cuda:0"]:
|
||||||
|
plot_scene(
|
||||||
|
{
|
||||||
|
"scene": {
|
||||||
|
"ray_bundle": RayBundle(
|
||||||
|
origins=torch.randn(B, n_rays, 3, device=device),
|
||||||
|
xys=torch.randn(B, n_rays, 2, device=device),
|
||||||
|
directions=torch.randn(B, n_rays, 3, device=device),
|
||||||
|
lengths=torch.randn(
|
||||||
|
B, n_rays, n_pts_per_ray, device=device
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"heterogeneous_ray_bundle": HeterogeneousRayBundle(
|
||||||
|
origins=torch.randn(B * n_rays, 3, device=device),
|
||||||
|
xys=torch.randn(B * n_rays, 2, device=device),
|
||||||
|
directions=torch.randn(B * n_rays, 3, device=device),
|
||||||
|
lengths=torch.randn(
|
||||||
|
B * n_rays, n_pts_per_ray, device=device
|
||||||
|
),
|
||||||
|
camera_ids=torch.randint(
|
||||||
|
low=0, high=B, size=(B * n_rays,), device=device
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"camera": PerspectiveCameras(
|
||||||
|
R=random_rotations(B, device=device),
|
||||||
|
T=torch.randn(B, 3, device=device),
|
||||||
|
),
|
||||||
|
"mesh": Meshes(
|
||||||
|
verts=torch.randn(B, n_verts, 3, device=device),
|
||||||
|
faces=torch.randint(
|
||||||
|
low=0, high=n_verts, size=(B, n_edges, 3), device=device
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"point_clouds": Pointclouds(
|
||||||
|
points=torch.randn(B, n_pts, 3, device=device),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
Loading…
x
Reference in New Issue
Block a user