include TexturesUV in IO.save_mesh(x.obj)

Summary:
Added export of UV textures to IO.save_mesh in Pytorch3d
MeshObjFormat now passes verts_uv, faces_uv, and texture_map as input to save_obj

TODO: check if TexturesUV.verts_uv_list or TexturesUV.verts_uv_padded() should be passed to save_obj

IO.save_mesh(obj_file, meshes, decimal_places=2) should be IO().save_mesh(obj_file, meshes, decimal_places=2)

Reviewed By: bottler

Differential Revision: D39617441

fbshipit-source-id: 4628b7f26f70e38c65f235852b990c8edb0ded23
This commit is contained in:
Michaël Ramamonjisoa 2022-09-21 06:16:48 -07:00 committed by Facebook GitHub Bot
parent 305cf32f6b
commit 6ae6ff9cf7
2 changed files with 75 additions and 0 deletions

View File

@ -334,12 +334,25 @@ class MeshObjFormat(MeshFormatInterpreter):
verts = data.verts_list()[0]
faces = data.faces_list()[0]
verts_uvs: Optional[torch.Tensor] = None
faces_uvs: Optional[torch.Tensor] = None
texture_map: Optional[torch.Tensor] = None
if isinstance(data.textures, TexturesUV):
verts_uvs = data.textures.verts_uvs_padded()[0]
faces_uvs = data.textures.faces_uvs_padded()[0]
texture_map = data.textures.maps_padded()[0]
save_obj(
f=path,
verts=verts,
faces=faces,
decimal_places=decimal_places,
path_manager=path_manager,
verts_uvs=verts_uvs,
faces_uvs=faces_uvs,
texture_map=texture_map,
)
return True

View File

@ -1050,6 +1050,68 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
texture_map=texture_map[..., 1], # Incorrect shape
)
def test_save_obj_with_texture_IO(self):
verts = torch.tensor(
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
dtype=torch.float32,
)
faces = torch.tensor(
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
)
verts_uvs = torch.tensor(
[[0.02, 0.5], [0.3, 0.03], [0.32, 0.12], [0.36, 0.17]],
dtype=torch.float32,
)
faces_uvs = faces
texture_map = torch.randint(size=(2, 2, 3), high=255) / 255.0
with TemporaryDirectory() as temp_dir:
obj_file = os.path.join(temp_dir, "mesh.obj")
textures_uv = TexturesUV([texture_map], [faces_uvs], [verts_uvs])
test_mesh = Meshes(verts=[verts], faces=[faces], textures=textures_uv)
IO().save_mesh(data=test_mesh, path=obj_file, decimal_places=2)
expected_obj_file = "\n".join(
[
"",
"mtllib mesh.mtl",
"usemtl mesh",
"",
"v 0.01 0.20 0.30",
"v 0.20 0.03 0.41",
"v 0.30 0.40 0.05",
"v 0.60 0.70 0.80",
"vt 0.02 0.50",
"vt 0.30 0.03",
"vt 0.32 0.12",
"vt 0.36 0.17",
"f 1/1 3/3 2/2",
"f 1/1 2/2 3/3",
"f 4/4 3/3 2/2",
"f 4/4 2/2 1/1",
]
)
expected_mtl_file = "\n".join(["newmtl mesh", "map_Kd mesh.png", ""])
# Check there are only 3 files in the temp dir
tempfiles = ["mesh.obj", "mesh.png", "mesh.mtl"]
tempfiles_dir = os.listdir(temp_dir)
self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
# Check the obj file is saved correctly
actual_file = open(obj_file, "r")
self.assertEqual(actual_file.read(), expected_obj_file)
# Check the mtl file is saved correctly
mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
mtl_file = open(mtl_file_name, "r")
self.assertEqual(mtl_file.read(), expected_mtl_file)
# Check the texture image file is saved correctly
texture_image = load_rgb_image("mesh.png", temp_dir)
self.assertClose(texture_image, texture_map)
@staticmethod
def _bm_save_obj(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):
return lambda: save_obj(StringIO(), verts, faces, decimal_places)