mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
Fixes for RayBundle plotting
Summary: Fixes some issues with RayBundle plotting: - allows plotting raybundles on gpu - view -> reshape since we do not require contiguous raybundle tensors as input Reviewed By: bottler, shapovalov Differential Revision: D42665923 fbshipit-source-id: e9c6c7810428365dca4cb5ec80ef15ff28644163
This commit is contained in:
committed by
Facebook GitHub Bot
parent
a12612a48f
commit
9dc28f5dd5
74
tests/test_vis.py
Normal file
74
tests/test_vis.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from pytorch3d.renderer import HeterogeneousRayBundle, PerspectiveCameras, RayBundle
|
||||
from pytorch3d.structures import Meshes, Pointclouds
|
||||
from pytorch3d.transforms import random_rotations
|
||||
|
||||
# Some of these imports are only needed for testing code coverage
|
||||
from pytorch3d.vis import ( # noqa: F401
|
||||
get_camera_wireframe, # noqa: F401
|
||||
plot_batch_individually, # noqa: F401
|
||||
plot_scene,
|
||||
texturesuv_image_PIL, # noqa: F401
|
||||
)
|
||||
|
||||
|
||||
class TestPlotlyVis(unittest.TestCase):
|
||||
def test_plot_scene(
|
||||
self,
|
||||
B: int = 3,
|
||||
n_rays: int = 128,
|
||||
n_pts_per_ray: int = 32,
|
||||
n_verts: int = 32,
|
||||
n_edges: int = 64,
|
||||
n_pts: int = 256,
|
||||
):
|
||||
"""
|
||||
Tests plotting of all supported structures using plot_scene.
|
||||
"""
|
||||
for device in ["cpu", "cuda:0"]:
|
||||
plot_scene(
|
||||
{
|
||||
"scene": {
|
||||
"ray_bundle": RayBundle(
|
||||
origins=torch.randn(B, n_rays, 3, device=device),
|
||||
xys=torch.randn(B, n_rays, 2, device=device),
|
||||
directions=torch.randn(B, n_rays, 3, device=device),
|
||||
lengths=torch.randn(
|
||||
B, n_rays, n_pts_per_ray, device=device
|
||||
),
|
||||
),
|
||||
"heterogeneous_ray_bundle": HeterogeneousRayBundle(
|
||||
origins=torch.randn(B * n_rays, 3, device=device),
|
||||
xys=torch.randn(B * n_rays, 2, device=device),
|
||||
directions=torch.randn(B * n_rays, 3, device=device),
|
||||
lengths=torch.randn(
|
||||
B * n_rays, n_pts_per_ray, device=device
|
||||
),
|
||||
camera_ids=torch.randint(
|
||||
low=0, high=B, size=(B * n_rays,), device=device
|
||||
),
|
||||
),
|
||||
"camera": PerspectiveCameras(
|
||||
R=random_rotations(B, device=device),
|
||||
T=torch.randn(B, 3, device=device),
|
||||
),
|
||||
"mesh": Meshes(
|
||||
verts=torch.randn(B, n_verts, 3, device=device),
|
||||
faces=torch.randint(
|
||||
low=0, high=n_verts, size=(B, n_edges, 3), device=device
|
||||
),
|
||||
),
|
||||
"point_clouds": Pointclouds(
|
||||
points=torch.randn(B, n_pts, 3, device=device),
|
||||
),
|
||||
}
|
||||
}
|
||||
)
|
||||
Reference in New Issue
Block a user