diff --git a/pytorch3d/structures/meshes.py b/pytorch3d/structures/meshes.py index a72eb5dc..32c2b692 100644 --- a/pytorch3d/structures/meshes.py +++ b/pytorch3d/structures/meshes.py @@ -1,6 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from typing import List +from typing import List, Union import torch @@ -1539,3 +1539,28 @@ def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True): tex = Textures(**kwargs) return Meshes(verts=verts, faces=faces, textures=tex) + + +def join_mesh(meshes: Union[Meshes, List[Meshes]]) -> Meshes: + """ + Joins a batch of meshes in the form of a Meshes object or a list of Meshes + objects as a single mesh. If the input is a list, the Meshes objects in the list + must all be on the same device. This version ignores all textures in the input mehses. + + Args: + meshes: Meshes object that contains a batch of meshes or a list of Meshes objects + + Returns: + new Meshes object containing a single mesh + """ + if isinstance(meshes, List): + meshes = join_meshes_as_batch(meshes, include_textures=False) + + if len(meshes) == 1: + return meshes + verts = meshes.verts_packed() # (sum(V_n), 3) + # Offset automatically done by faces_packed + faces = meshes.faces_packed() # (sum(F_n), 3) + + mesh = Meshes(verts=verts.unsqueeze(0), faces=faces.unsqueeze(0)) + return mesh diff --git a/tests/data/test_joined_spheres_flat.png b/tests/data/test_joined_spheres_flat.png new file mode 100644 index 00000000..71810990 Binary files /dev/null and b/tests/data/test_joined_spheres_flat.png differ diff --git a/tests/data/test_joined_spheres_gouraud.png b/tests/data/test_joined_spheres_gouraud.png new file mode 100644 index 00000000..b34052c2 Binary files /dev/null and b/tests/data/test_joined_spheres_gouraud.png differ diff --git a/tests/data/test_joined_spheres_phong.png b/tests/data/test_joined_spheres_phong.png new file mode 100644 index 00000000..13d254a0 Binary files /dev/null and b/tests/data/test_joined_spheres_phong.png differ diff --git a/tests/test_obj_io.py b/tests/test_obj_io.py index 43e00a53..43f2a411 100644 --- a/tests/test_obj_io.py +++ b/tests/test_obj_io.py @@ -14,6 +14,7 @@ from pytorch3d.io.mtl_io import ( _bilinear_interpolation_vectorized, ) from pytorch3d.structures import Meshes, Textures, join_meshes_as_batch +from pytorch3d.structures.meshes import join_mesh from pytorch3d.utils import torus @@ -648,6 +649,42 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): self.assertClose(cow3_tea.verts_list()[3], mesh_teapot.verts_list()[0]) self.assertClose(cow3_tea.faces_list()[3], mesh_teapot.faces_list()[0]) + def test_join_meshes(self): + """ + Test that join_mesh joins single meshes and the corresponding values are + consistent with the single meshes. + """ + + # Load cow mesh. + DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data" + cow_obj = DATA_DIR / "cow_mesh/cow.obj" + + cow_mesh = load_objs_as_meshes([cow_obj]) + cow_verts, cow_faces = cow_mesh.get_mesh_verts_faces(0) + # Join a batch of three single meshes and check that the values are consistent + # with the individual meshes. + cow_mesh3 = join_mesh([cow_mesh, cow_mesh, cow_mesh]) + + def check_item(x, y, offset): + self.assertClose(torch.cat([x, x + offset, x + 2 * offset], dim=1), y) + + check_item(cow_mesh.verts_padded(), cow_mesh3.verts_padded(), 0) + check_item(cow_mesh.faces_padded(), cow_mesh3.faces_padded(), cow_mesh._V) + + # Test the joining of meshes of different sizes. + teapot_obj = DATA_DIR / "teapot.obj" + teapot_mesh = load_objs_as_meshes([teapot_obj]) + teapot_verts, teapot_faces = teapot_mesh.get_mesh_verts_faces(0) + + mix_mesh = join_mesh([cow_mesh, teapot_mesh]) + mix_verts, mix_faces = mix_mesh.get_mesh_verts_faces(0) + self.assertEqual(len(mix_mesh), 1) + + self.assertClose(mix_verts[: cow_mesh._V], cow_verts) + self.assertClose(mix_faces[: cow_mesh._F], cow_faces) + self.assertClose(mix_verts[cow_mesh._V :], teapot_verts) + self.assertClose(mix_faces[cow_mesh._F :], teapot_faces + cow_mesh._V) + @staticmethod def _bm_save_obj(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int): return lambda: save_obj(StringIO(), verts, faces, decimal_places) diff --git a/tests/test_render_meshes.py b/tests/test_render_meshes.py index 87c83409..0ae19471 100644 --- a/tests/test_render_meshes.py +++ b/tests/test_render_meshes.py @@ -26,7 +26,7 @@ from pytorch3d.renderer.mesh.shader import ( TexturedSoftPhongShader, ) from pytorch3d.renderer.mesh.texturing import Textures -from pytorch3d.structures.meshes import Meshes +from pytorch3d.structures.meshes import Meshes, join_mesh from pytorch3d.utils.ico_sphere import ico_sphere @@ -176,7 +176,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): # Init renderer rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings) shaders = { - "phong": HardGouraudShader, + "phong": HardPhongShader, "gouraud": HardGouraudShader, "flat": HardFlatShader, } @@ -369,3 +369,70 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): ) self.assertClose(rgb, image_ref, atol=0.05) + + def test_joined_spheres(self): + """ + Test a list of Meshes can be joined as a single mesh and + the single mesh is rendered correctly with Phong, Gouraud + and Flat Shaders. + """ + device = torch.device("cuda:0") + + # Init mesh with vertex textures. + # Initialize a list containing two ico spheres of different sizes. + sphere_list = [ico_sphere(3, device), ico_sphere(4, device)] + # [(42 verts, 80 faces), (162 verts, 320 faces)] + # The scale the vertices need to be set at to resize the spheres + scales = [0.25, 1] + # The distance the spheres ought to be offset horizontally to prevent overlap. + offsets = [1.2, -0.3] + # Initialize a list containing the adjusted sphere meshes. + sphere_mesh_list = [] + for i in range(len(sphere_list)): + verts = sphere_list[i].verts_padded() * scales[i] + verts[0, :, 0] += offsets[i] + sphere_mesh_list.append( + Meshes(verts=verts, faces=sphere_list[i].faces_padded()) + ) + joined_sphere_mesh = join_mesh(sphere_mesh_list) + joined_sphere_mesh.textures = Textures( + verts_rgb=torch.ones_like(joined_sphere_mesh.verts_padded()) + ) + + # Init rasterizer settings + R, T = look_at_view_transform(2.7, 0.0, 0.0) + cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) + raster_settings = RasterizationSettings( + image_size=512, blur_radius=0.0, faces_per_pixel=1 + ) + + # Init shader settings + materials = Materials(device=device) + lights = PointLights(device=device) + lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None] + blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0)) + + # Init renderer + rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings) + shaders = { + "phong": HardPhongShader, + "gouraud": HardGouraudShader, + "flat": HardFlatShader, + } + for (name, shader_init) in shaders.items(): + shader = shader_init( + lights=lights, + cameras=cameras, + materials=materials, + blend_params=blend_params, + ) + renderer = MeshRenderer(rasterizer=rasterizer, shader=shader) + image = renderer(joined_sphere_mesh) + rgb = image[..., :3].squeeze().cpu() + if DEBUG: + file_name = "DEBUG_joined_spheres_%s.png" % name + Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / file_name + ) + image_ref = load_rgb_image("test_joined_spheres_%s.png" % name, DATA_DIR) + self.assertClose(rgb, image_ref, atol=0.05)