mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Culling to frustrum bug fix
Summary: When `z_clip_value = None` and faces are outside the view frustum the shape of one of the tensors in `clip.py` is incorrect. `faces_num_clipped_verts` should be (F,) but it was (F,3). Added a new test to ensure this case is handled. Reviewed By: bottler Differential Revision: D29051282 fbshipit-source-id: 5f4172ba4d4a75d928404dde9abf48aef18c68bd
This commit is contained in:
		
							parent
							
								
									ef16253953
								
							
						
					
					
						commit
						a0f79318c5
					
				@ -372,7 +372,7 @@ def clip_faces(
 | 
			
		||||
        # (F) dim tensor containing the number of clipped vertices in each triangle
 | 
			
		||||
        faces_num_clipped_verts = faces_clipped_verts.sum(1)
 | 
			
		||||
    else:
 | 
			
		||||
        faces_num_clipped_verts = torch.zeros([F, 3], device=device)
 | 
			
		||||
        faces_num_clipped_verts = torch.zeros([F], device=device)
 | 
			
		||||
 | 
			
		||||
    # If no triangles need to be clipped or culled, avoid unnecessary computation
 | 
			
		||||
    # and return early
 | 
			
		||||
 | 
			
		||||
										
											Binary file not shown.
										
									
								
							| 
		 Before Width: | Height: | Size: 7.2 KiB  | 
@ -15,7 +15,11 @@ import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from common_testing import TestCaseMixin, get_tests_dir, load_rgb_image
 | 
			
		||||
from pytorch3d.io import save_obj
 | 
			
		||||
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
 | 
			
		||||
from pytorch3d.renderer.cameras import (
 | 
			
		||||
    FoVPerspectiveCameras,
 | 
			
		||||
    look_at_view_transform,
 | 
			
		||||
    PerspectiveCameras,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.renderer.lighting import PointLights
 | 
			
		||||
from pytorch3d.renderer.mesh import (
 | 
			
		||||
    ClipFrustum,
 | 
			
		||||
@ -27,8 +31,9 @@ from pytorch3d.renderer.mesh.rasterize_meshes import _RasterizeFaceVerts
 | 
			
		||||
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
 | 
			
		||||
from pytorch3d.renderer.mesh.renderer import MeshRenderer
 | 
			
		||||
from pytorch3d.renderer.mesh.shader import SoftPhongShader
 | 
			
		||||
from pytorch3d.renderer.mesh.textures import TexturesVertex
 | 
			
		||||
from pytorch3d.structures.meshes import Meshes
 | 
			
		||||
 | 
			
		||||
from pytorch3d.utils import torus
 | 
			
		||||
 | 
			
		||||
# If DEBUG=True, save out images generated in the tests for debugging.
 | 
			
		||||
# All saved images have prefix DEBUG_
 | 
			
		||||
@ -97,9 +102,9 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            return mesh, verts
 | 
			
		||||
        return mesh
 | 
			
		||||
 | 
			
		||||
    def test_cube_mesh_render(self):
 | 
			
		||||
    def debug_cube_mesh_render(self):
 | 
			
		||||
        """
 | 
			
		||||
        End-End test of rendering a cube mesh with texture
 | 
			
		||||
        End-End debug run of rendering a cube mesh with texture
 | 
			
		||||
        from decreasing camera distances. The camera starts
 | 
			
		||||
        outside the cube and enters the inside of the cube.
 | 
			
		||||
        """
 | 
			
		||||
@ -132,22 +137,16 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        # the camera enters the cube. Check the output looks correct.
 | 
			
		||||
        images_list = []
 | 
			
		||||
        dists = np.linspace(0.1, 2.5, 20)[::-1]
 | 
			
		||||
 | 
			
		||||
        for d in dists:
 | 
			
		||||
            R, T = look_at_view_transform(d, 0, 0)
 | 
			
		||||
            T[0, 1] -= 0.1  # move down in the y axis
 | 
			
		||||
            cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90)
 | 
			
		||||
            images = renderer(mesh, cameras=cameras)
 | 
			
		||||
            rgb = images[0, ..., :3].cpu().detach()
 | 
			
		||||
            filename = "DEBUG_cube_dist=%.1f.jpg" % d
 | 
			
		||||
            im = (rgb.numpy() * 255).astype(np.uint8)
 | 
			
		||||
            images_list.append(im)
 | 
			
		||||
 | 
			
		||||
            # Check one of the images where the camera is inside the mesh
 | 
			
		||||
            if d == 0.5:
 | 
			
		||||
                filename = "test_render_mesh_clipped_cam_dist=0.5.jpg"
 | 
			
		||||
                image_ref = load_rgb_image(filename, DATA_DIR)
 | 
			
		||||
                self.assertClose(rgb, image_ref, atol=0.05)
 | 
			
		||||
 | 
			
		||||
        # Save a gif of the output - this should show
 | 
			
		||||
        # the camera moving inside the cube.
 | 
			
		||||
        if DEBUG:
 | 
			
		||||
@ -655,3 +654,25 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            double_hit = torch.tensor([0, 0, -1], device=device)
 | 
			
		||||
            check_double_hit = any(torch.allclose(i, double_hit) for i in unique_vals)
 | 
			
		||||
            self.assertFalse(check_double_hit)
 | 
			
		||||
 | 
			
		||||
    def test_mesh_outside_frustrum(self):
 | 
			
		||||
        """
 | 
			
		||||
        Test the case where the mesh is completely outside the view
 | 
			
		||||
        frustrum so all faces are culled and z_clip_value = None.
 | 
			
		||||
        """
 | 
			
		||||
        device = "cuda:0"
 | 
			
		||||
        mesh = torus(20.0, 85.0, 32, 16, device=device)
 | 
			
		||||
        tex = TexturesVertex(verts_features=torch.rand_like(mesh.verts_padded()))
 | 
			
		||||
        mesh.textures = tex
 | 
			
		||||
        raster_settings = RasterizationSettings(image_size=512, cull_to_frustum=True)
 | 
			
		||||
        R, T = look_at_view_transform(1.0, 0.0, 0.0)
 | 
			
		||||
        cameras = PerspectiveCameras(device=device, R=R, T=T)
 | 
			
		||||
        renderer = MeshRenderer(
 | 
			
		||||
            rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
 | 
			
		||||
            shader=SoftPhongShader(cameras=cameras, device=device),
 | 
			
		||||
        )
 | 
			
		||||
        images = renderer(mesh)
 | 
			
		||||
 | 
			
		||||
        # Mesh is completely outside the view frustrum
 | 
			
		||||
        # The image should be white.
 | 
			
		||||
        self.assertClose(images[0, ..., :3], torch.ones_like(images[0, ..., :3]))
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user