mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 09:52:11 +08:00 
			
		
		
		
	Adding join_mesh in pytorch3d.structures.meshes
Summary: Adding a function in pytorch3d.structures.meshes to join multiple meshes into a Meshes object representing a single mesh. The function currently ignores all textures. Reviewed By: nikhilaravi Differential Revision: D21876908 fbshipit-source-id: 448602857e9d3d3f774d18bb4e93076f78329823
This commit is contained in:
		
							parent
							
								
									4b78e95eeb
								
							
						
					
					
						commit
						e053d7c456
					
				@ -1,6 +1,6 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
from typing import List
 | 
			
		||||
from typing import List, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
@ -1539,3 +1539,28 @@ def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True):
 | 
			
		||||
 | 
			
		||||
    tex = Textures(**kwargs)
 | 
			
		||||
    return Meshes(verts=verts, faces=faces, textures=tex)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def join_mesh(meshes: Union[Meshes, List[Meshes]]) -> Meshes:
 | 
			
		||||
    """
 | 
			
		||||
    Joins a batch of meshes in the form of a Meshes object or a list of Meshes
 | 
			
		||||
    objects as a single mesh. If the input is a list, the Meshes objects in the list
 | 
			
		||||
    must all be on the same device. This version ignores all textures in the input mehses.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        meshes: Meshes object that contains a batch of meshes or a list of Meshes objects
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        new Meshes object containing a single mesh
 | 
			
		||||
    """
 | 
			
		||||
    if isinstance(meshes, List):
 | 
			
		||||
        meshes = join_meshes_as_batch(meshes, include_textures=False)
 | 
			
		||||
 | 
			
		||||
    if len(meshes) == 1:
 | 
			
		||||
        return meshes
 | 
			
		||||
    verts = meshes.verts_packed()  # (sum(V_n), 3)
 | 
			
		||||
    # Offset automatically done by faces_packed
 | 
			
		||||
    faces = meshes.faces_packed()  # (sum(F_n), 3)
 | 
			
		||||
 | 
			
		||||
    mesh = Meshes(verts=verts.unsqueeze(0), faces=faces.unsqueeze(0))
 | 
			
		||||
    return mesh
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_joined_spheres_flat.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								tests/data/test_joined_spheres_flat.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 26 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_joined_spheres_gouraud.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								tests/data/test_joined_spheres_gouraud.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 21 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_joined_spheres_phong.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								tests/data/test_joined_spheres_phong.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 21 KiB  | 
@ -14,6 +14,7 @@ from pytorch3d.io.mtl_io import (
 | 
			
		||||
    _bilinear_interpolation_vectorized,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.structures import Meshes, Textures, join_meshes_as_batch
 | 
			
		||||
from pytorch3d.structures.meshes import join_mesh
 | 
			
		||||
from pytorch3d.utils import torus
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -648,6 +649,42 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        self.assertClose(cow3_tea.verts_list()[3], mesh_teapot.verts_list()[0])
 | 
			
		||||
        self.assertClose(cow3_tea.faces_list()[3], mesh_teapot.faces_list()[0])
 | 
			
		||||
 | 
			
		||||
    def test_join_meshes(self):
 | 
			
		||||
        """
 | 
			
		||||
        Test that join_mesh joins single meshes and the corresponding values are
 | 
			
		||||
        consistent with the single meshes.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        # Load cow mesh.
 | 
			
		||||
        DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
 | 
			
		||||
        cow_obj = DATA_DIR / "cow_mesh/cow.obj"
 | 
			
		||||
 | 
			
		||||
        cow_mesh = load_objs_as_meshes([cow_obj])
 | 
			
		||||
        cow_verts, cow_faces = cow_mesh.get_mesh_verts_faces(0)
 | 
			
		||||
        # Join a batch of three single meshes and check that the values are consistent
 | 
			
		||||
        # with the individual meshes.
 | 
			
		||||
        cow_mesh3 = join_mesh([cow_mesh, cow_mesh, cow_mesh])
 | 
			
		||||
 | 
			
		||||
        def check_item(x, y, offset):
 | 
			
		||||
            self.assertClose(torch.cat([x, x + offset, x + 2 * offset], dim=1), y)
 | 
			
		||||
 | 
			
		||||
        check_item(cow_mesh.verts_padded(), cow_mesh3.verts_padded(), 0)
 | 
			
		||||
        check_item(cow_mesh.faces_padded(), cow_mesh3.faces_padded(), cow_mesh._V)
 | 
			
		||||
 | 
			
		||||
        # Test the joining of meshes of different sizes.
 | 
			
		||||
        teapot_obj = DATA_DIR / "teapot.obj"
 | 
			
		||||
        teapot_mesh = load_objs_as_meshes([teapot_obj])
 | 
			
		||||
        teapot_verts, teapot_faces = teapot_mesh.get_mesh_verts_faces(0)
 | 
			
		||||
 | 
			
		||||
        mix_mesh = join_mesh([cow_mesh, teapot_mesh])
 | 
			
		||||
        mix_verts, mix_faces = mix_mesh.get_mesh_verts_faces(0)
 | 
			
		||||
        self.assertEqual(len(mix_mesh), 1)
 | 
			
		||||
 | 
			
		||||
        self.assertClose(mix_verts[: cow_mesh._V], cow_verts)
 | 
			
		||||
        self.assertClose(mix_faces[: cow_mesh._F], cow_faces)
 | 
			
		||||
        self.assertClose(mix_verts[cow_mesh._V :], teapot_verts)
 | 
			
		||||
        self.assertClose(mix_faces[cow_mesh._F :], teapot_faces + cow_mesh._V)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _bm_save_obj(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):
 | 
			
		||||
        return lambda: save_obj(StringIO(), verts, faces, decimal_places)
 | 
			
		||||
 | 
			
		||||
@ -26,7 +26,7 @@ from pytorch3d.renderer.mesh.shader import (
 | 
			
		||||
    TexturedSoftPhongShader,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.renderer.mesh.texturing import Textures
 | 
			
		||||
from pytorch3d.structures.meshes import Meshes
 | 
			
		||||
from pytorch3d.structures.meshes import Meshes, join_mesh
 | 
			
		||||
from pytorch3d.utils.ico_sphere import ico_sphere
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -176,7 +176,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        # Init renderer
 | 
			
		||||
        rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
 | 
			
		||||
        shaders = {
 | 
			
		||||
            "phong": HardGouraudShader,
 | 
			
		||||
            "phong": HardPhongShader,
 | 
			
		||||
            "gouraud": HardGouraudShader,
 | 
			
		||||
            "flat": HardFlatShader,
 | 
			
		||||
        }
 | 
			
		||||
@ -369,3 +369,70 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            self.assertClose(rgb, image_ref, atol=0.05)
 | 
			
		||||
 | 
			
		||||
    def test_joined_spheres(self):
 | 
			
		||||
        """
 | 
			
		||||
        Test a list of Meshes can be joined as a single mesh and
 | 
			
		||||
        the single mesh is rendered correctly with Phong, Gouraud
 | 
			
		||||
        and Flat Shaders.
 | 
			
		||||
        """
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
 | 
			
		||||
        # Init mesh with vertex textures.
 | 
			
		||||
        # Initialize a list containing two ico spheres of different sizes.
 | 
			
		||||
        sphere_list = [ico_sphere(3, device), ico_sphere(4, device)]
 | 
			
		||||
        # [(42 verts, 80 faces), (162 verts, 320 faces)]
 | 
			
		||||
        # The scale the vertices need to be set at to resize the spheres
 | 
			
		||||
        scales = [0.25, 1]
 | 
			
		||||
        # The distance the spheres ought to be offset horizontally to prevent overlap.
 | 
			
		||||
        offsets = [1.2, -0.3]
 | 
			
		||||
        # Initialize a list containing the adjusted sphere meshes.
 | 
			
		||||
        sphere_mesh_list = []
 | 
			
		||||
        for i in range(len(sphere_list)):
 | 
			
		||||
            verts = sphere_list[i].verts_padded() * scales[i]
 | 
			
		||||
            verts[0, :, 0] += offsets[i]
 | 
			
		||||
            sphere_mesh_list.append(
 | 
			
		||||
                Meshes(verts=verts, faces=sphere_list[i].faces_padded())
 | 
			
		||||
            )
 | 
			
		||||
        joined_sphere_mesh = join_mesh(sphere_mesh_list)
 | 
			
		||||
        joined_sphere_mesh.textures = Textures(
 | 
			
		||||
            verts_rgb=torch.ones_like(joined_sphere_mesh.verts_padded())
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Init rasterizer settings
 | 
			
		||||
        R, T = look_at_view_transform(2.7, 0.0, 0.0)
 | 
			
		||||
        cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
 | 
			
		||||
        raster_settings = RasterizationSettings(
 | 
			
		||||
            image_size=512, blur_radius=0.0, faces_per_pixel=1
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Init shader settings
 | 
			
		||||
        materials = Materials(device=device)
 | 
			
		||||
        lights = PointLights(device=device)
 | 
			
		||||
        lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
 | 
			
		||||
        blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0))
 | 
			
		||||
 | 
			
		||||
        # Init renderer
 | 
			
		||||
        rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
 | 
			
		||||
        shaders = {
 | 
			
		||||
            "phong": HardPhongShader,
 | 
			
		||||
            "gouraud": HardGouraudShader,
 | 
			
		||||
            "flat": HardFlatShader,
 | 
			
		||||
        }
 | 
			
		||||
        for (name, shader_init) in shaders.items():
 | 
			
		||||
            shader = shader_init(
 | 
			
		||||
                lights=lights,
 | 
			
		||||
                cameras=cameras,
 | 
			
		||||
                materials=materials,
 | 
			
		||||
                blend_params=blend_params,
 | 
			
		||||
            )
 | 
			
		||||
            renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
 | 
			
		||||
            image = renderer(joined_sphere_mesh)
 | 
			
		||||
            rgb = image[..., :3].squeeze().cpu()
 | 
			
		||||
            if DEBUG:
 | 
			
		||||
                file_name = "DEBUG_joined_spheres_%s.png" % name
 | 
			
		||||
                Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
 | 
			
		||||
                    DATA_DIR / file_name
 | 
			
		||||
                )
 | 
			
		||||
            image_ref = load_rgb_image("test_joined_spheres_%s.png" % name, DATA_DIR)
 | 
			
		||||
            self.assertClose(rgb, image_ref, atol=0.05)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user