Single function to load meshes from OBJs. join_meshes.

Summary:
Create the textures and the Meshes object from OBJ files in a single call.

There is functionality in OBJ files (like normals) which is ignored by this function.

Reviewed By: gkioxari

Differential Revision: D19691699

fbshipit-source-id: e26442ed80ff231b65b17d6c54c9d41e22b4e4a3
This commit is contained in:
Jeremy Reizenstein
2020-02-13 03:36:39 -08:00
committed by Facebook Github Bot
parent 23bb27956a
commit 8fe65d5f56
8 changed files with 218 additions and 44 deletions

View File

@@ -7,10 +7,13 @@ from io import StringIO
from pathlib import Path
import torch
from pytorch3d.io import load_obj, save_obj
from pytorch3d.io import load_obj, load_objs_as_meshes, save_obj
from pytorch3d.structures import Meshes, Textures, join_meshes
from common_testing import TestCaseMixin
class TestMeshObjIO(unittest.TestCase):
class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
def test_load_obj_simple(self):
obj_file = "\n".join(
[
@@ -517,6 +520,88 @@ class TestMeshObjIO(unittest.TestCase):
self.assertTrue(aux.material_colors is None)
self.assertTrue(aux.texture_images is None)
def test_join_meshes(self):
"""
Test that join_meshes and load_objs_as_meshes are consistent with single
meshes.
"""
def check_triple(mesh, mesh3):
"""
Verify that mesh3 is three copies of mesh.
"""
def check_item(x, y):
self.assertEqual(x is None, y is None)
if x is not None:
self.assertClose(torch.cat([x, x, x]), y)
check_item(mesh.verts_padded(), mesh3.verts_padded())
check_item(mesh.faces_padded(), mesh3.faces_padded())
if mesh.textures is not None:
check_item(
mesh.textures.maps_padded(), mesh3.textures.maps_padded()
)
check_item(
mesh.textures.faces_uvs_padded(),
mesh3.textures.faces_uvs_padded(),
)
check_item(
mesh.textures.verts_uvs_padded(),
mesh3.textures.verts_uvs_padded(),
)
check_item(
mesh.textures.verts_rgb_padded(),
mesh3.textures.verts_rgb_padded(),
)
DATA_DIR = (
Path(__file__).resolve().parent.parent / "docs/tutorials/data"
)
obj_filename = DATA_DIR / "cow_mesh/cow.obj"
mesh = load_objs_as_meshes([obj_filename])
mesh3 = load_objs_as_meshes([obj_filename, obj_filename, obj_filename])
check_triple(mesh, mesh3)
self.assertTupleEqual(
mesh.textures.maps_padded().shape, (1, 1024, 1024, 3)
)
mesh_notex = load_objs_as_meshes([obj_filename], load_textures=False)
mesh3_notex = load_objs_as_meshes(
[obj_filename, obj_filename, obj_filename], load_textures=False
)
check_triple(mesh_notex, mesh3_notex)
self.assertIsNone(mesh_notex.textures)
verts = torch.randn((4, 3), dtype=torch.float32)
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
vert_tex = torch.tensor(
[[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32
)
tex = Textures(verts_rgb=vert_tex[None, :])
mesh_rgb = Meshes(verts=[verts], faces=[faces], textures=tex)
mesh_rgb3 = join_meshes([mesh_rgb, mesh_rgb, mesh_rgb])
check_triple(mesh_rgb, mesh_rgb3)
teapot_obj = DATA_DIR / "teapot.obj"
mesh_teapot = load_objs_as_meshes([teapot_obj])
teapot_verts, teapot_faces = mesh_teapot.get_mesh_verts_faces(0)
mix_mesh = load_objs_as_meshes(
[obj_filename, teapot_obj], load_textures=False
)
self.assertEqual(len(mix_mesh), 2)
self.assertClose(mix_mesh.verts_list()[0], mesh.verts_list()[0])
self.assertClose(mix_mesh.faces_list()[0], mesh.faces_list()[0])
self.assertClose(mix_mesh.verts_list()[1], teapot_verts)
self.assertClose(mix_mesh.faces_list()[1], teapot_faces)
cow3_tea = join_meshes([mesh3, mesh_teapot], include_textures=False)
self.assertEqual(len(cow3_tea), 4)
check_triple(mesh_notex, cow3_tea[:3])
self.assertClose(cow3_tea.verts_list()[3], mesh_teapot.verts_list()[0])
self.assertClose(cow3_tea.faces_list()[3], mesh_teapot.faces_list()[0])
@staticmethod
def save_obj_with_init(V: int, F: int):
verts_list = torch.tensor(V * [[0.11, 0.22, 0.33]]).view(-1, 3)

View File

@@ -11,7 +11,7 @@ from pathlib import Path
import torch
from PIL import Image
from pytorch3d.io import load_obj
from pytorch3d.io import load_objs_as_meshes
from pytorch3d.renderer.cameras import (
OpenGLPerspectiveCameras,
look_at_view_transform,
@@ -274,21 +274,7 @@ class TestRenderingMeshes(unittest.TestCase):
obj_filename = DATA_DIR / "cow_mesh/cow.obj"
# Load mesh + texture
verts, faces, aux = load_obj(obj_filename)
faces_idx = faces.verts_idx.to(device)
verts = verts.to(device)
texture_uvs = aux.verts_uvs
materials = aux.material_colors
tex_maps = aux.texture_images
# tex_maps is a dictionary of material names as keys and texture images
# as values. Only need the images for this example.
textures = Textures(
maps=list(tex_maps.values()),
faces_uvs=faces.textures_idx.to(torch.int64).to(device)[None, :],
verts_uvs=texture_uvs.to(torch.float32).to(device)[None, :],
)
mesh = Meshes(verts=[verts], faces=[faces_idx], textures=textures)
mesh = load_objs_as_meshes([obj_filename], device=device)
# Init rasterizer settings
R, T = look_at_view_transform(2.7, 10, 20)
@@ -333,9 +319,11 @@ class TestRenderingMeshes(unittest.TestCase):
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
# Check grad exists
verts = verts.clone()
[verts] = mesh.verts_list()
verts.requires_grad = True
mesh = Meshes(verts=[verts], faces=[faces_idx], textures=textures)
images = renderer(mesh)
mesh2 = Meshes(
verts=[verts], faces=mesh.faces_list(), textures=mesh.textures
)
images = renderer(mesh2)
images[0, ...].sum().backward()
self.assertIsNotNone(verts.grad)