diff --git a/docs/notes/meshes_io.md b/docs/notes/meshes_io.md index b4e00235..e6d02209 100644 --- a/docs/notes/meshes_io.md +++ b/docs/notes/meshes_io.md @@ -55,6 +55,9 @@ tex = Textures(verts_uvs=verts_uvs, faces_uvs=faces_uvs, maps=texture_image) # Initialise the mesh with textures meshes = Meshes(verts=[verts], faces=[faces.verts_idx], textures=tex) ``` + +The `load_objs_as_meshes` function provides this procedure. + ## PLY Ply files are flexible in the way they store additional information, pytorch3d diff --git a/docs/tutorials/render_textured_meshes.ipynb b/docs/tutorials/render_textured_meshes.ipynb index 613ac98a..ef44a581 100644 --- a/docs/tutorials/render_textured_meshes.ipynb +++ b/docs/tutorials/render_textured_meshes.ipynb @@ -87,7 +87,7 @@ "from skimage.io import imread\n", "\n", "# Util function for loading meshes\n", - "from pytorch3d.io import load_obj\n", + "from pytorch3d.io import load_objs_as_meshes\n", "\n", "# Data structures and functions for rendering\n", "from pytorch3d.structures import Meshes, Textures\n", @@ -232,25 +232,8 @@ "obj_filename = os.path.join(DATA_DIR, \"cow_mesh/cow.obj\")\n", "\n", "# Load obj file\n", - "verts, faces, aux = load_obj(obj_filename)\n", - "faces_idx = faces.verts_idx.to(device)\n", - "verts = verts.to(device)\n", - "\n", - "# Get textures from the outputs of the load_obj function\n", - "# the `aux` variable contains the texture maps and vertex uv coordinates. \n", - "# Refer to the `obj_io.load_obj` function for full API reference. \n", - "# Here we only have one texture map for the whole mesh. \n", - "verts_uvs = aux.verts_uvs[None, ...].to(device) # (N, V, 2)\n", - "faces_uvs = faces.textures_idx[None, ...].to(device) # (N, F, 3)\n", - "tex_maps = aux.texture_images\n", - "texture_image = list(tex_maps.values())[0]\n", - "texture_image = texture_image[None, ...].to(device) # (N, H, W, 3)\n", - "\n", - "# Create a textures object\n", - "tex = Textures(verts_uvs=verts_uvs, faces_uvs=faces_uvs, maps=texture_image)\n", - "\n", - "# Create a meshes object with textures\n", - "mesh = Meshes(verts=[verts], faces=[faces_idx], textures=tex)" + "mesh = load_objs_as_meshes([obj_filename], device=device)\n", + "texture_image=mesh.textures.maps_padded()" ] }, { diff --git a/pytorch3d/io/__init__.py b/pytorch3d/io/__init__.py index 1ef515ac..0162ae89 100644 --- a/pytorch3d/io/__init__.py +++ b/pytorch3d/io/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from .obj_io import load_obj, save_obj +from .obj_io import load_obj, load_objs_as_meshes, save_obj from .ply_io import load_ply, save_ply __all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/pytorch3d/io/obj_io.py b/pytorch3d/io/obj_io.py index d85fa14e..31ee9d38 100644 --- a/pytorch3d/io/obj_io.py +++ b/pytorch3d/io/obj_io.py @@ -13,6 +13,8 @@ import torch from fvcore.common.file_io import PathManager from PIL import Image +from pytorch3d.structures import Meshes, Textures, join_meshes + def _read_image(file_name: str, format=None): """ @@ -90,7 +92,7 @@ def _open_file(f): def load_obj(f_obj, load_textures=True): """ - Load a mesh and textures from a .obj and .mtl file. + Load a mesh from a .obj file and optionally textures from a .mtl file. Currently this handles verts, faces, vertex texture uv coordinates, normals, texture images and material reflectivity values. @@ -208,6 +210,44 @@ def load_obj(f_obj, load_textures=True): f_obj.close() +def load_objs_as_meshes(files: list, device=None, load_textures: bool = True): + """ + Load meshes from a list of .obj files using the load_obj function, and + return them as a Meshes object. This only works for meshes which have a + single texture image for the whole mesh. See the load_obj function for more + details. material_colors and normals are not stored. + + Args: + f: A list of file-like objects (with methods read, readline, tell, + and seek), pathlib paths or strings containing file names. + device: Desired device of returned Meshes. Default: + uses the current device for the default tensor type. + load_textures: Boolean indicating whether material files are loaded + + Returns: + New Meshes object. + """ + mesh_list = [] + for f_obj in files: + verts, faces, aux = load_obj(f_obj, load_textures=load_textures) + verts = verts.to(device) + tex = None + tex_maps = aux.texture_images + if tex_maps is not None and len(tex_maps) > 0: + verts_uvs = aux.verts_uvs[None, ...].to(device) # (1, V, 2) + faces_uvs = faces.textures_idx[None, ...].to(device) # (1, F, 3) + image = list(tex_maps.values())[0].to(device)[None] + tex = Textures(verts_uvs=verts_uvs, faces_uvs=faces_uvs, maps=image) + + mesh = Meshes( + verts=[verts], faces=[faces.verts_idx.to(device)], textures=tex + ) + mesh_list.append(mesh) + if len(mesh_list) == 1: + return mesh_list[0] + return join_meshes(mesh_list) + + def _parse_face( line, material_idx, diff --git a/pytorch3d/structures/__init__.py b/pytorch3d/structures/__init__.py index 78cb585a..e0911804 100644 --- a/pytorch3d/structures/__init__.py +++ b/pytorch3d/structures/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from .meshes import Meshes +from .meshes import Meshes, join_meshes from .textures import Textures from .utils import ( list_to_packed, diff --git a/pytorch3d/structures/meshes.py b/pytorch3d/structures/meshes.py index 64cf79f1..fdb6bcb3 100644 --- a/pytorch3d/structures/meshes.py +++ b/pytorch3d/structures/meshes.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +from typing import List import torch from pytorch3d import _C @@ -1365,3 +1366,77 @@ class Meshes(object): if self.textures is not None: tex = self.textures.extend(N) return Meshes(verts=new_verts_list, faces=new_faces_list, textures=tex) + + +def join_meshes(meshes: List[Meshes], include_textures: bool = True): + """ + Merge multiple Meshes objects, i.e. concatenate the meshes objects. They + must all be on the same device. If include_textures is true, they must all + be compatible, either all or none having textures, and all the Textures + objects having the same members. If include_textures is False, textures are + ignored. + + Args: + meshes: list of meshes. + include_textures: (bool) whether to try to join the textures. + + Returns: + new Meshes object containing all the meshes from all the inputs. + """ + if isinstance(meshes, Meshes): + # Meshes objects can be iterated and produce single Meshes. We avoid + # letting join_meshes(mesh1, mesh2) silently do the wrong thing. + raise ValueError("Wrong first argument to join_meshes.") + verts = [v for mesh in meshes for v in mesh.verts_list()] + faces = [f for mesh in meshes for f in mesh.faces_list()] + if len(meshes) == 0 or not include_textures: + return Meshes(verts=verts, faces=faces) + + if meshes[0].textures is None: + if any(mesh.textures is not None for mesh in meshes): + raise ValueError("Inconsistent textures in join_meshes.") + return Meshes(verts=verts, faces=faces) + + if any(mesh.textures is None for mesh in meshes): + raise ValueError("Inconsistent textures in join_meshes.") + + # Now we know there are multiple meshes and they have textures to merge. + first = meshes[0].textures + kwargs = {} + if first.maps_padded() is not None: + if any(mesh.textures.maps_padded() is None for mesh in meshes): + raise ValueError("Inconsistent maps_padded in join_meshes.") + maps = [m for mesh in meshes for m in mesh.textures.maps_padded()] + kwargs["maps"] = maps + elif any(mesh.textures.maps_padded() is not None for mesh in meshes): + raise ValueError("Inconsistent maps_padded in join_meshes.") + + if first.verts_uvs_padded() is not None: + if any(mesh.textures.verts_uvs_padded() is None for mesh in meshes): + raise ValueError("Inconsistent verts_uvs_padded in join_meshes.") + uvs = [uv for mesh in meshes for uv in mesh.textures.verts_uvs_list()] + V = max(uv.shape[0] for uv in uvs) + kwargs["verts_uvs"] = struct_utils.list_to_padded(uvs, (V, 2), -1) + elif any(mesh.textures.verts_uvs_padded() is not None for mesh in meshes): + raise ValueError("Inconsistent verts_uvs_padded in join_meshes.") + + if first.faces_uvs_padded() is not None: + if any(mesh.textures.faces_uvs_padded() is None for mesh in meshes): + raise ValueError("Inconsistent faces_uvs_padded in join_meshes.") + uvs = [uv for mesh in meshes for uv in mesh.textures.faces_uvs_list()] + F = max(uv.shape[0] for uv in uvs) + kwargs["faces_uvs"] = struct_utils.list_to_padded(uvs, (F, 3), -1) + elif any(mesh.textures.faces_uvs_padded() is not None for mesh in meshes): + raise ValueError("Inconsistent faces_uvs_padded in join_meshes.") + + if first.verts_rgb_padded() is not None: + if any(mesh.textures.verts_rgb_padded() is None for mesh in meshes): + raise ValueError("Inconsistent verts_rgb_padded in join_meshes.") + rgb = [i for mesh in meshes for i in mesh.textures.verts_rgb_list()] + V = max(i.shape[0] for i in rgb) + kwargs["verts_rgb"] = struct_utils.list_to_padded(rgb, (V, 3)) + elif any(mesh.textures.verts_rgb_padded() is not None for mesh in meshes): + raise ValueError("Inconsistent verts_rgb_padded in join_meshes.") + + tex = Textures(**kwargs) + return Meshes(verts=verts, faces=faces, textures=tex) diff --git a/tests/test_obj_io.py b/tests/test_obj_io.py index 336dbfdb..f12bb8bb 100644 --- a/tests/test_obj_io.py +++ b/tests/test_obj_io.py @@ -7,10 +7,13 @@ from io import StringIO from pathlib import Path import torch -from pytorch3d.io import load_obj, save_obj +from pytorch3d.io import load_obj, load_objs_as_meshes, save_obj +from pytorch3d.structures import Meshes, Textures, join_meshes + +from common_testing import TestCaseMixin -class TestMeshObjIO(unittest.TestCase): +class TestMeshObjIO(TestCaseMixin, unittest.TestCase): def test_load_obj_simple(self): obj_file = "\n".join( [ @@ -517,6 +520,88 @@ class TestMeshObjIO(unittest.TestCase): self.assertTrue(aux.material_colors is None) self.assertTrue(aux.texture_images is None) + def test_join_meshes(self): + """ + Test that join_meshes and load_objs_as_meshes are consistent with single + meshes. + """ + + def check_triple(mesh, mesh3): + """ + Verify that mesh3 is three copies of mesh. + """ + + def check_item(x, y): + self.assertEqual(x is None, y is None) + if x is not None: + self.assertClose(torch.cat([x, x, x]), y) + + check_item(mesh.verts_padded(), mesh3.verts_padded()) + check_item(mesh.faces_padded(), mesh3.faces_padded()) + if mesh.textures is not None: + check_item( + mesh.textures.maps_padded(), mesh3.textures.maps_padded() + ) + check_item( + mesh.textures.faces_uvs_padded(), + mesh3.textures.faces_uvs_padded(), + ) + check_item( + mesh.textures.verts_uvs_padded(), + mesh3.textures.verts_uvs_padded(), + ) + check_item( + mesh.textures.verts_rgb_padded(), + mesh3.textures.verts_rgb_padded(), + ) + + DATA_DIR = ( + Path(__file__).resolve().parent.parent / "docs/tutorials/data" + ) + obj_filename = DATA_DIR / "cow_mesh/cow.obj" + + mesh = load_objs_as_meshes([obj_filename]) + mesh3 = load_objs_as_meshes([obj_filename, obj_filename, obj_filename]) + check_triple(mesh, mesh3) + self.assertTupleEqual( + mesh.textures.maps_padded().shape, (1, 1024, 1024, 3) + ) + + mesh_notex = load_objs_as_meshes([obj_filename], load_textures=False) + mesh3_notex = load_objs_as_meshes( + [obj_filename, obj_filename, obj_filename], load_textures=False + ) + check_triple(mesh_notex, mesh3_notex) + self.assertIsNone(mesh_notex.textures) + + verts = torch.randn((4, 3), dtype=torch.float32) + faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64) + vert_tex = torch.tensor( + [[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32 + ) + tex = Textures(verts_rgb=vert_tex[None, :]) + mesh_rgb = Meshes(verts=[verts], faces=[faces], textures=tex) + mesh_rgb3 = join_meshes([mesh_rgb, mesh_rgb, mesh_rgb]) + check_triple(mesh_rgb, mesh_rgb3) + + teapot_obj = DATA_DIR / "teapot.obj" + mesh_teapot = load_objs_as_meshes([teapot_obj]) + teapot_verts, teapot_faces = mesh_teapot.get_mesh_verts_faces(0) + mix_mesh = load_objs_as_meshes( + [obj_filename, teapot_obj], load_textures=False + ) + self.assertEqual(len(mix_mesh), 2) + self.assertClose(mix_mesh.verts_list()[0], mesh.verts_list()[0]) + self.assertClose(mix_mesh.faces_list()[0], mesh.faces_list()[0]) + self.assertClose(mix_mesh.verts_list()[1], teapot_verts) + self.assertClose(mix_mesh.faces_list()[1], teapot_faces) + + cow3_tea = join_meshes([mesh3, mesh_teapot], include_textures=False) + self.assertEqual(len(cow3_tea), 4) + check_triple(mesh_notex, cow3_tea[:3]) + self.assertClose(cow3_tea.verts_list()[3], mesh_teapot.verts_list()[0]) + self.assertClose(cow3_tea.faces_list()[3], mesh_teapot.faces_list()[0]) + @staticmethod def save_obj_with_init(V: int, F: int): verts_list = torch.tensor(V * [[0.11, 0.22, 0.33]]).view(-1, 3) diff --git a/tests/test_rendering_meshes.py b/tests/test_rendering_meshes.py index b3cb1966..8422f395 100644 --- a/tests/test_rendering_meshes.py +++ b/tests/test_rendering_meshes.py @@ -11,7 +11,7 @@ from pathlib import Path import torch from PIL import Image -from pytorch3d.io import load_obj +from pytorch3d.io import load_objs_as_meshes from pytorch3d.renderer.cameras import ( OpenGLPerspectiveCameras, look_at_view_transform, @@ -274,21 +274,7 @@ class TestRenderingMeshes(unittest.TestCase): obj_filename = DATA_DIR / "cow_mesh/cow.obj" # Load mesh + texture - verts, faces, aux = load_obj(obj_filename) - faces_idx = faces.verts_idx.to(device) - verts = verts.to(device) - texture_uvs = aux.verts_uvs - materials = aux.material_colors - tex_maps = aux.texture_images - - # tex_maps is a dictionary of material names as keys and texture images - # as values. Only need the images for this example. - textures = Textures( - maps=list(tex_maps.values()), - faces_uvs=faces.textures_idx.to(torch.int64).to(device)[None, :], - verts_uvs=texture_uvs.to(torch.float32).to(device)[None, :], - ) - mesh = Meshes(verts=[verts], faces=[faces_idx], textures=textures) + mesh = load_objs_as_meshes([obj_filename], device=device) # Init rasterizer settings R, T = look_at_view_transform(2.7, 10, 20) @@ -333,9 +319,11 @@ class TestRenderingMeshes(unittest.TestCase): self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05)) # Check grad exists - verts = verts.clone() + [verts] = mesh.verts_list() verts.requires_grad = True - mesh = Meshes(verts=[verts], faces=[faces_idx], textures=textures) - images = renderer(mesh) + mesh2 = Meshes( + verts=[verts], faces=mesh.faces_list(), textures=mesh.textures + ) + images = renderer(mesh2) images[0, ...].sum().backward() self.assertIsNotNone(verts.grad)