mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Culling to frustrum bug fix
Summary: When `z_clip_value = None` and faces are outside the view frustum the shape of one of the tensors in `clip.py` is incorrect. `faces_num_clipped_verts` should be (F,) but it was (F,3). Added a new test to ensure this case is handled. Reviewed By: bottler Differential Revision: D29051282 fbshipit-source-id: 5f4172ba4d4a75d928404dde9abf48aef18c68bd
This commit is contained in:
parent
ef16253953
commit
a0f79318c5
@ -372,7 +372,7 @@ def clip_faces(
|
||||
# (F) dim tensor containing the number of clipped vertices in each triangle
|
||||
faces_num_clipped_verts = faces_clipped_verts.sum(1)
|
||||
else:
|
||||
faces_num_clipped_verts = torch.zeros([F, 3], device=device)
|
||||
faces_num_clipped_verts = torch.zeros([F], device=device)
|
||||
|
||||
# If no triangles need to be clipped or culled, avoid unnecessary computation
|
||||
# and return early
|
||||
|
Binary file not shown.
Before Width: | Height: | Size: 7.2 KiB |
@ -15,7 +15,11 @@ import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin, get_tests_dir, load_rgb_image
|
||||
from pytorch3d.io import save_obj
|
||||
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
|
||||
from pytorch3d.renderer.cameras import (
|
||||
FoVPerspectiveCameras,
|
||||
look_at_view_transform,
|
||||
PerspectiveCameras,
|
||||
)
|
||||
from pytorch3d.renderer.lighting import PointLights
|
||||
from pytorch3d.renderer.mesh import (
|
||||
ClipFrustum,
|
||||
@ -27,8 +31,9 @@ from pytorch3d.renderer.mesh.rasterize_meshes import _RasterizeFaceVerts
|
||||
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
|
||||
from pytorch3d.renderer.mesh.renderer import MeshRenderer
|
||||
from pytorch3d.renderer.mesh.shader import SoftPhongShader
|
||||
from pytorch3d.renderer.mesh.textures import TexturesVertex
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
|
||||
from pytorch3d.utils import torus
|
||||
|
||||
# If DEBUG=True, save out images generated in the tests for debugging.
|
||||
# All saved images have prefix DEBUG_
|
||||
@ -97,9 +102,9 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
|
||||
return mesh, verts
|
||||
return mesh
|
||||
|
||||
def test_cube_mesh_render(self):
|
||||
def debug_cube_mesh_render(self):
|
||||
"""
|
||||
End-End test of rendering a cube mesh with texture
|
||||
End-End debug run of rendering a cube mesh with texture
|
||||
from decreasing camera distances. The camera starts
|
||||
outside the cube and enters the inside of the cube.
|
||||
"""
|
||||
@ -132,22 +137,16 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
|
||||
# the camera enters the cube. Check the output looks correct.
|
||||
images_list = []
|
||||
dists = np.linspace(0.1, 2.5, 20)[::-1]
|
||||
|
||||
for d in dists:
|
||||
R, T = look_at_view_transform(d, 0, 0)
|
||||
T[0, 1] -= 0.1 # move down in the y axis
|
||||
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90)
|
||||
images = renderer(mesh, cameras=cameras)
|
||||
rgb = images[0, ..., :3].cpu().detach()
|
||||
filename = "DEBUG_cube_dist=%.1f.jpg" % d
|
||||
im = (rgb.numpy() * 255).astype(np.uint8)
|
||||
images_list.append(im)
|
||||
|
||||
# Check one of the images where the camera is inside the mesh
|
||||
if d == 0.5:
|
||||
filename = "test_render_mesh_clipped_cam_dist=0.5.jpg"
|
||||
image_ref = load_rgb_image(filename, DATA_DIR)
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
||||
# Save a gif of the output - this should show
|
||||
# the camera moving inside the cube.
|
||||
if DEBUG:
|
||||
@ -655,3 +654,25 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
|
||||
double_hit = torch.tensor([0, 0, -1], device=device)
|
||||
check_double_hit = any(torch.allclose(i, double_hit) for i in unique_vals)
|
||||
self.assertFalse(check_double_hit)
|
||||
|
||||
def test_mesh_outside_frustrum(self):
|
||||
"""
|
||||
Test the case where the mesh is completely outside the view
|
||||
frustrum so all faces are culled and z_clip_value = None.
|
||||
"""
|
||||
device = "cuda:0"
|
||||
mesh = torus(20.0, 85.0, 32, 16, device=device)
|
||||
tex = TexturesVertex(verts_features=torch.rand_like(mesh.verts_padded()))
|
||||
mesh.textures = tex
|
||||
raster_settings = RasterizationSettings(image_size=512, cull_to_frustum=True)
|
||||
R, T = look_at_view_transform(1.0, 0.0, 0.0)
|
||||
cameras = PerspectiveCameras(device=device, R=R, T=T)
|
||||
renderer = MeshRenderer(
|
||||
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
|
||||
shader=SoftPhongShader(cameras=cameras, device=device),
|
||||
)
|
||||
images = renderer(mesh)
|
||||
|
||||
# Mesh is completely outside the view frustrum
|
||||
# The image should be white.
|
||||
self.assertClose(images[0, ..., :3], torch.ones_like(images[0, ..., :3]))
|
||||
|
Loading…
x
Reference in New Issue
Block a user