mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Culling to frustrum bug fix
Summary: When `z_clip_value = None` and faces are outside the view frustum the shape of one of the tensors in `clip.py` is incorrect. `faces_num_clipped_verts` should be (F,) but it was (F,3). Added a new test to ensure this case is handled. Reviewed By: bottler Differential Revision: D29051282 fbshipit-source-id: 5f4172ba4d4a75d928404dde9abf48aef18c68bd
This commit is contained in:
parent
ef16253953
commit
a0f79318c5
@ -372,7 +372,7 @@ def clip_faces(
|
|||||||
# (F) dim tensor containing the number of clipped vertices in each triangle
|
# (F) dim tensor containing the number of clipped vertices in each triangle
|
||||||
faces_num_clipped_verts = faces_clipped_verts.sum(1)
|
faces_num_clipped_verts = faces_clipped_verts.sum(1)
|
||||||
else:
|
else:
|
||||||
faces_num_clipped_verts = torch.zeros([F, 3], device=device)
|
faces_num_clipped_verts = torch.zeros([F], device=device)
|
||||||
|
|
||||||
# If no triangles need to be clipped or culled, avoid unnecessary computation
|
# If no triangles need to be clipped or culled, avoid unnecessary computation
|
||||||
# and return early
|
# and return early
|
||||||
|
Binary file not shown.
Before Width: | Height: | Size: 7.2 KiB |
@ -15,7 +15,11 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from common_testing import TestCaseMixin, get_tests_dir, load_rgb_image
|
from common_testing import TestCaseMixin, get_tests_dir, load_rgb_image
|
||||||
from pytorch3d.io import save_obj
|
from pytorch3d.io import save_obj
|
||||||
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
|
from pytorch3d.renderer.cameras import (
|
||||||
|
FoVPerspectiveCameras,
|
||||||
|
look_at_view_transform,
|
||||||
|
PerspectiveCameras,
|
||||||
|
)
|
||||||
from pytorch3d.renderer.lighting import PointLights
|
from pytorch3d.renderer.lighting import PointLights
|
||||||
from pytorch3d.renderer.mesh import (
|
from pytorch3d.renderer.mesh import (
|
||||||
ClipFrustum,
|
ClipFrustum,
|
||||||
@ -27,8 +31,9 @@ from pytorch3d.renderer.mesh.rasterize_meshes import _RasterizeFaceVerts
|
|||||||
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
|
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
|
||||||
from pytorch3d.renderer.mesh.renderer import MeshRenderer
|
from pytorch3d.renderer.mesh.renderer import MeshRenderer
|
||||||
from pytorch3d.renderer.mesh.shader import SoftPhongShader
|
from pytorch3d.renderer.mesh.shader import SoftPhongShader
|
||||||
|
from pytorch3d.renderer.mesh.textures import TexturesVertex
|
||||||
from pytorch3d.structures.meshes import Meshes
|
from pytorch3d.structures.meshes import Meshes
|
||||||
|
from pytorch3d.utils import torus
|
||||||
|
|
||||||
# If DEBUG=True, save out images generated in the tests for debugging.
|
# If DEBUG=True, save out images generated in the tests for debugging.
|
||||||
# All saved images have prefix DEBUG_
|
# All saved images have prefix DEBUG_
|
||||||
@ -97,9 +102,9 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
|
|||||||
return mesh, verts
|
return mesh, verts
|
||||||
return mesh
|
return mesh
|
||||||
|
|
||||||
def test_cube_mesh_render(self):
|
def debug_cube_mesh_render(self):
|
||||||
"""
|
"""
|
||||||
End-End test of rendering a cube mesh with texture
|
End-End debug run of rendering a cube mesh with texture
|
||||||
from decreasing camera distances. The camera starts
|
from decreasing camera distances. The camera starts
|
||||||
outside the cube and enters the inside of the cube.
|
outside the cube and enters the inside of the cube.
|
||||||
"""
|
"""
|
||||||
@ -132,22 +137,16 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
|
|||||||
# the camera enters the cube. Check the output looks correct.
|
# the camera enters the cube. Check the output looks correct.
|
||||||
images_list = []
|
images_list = []
|
||||||
dists = np.linspace(0.1, 2.5, 20)[::-1]
|
dists = np.linspace(0.1, 2.5, 20)[::-1]
|
||||||
|
|
||||||
for d in dists:
|
for d in dists:
|
||||||
R, T = look_at_view_transform(d, 0, 0)
|
R, T = look_at_view_transform(d, 0, 0)
|
||||||
T[0, 1] -= 0.1 # move down in the y axis
|
T[0, 1] -= 0.1 # move down in the y axis
|
||||||
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90)
|
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90)
|
||||||
images = renderer(mesh, cameras=cameras)
|
images = renderer(mesh, cameras=cameras)
|
||||||
rgb = images[0, ..., :3].cpu().detach()
|
rgb = images[0, ..., :3].cpu().detach()
|
||||||
filename = "DEBUG_cube_dist=%.1f.jpg" % d
|
|
||||||
im = (rgb.numpy() * 255).astype(np.uint8)
|
im = (rgb.numpy() * 255).astype(np.uint8)
|
||||||
images_list.append(im)
|
images_list.append(im)
|
||||||
|
|
||||||
# Check one of the images where the camera is inside the mesh
|
|
||||||
if d == 0.5:
|
|
||||||
filename = "test_render_mesh_clipped_cam_dist=0.5.jpg"
|
|
||||||
image_ref = load_rgb_image(filename, DATA_DIR)
|
|
||||||
self.assertClose(rgb, image_ref, atol=0.05)
|
|
||||||
|
|
||||||
# Save a gif of the output - this should show
|
# Save a gif of the output - this should show
|
||||||
# the camera moving inside the cube.
|
# the camera moving inside the cube.
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
@ -655,3 +654,25 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
|
|||||||
double_hit = torch.tensor([0, 0, -1], device=device)
|
double_hit = torch.tensor([0, 0, -1], device=device)
|
||||||
check_double_hit = any(torch.allclose(i, double_hit) for i in unique_vals)
|
check_double_hit = any(torch.allclose(i, double_hit) for i in unique_vals)
|
||||||
self.assertFalse(check_double_hit)
|
self.assertFalse(check_double_hit)
|
||||||
|
|
||||||
|
def test_mesh_outside_frustrum(self):
|
||||||
|
"""
|
||||||
|
Test the case where the mesh is completely outside the view
|
||||||
|
frustrum so all faces are culled and z_clip_value = None.
|
||||||
|
"""
|
||||||
|
device = "cuda:0"
|
||||||
|
mesh = torus(20.0, 85.0, 32, 16, device=device)
|
||||||
|
tex = TexturesVertex(verts_features=torch.rand_like(mesh.verts_padded()))
|
||||||
|
mesh.textures = tex
|
||||||
|
raster_settings = RasterizationSettings(image_size=512, cull_to_frustum=True)
|
||||||
|
R, T = look_at_view_transform(1.0, 0.0, 0.0)
|
||||||
|
cameras = PerspectiveCameras(device=device, R=R, T=T)
|
||||||
|
renderer = MeshRenderer(
|
||||||
|
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
|
||||||
|
shader=SoftPhongShader(cameras=cameras, device=device),
|
||||||
|
)
|
||||||
|
images = renderer(mesh)
|
||||||
|
|
||||||
|
# Mesh is completely outside the view frustrum
|
||||||
|
# The image should be white.
|
||||||
|
self.assertClose(images[0, ..., :3], torch.ones_like(images[0, ..., :3]))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user