Single function to load meshes from OBJs. join_meshes.

Summary:
Create the textures and the Meshes object from OBJ files in a single call.

There is functionality in OBJ files (like normals) which is ignored by this function.

Reviewed By: gkioxari

Differential Revision: D19691699

fbshipit-source-id: e26442ed80ff231b65b17d6c54c9d41e22b4e4a3
This commit is contained in:
Jeremy Reizenstein 2020-02-13 03:36:39 -08:00 committed by Facebook Github Bot
parent 23bb27956a
commit 8fe65d5f56
8 changed files with 218 additions and 44 deletions

View File

@ -55,6 +55,9 @@ tex = Textures(verts_uvs=verts_uvs, faces_uvs=faces_uvs, maps=texture_image)
# Initialise the mesh with textures # Initialise the mesh with textures
meshes = Meshes(verts=[verts], faces=[faces.verts_idx], textures=tex) meshes = Meshes(verts=[verts], faces=[faces.verts_idx], textures=tex)
``` ```
The `load_objs_as_meshes` function provides this procedure.
## PLY ## PLY
Ply files are flexible in the way they store additional information, pytorch3d Ply files are flexible in the way they store additional information, pytorch3d

View File

@ -87,7 +87,7 @@
"from skimage.io import imread\n", "from skimage.io import imread\n",
"\n", "\n",
"# Util function for loading meshes\n", "# Util function for loading meshes\n",
"from pytorch3d.io import load_obj\n", "from pytorch3d.io import load_objs_as_meshes\n",
"\n", "\n",
"# Data structures and functions for rendering\n", "# Data structures and functions for rendering\n",
"from pytorch3d.structures import Meshes, Textures\n", "from pytorch3d.structures import Meshes, Textures\n",
@ -232,25 +232,8 @@
"obj_filename = os.path.join(DATA_DIR, \"cow_mesh/cow.obj\")\n", "obj_filename = os.path.join(DATA_DIR, \"cow_mesh/cow.obj\")\n",
"\n", "\n",
"# Load obj file\n", "# Load obj file\n",
"verts, faces, aux = load_obj(obj_filename)\n", "mesh = load_objs_as_meshes([obj_filename], device=device)\n",
"faces_idx = faces.verts_idx.to(device)\n", "texture_image=mesh.textures.maps_padded()"
"verts = verts.to(device)\n",
"\n",
"# Get textures from the outputs of the load_obj function\n",
"# the `aux` variable contains the texture maps and vertex uv coordinates. \n",
"# Refer to the `obj_io.load_obj` function for full API reference. \n",
"# Here we only have one texture map for the whole mesh. \n",
"verts_uvs = aux.verts_uvs[None, ...].to(device) # (N, V, 2)\n",
"faces_uvs = faces.textures_idx[None, ...].to(device) # (N, F, 3)\n",
"tex_maps = aux.texture_images\n",
"texture_image = list(tex_maps.values())[0]\n",
"texture_image = texture_image[None, ...].to(device) # (N, H, W, 3)\n",
"\n",
"# Create a textures object\n",
"tex = Textures(verts_uvs=verts_uvs, faces_uvs=faces_uvs, maps=texture_image)\n",
"\n",
"# Create a meshes object with textures\n",
"mesh = Meshes(verts=[verts], faces=[faces_idx], textures=tex)"
] ]
}, },
{ {

View File

@ -1,7 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .obj_io import load_obj, save_obj from .obj_io import load_obj, load_objs_as_meshes, save_obj
from .ply_io import load_ply, save_ply from .ply_io import load_ply, save_ply
__all__ = [k for k in globals().keys() if not k.startswith("_")] __all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@ -13,6 +13,8 @@ import torch
from fvcore.common.file_io import PathManager from fvcore.common.file_io import PathManager
from PIL import Image from PIL import Image
from pytorch3d.structures import Meshes, Textures, join_meshes
def _read_image(file_name: str, format=None): def _read_image(file_name: str, format=None):
""" """
@ -90,7 +92,7 @@ def _open_file(f):
def load_obj(f_obj, load_textures=True): def load_obj(f_obj, load_textures=True):
""" """
Load a mesh and textures from a .obj and .mtl file. Load a mesh from a .obj file and optionally textures from a .mtl file.
Currently this handles verts, faces, vertex texture uv coordinates, normals, Currently this handles verts, faces, vertex texture uv coordinates, normals,
texture images and material reflectivity values. texture images and material reflectivity values.
@ -208,6 +210,44 @@ def load_obj(f_obj, load_textures=True):
f_obj.close() f_obj.close()
def load_objs_as_meshes(files: list, device=None, load_textures: bool = True):
"""
Load meshes from a list of .obj files using the load_obj function, and
return them as a Meshes object. This only works for meshes which have a
single texture image for the whole mesh. See the load_obj function for more
details. material_colors and normals are not stored.
Args:
f: A list of file-like objects (with methods read, readline, tell,
and seek), pathlib paths or strings containing file names.
device: Desired device of returned Meshes. Default:
uses the current device for the default tensor type.
load_textures: Boolean indicating whether material files are loaded
Returns:
New Meshes object.
"""
mesh_list = []
for f_obj in files:
verts, faces, aux = load_obj(f_obj, load_textures=load_textures)
verts = verts.to(device)
tex = None
tex_maps = aux.texture_images
if tex_maps is not None and len(tex_maps) > 0:
verts_uvs = aux.verts_uvs[None, ...].to(device) # (1, V, 2)
faces_uvs = faces.textures_idx[None, ...].to(device) # (1, F, 3)
image = list(tex_maps.values())[0].to(device)[None]
tex = Textures(verts_uvs=verts_uvs, faces_uvs=faces_uvs, maps=image)
mesh = Meshes(
verts=[verts], faces=[faces.verts_idx.to(device)], textures=tex
)
mesh_list.append(mesh)
if len(mesh_list) == 1:
return mesh_list[0]
return join_meshes(mesh_list)
def _parse_face( def _parse_face(
line, line,
material_idx, material_idx,

View File

@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .meshes import Meshes from .meshes import Meshes, join_meshes
from .textures import Textures from .textures import Textures
from .utils import ( from .utils import (
list_to_packed, list_to_packed,

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import List
import torch import torch
from pytorch3d import _C from pytorch3d import _C
@ -1365,3 +1366,77 @@ class Meshes(object):
if self.textures is not None: if self.textures is not None:
tex = self.textures.extend(N) tex = self.textures.extend(N)
return Meshes(verts=new_verts_list, faces=new_faces_list, textures=tex) return Meshes(verts=new_verts_list, faces=new_faces_list, textures=tex)
def join_meshes(meshes: List[Meshes], include_textures: bool = True):
"""
Merge multiple Meshes objects, i.e. concatenate the meshes objects. They
must all be on the same device. If include_textures is true, they must all
be compatible, either all or none having textures, and all the Textures
objects having the same members. If include_textures is False, textures are
ignored.
Args:
meshes: list of meshes.
include_textures: (bool) whether to try to join the textures.
Returns:
new Meshes object containing all the meshes from all the inputs.
"""
if isinstance(meshes, Meshes):
# Meshes objects can be iterated and produce single Meshes. We avoid
# letting join_meshes(mesh1, mesh2) silently do the wrong thing.
raise ValueError("Wrong first argument to join_meshes.")
verts = [v for mesh in meshes for v in mesh.verts_list()]
faces = [f for mesh in meshes for f in mesh.faces_list()]
if len(meshes) == 0 or not include_textures:
return Meshes(verts=verts, faces=faces)
if meshes[0].textures is None:
if any(mesh.textures is not None for mesh in meshes):
raise ValueError("Inconsistent textures in join_meshes.")
return Meshes(verts=verts, faces=faces)
if any(mesh.textures is None for mesh in meshes):
raise ValueError("Inconsistent textures in join_meshes.")
# Now we know there are multiple meshes and they have textures to merge.
first = meshes[0].textures
kwargs = {}
if first.maps_padded() is not None:
if any(mesh.textures.maps_padded() is None for mesh in meshes):
raise ValueError("Inconsistent maps_padded in join_meshes.")
maps = [m for mesh in meshes for m in mesh.textures.maps_padded()]
kwargs["maps"] = maps
elif any(mesh.textures.maps_padded() is not None for mesh in meshes):
raise ValueError("Inconsistent maps_padded in join_meshes.")
if first.verts_uvs_padded() is not None:
if any(mesh.textures.verts_uvs_padded() is None for mesh in meshes):
raise ValueError("Inconsistent verts_uvs_padded in join_meshes.")
uvs = [uv for mesh in meshes for uv in mesh.textures.verts_uvs_list()]
V = max(uv.shape[0] for uv in uvs)
kwargs["verts_uvs"] = struct_utils.list_to_padded(uvs, (V, 2), -1)
elif any(mesh.textures.verts_uvs_padded() is not None for mesh in meshes):
raise ValueError("Inconsistent verts_uvs_padded in join_meshes.")
if first.faces_uvs_padded() is not None:
if any(mesh.textures.faces_uvs_padded() is None for mesh in meshes):
raise ValueError("Inconsistent faces_uvs_padded in join_meshes.")
uvs = [uv for mesh in meshes for uv in mesh.textures.faces_uvs_list()]
F = max(uv.shape[0] for uv in uvs)
kwargs["faces_uvs"] = struct_utils.list_to_padded(uvs, (F, 3), -1)
elif any(mesh.textures.faces_uvs_padded() is not None for mesh in meshes):
raise ValueError("Inconsistent faces_uvs_padded in join_meshes.")
if first.verts_rgb_padded() is not None:
if any(mesh.textures.verts_rgb_padded() is None for mesh in meshes):
raise ValueError("Inconsistent verts_rgb_padded in join_meshes.")
rgb = [i for mesh in meshes for i in mesh.textures.verts_rgb_list()]
V = max(i.shape[0] for i in rgb)
kwargs["verts_rgb"] = struct_utils.list_to_padded(rgb, (V, 3))
elif any(mesh.textures.verts_rgb_padded() is not None for mesh in meshes):
raise ValueError("Inconsistent verts_rgb_padded in join_meshes.")
tex = Textures(**kwargs)
return Meshes(verts=verts, faces=faces, textures=tex)

View File

@ -7,10 +7,13 @@ from io import StringIO
from pathlib import Path from pathlib import Path
import torch import torch
from pytorch3d.io import load_obj, save_obj from pytorch3d.io import load_obj, load_objs_as_meshes, save_obj
from pytorch3d.structures import Meshes, Textures, join_meshes
from common_testing import TestCaseMixin
class TestMeshObjIO(unittest.TestCase): class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
def test_load_obj_simple(self): def test_load_obj_simple(self):
obj_file = "\n".join( obj_file = "\n".join(
[ [
@ -517,6 +520,88 @@ class TestMeshObjIO(unittest.TestCase):
self.assertTrue(aux.material_colors is None) self.assertTrue(aux.material_colors is None)
self.assertTrue(aux.texture_images is None) self.assertTrue(aux.texture_images is None)
def test_join_meshes(self):
"""
Test that join_meshes and load_objs_as_meshes are consistent with single
meshes.
"""
def check_triple(mesh, mesh3):
"""
Verify that mesh3 is three copies of mesh.
"""
def check_item(x, y):
self.assertEqual(x is None, y is None)
if x is not None:
self.assertClose(torch.cat([x, x, x]), y)
check_item(mesh.verts_padded(), mesh3.verts_padded())
check_item(mesh.faces_padded(), mesh3.faces_padded())
if mesh.textures is not None:
check_item(
mesh.textures.maps_padded(), mesh3.textures.maps_padded()
)
check_item(
mesh.textures.faces_uvs_padded(),
mesh3.textures.faces_uvs_padded(),
)
check_item(
mesh.textures.verts_uvs_padded(),
mesh3.textures.verts_uvs_padded(),
)
check_item(
mesh.textures.verts_rgb_padded(),
mesh3.textures.verts_rgb_padded(),
)
DATA_DIR = (
Path(__file__).resolve().parent.parent / "docs/tutorials/data"
)
obj_filename = DATA_DIR / "cow_mesh/cow.obj"
mesh = load_objs_as_meshes([obj_filename])
mesh3 = load_objs_as_meshes([obj_filename, obj_filename, obj_filename])
check_triple(mesh, mesh3)
self.assertTupleEqual(
mesh.textures.maps_padded().shape, (1, 1024, 1024, 3)
)
mesh_notex = load_objs_as_meshes([obj_filename], load_textures=False)
mesh3_notex = load_objs_as_meshes(
[obj_filename, obj_filename, obj_filename], load_textures=False
)
check_triple(mesh_notex, mesh3_notex)
self.assertIsNone(mesh_notex.textures)
verts = torch.randn((4, 3), dtype=torch.float32)
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
vert_tex = torch.tensor(
[[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32
)
tex = Textures(verts_rgb=vert_tex[None, :])
mesh_rgb = Meshes(verts=[verts], faces=[faces], textures=tex)
mesh_rgb3 = join_meshes([mesh_rgb, mesh_rgb, mesh_rgb])
check_triple(mesh_rgb, mesh_rgb3)
teapot_obj = DATA_DIR / "teapot.obj"
mesh_teapot = load_objs_as_meshes([teapot_obj])
teapot_verts, teapot_faces = mesh_teapot.get_mesh_verts_faces(0)
mix_mesh = load_objs_as_meshes(
[obj_filename, teapot_obj], load_textures=False
)
self.assertEqual(len(mix_mesh), 2)
self.assertClose(mix_mesh.verts_list()[0], mesh.verts_list()[0])
self.assertClose(mix_mesh.faces_list()[0], mesh.faces_list()[0])
self.assertClose(mix_mesh.verts_list()[1], teapot_verts)
self.assertClose(mix_mesh.faces_list()[1], teapot_faces)
cow3_tea = join_meshes([mesh3, mesh_teapot], include_textures=False)
self.assertEqual(len(cow3_tea), 4)
check_triple(mesh_notex, cow3_tea[:3])
self.assertClose(cow3_tea.verts_list()[3], mesh_teapot.verts_list()[0])
self.assertClose(cow3_tea.faces_list()[3], mesh_teapot.faces_list()[0])
@staticmethod @staticmethod
def save_obj_with_init(V: int, F: int): def save_obj_with_init(V: int, F: int):
verts_list = torch.tensor(V * [[0.11, 0.22, 0.33]]).view(-1, 3) verts_list = torch.tensor(V * [[0.11, 0.22, 0.33]]).view(-1, 3)

View File

@ -11,7 +11,7 @@ from pathlib import Path
import torch import torch
from PIL import Image from PIL import Image
from pytorch3d.io import load_obj from pytorch3d.io import load_objs_as_meshes
from pytorch3d.renderer.cameras import ( from pytorch3d.renderer.cameras import (
OpenGLPerspectiveCameras, OpenGLPerspectiveCameras,
look_at_view_transform, look_at_view_transform,
@ -274,21 +274,7 @@ class TestRenderingMeshes(unittest.TestCase):
obj_filename = DATA_DIR / "cow_mesh/cow.obj" obj_filename = DATA_DIR / "cow_mesh/cow.obj"
# Load mesh + texture # Load mesh + texture
verts, faces, aux = load_obj(obj_filename) mesh = load_objs_as_meshes([obj_filename], device=device)
faces_idx = faces.verts_idx.to(device)
verts = verts.to(device)
texture_uvs = aux.verts_uvs
materials = aux.material_colors
tex_maps = aux.texture_images
# tex_maps is a dictionary of material names as keys and texture images
# as values. Only need the images for this example.
textures = Textures(
maps=list(tex_maps.values()),
faces_uvs=faces.textures_idx.to(torch.int64).to(device)[None, :],
verts_uvs=texture_uvs.to(torch.float32).to(device)[None, :],
)
mesh = Meshes(verts=[verts], faces=[faces_idx], textures=textures)
# Init rasterizer settings # Init rasterizer settings
R, T = look_at_view_transform(2.7, 10, 20) R, T = look_at_view_transform(2.7, 10, 20)
@ -333,9 +319,11 @@ class TestRenderingMeshes(unittest.TestCase):
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05)) self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
# Check grad exists # Check grad exists
verts = verts.clone() [verts] = mesh.verts_list()
verts.requires_grad = True verts.requires_grad = True
mesh = Meshes(verts=[verts], faces=[faces_idx], textures=textures) mesh2 = Meshes(
images = renderer(mesh) verts=[verts], faces=mesh.faces_list(), textures=mesh.textures
)
images = renderer(mesh2)
images[0, ...].sum().backward() images[0, ...].sum().backward()
self.assertIsNotNone(verts.grad) self.assertIsNotNone(verts.grad)