Save UV texture with obj mesh

Summary: Add functionality to to save an `.obj` file with associated UV textures: `.png` image and `.mtl` file as well as saving verts_uvs and faces_uvs to the `.obj` file.

Reviewed By: bottler

Differential Revision: D29337562

fbshipit-source-id: 86829b40dae9224088b328e7f5a16eacf8582eb5
This commit is contained in:
Nikhila Ravi 2021-06-24 15:55:09 -07:00 committed by Facebook GitHub Bot
parent 64289a491d
commit 542e2e7c07
3 changed files with 548 additions and 232 deletions

View File

@ -15,6 +15,7 @@ from typing import List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from iopath.common.file_io import PathManager from iopath.common.file_io import PathManager
from PIL import Image
from pytorch3d.common.types import Device from pytorch3d.common.types import Device
from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
@ -649,42 +650,118 @@ def _load_obj(
def save_obj( def save_obj(
f, f: Union[str, os.PathLike],
verts, verts,
faces, faces,
decimal_places: Optional[int] = None, decimal_places: Optional[int] = None,
path_manager: Optional[PathManager] = None, path_manager: Optional[PathManager] = None,
): *,
verts_uvs: Optional[torch.Tensor] = None,
faces_uvs: Optional[torch.Tensor] = None,
texture_map: Optional[torch.Tensor] = None,
) -> None:
""" """
Save a mesh to an .obj file. Save a mesh to an .obj file.
Args: Args:
f: File (or path) to which the mesh should be written. f: File (str or path) to which the mesh should be written.
verts: FloatTensor of shape (V, 3) giving vertex coordinates. verts: FloatTensor of shape (V, 3) giving vertex coordinates.
faces: LongTensor of shape (F, 3) giving faces. faces: LongTensor of shape (F, 3) giving faces.
decimal_places: Number of decimal places for saving. decimal_places: Number of decimal places for saving.
path_manager: Optional PathManager for interpreting f if path_manager: Optional PathManager for interpreting f if
it is a str. it is a str.
verts_uvs: FloatTensor of shape (V, 2) giving the uv coordinate per vertex.
faces_uvs: LongTensor of shape (F, 3) giving the index into verts_uvs for
each vertex in the face.
texture_map: FloatTensor of shape (H, W, 3) representing the texture map
for the mesh which will be saved as an image. The values are expected
to be in the range [0, 1],
""" """
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3): if len(verts) and (verts.dim() != 2 or verts.size(1) != 3):
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)." message = "'verts' should either be empty or of shape (num_verts, 3)."
raise ValueError(message) raise ValueError(message)
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3): if len(faces) and (faces.dim() != 2 or faces.size(1) != 3):
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)." message = "'faces' should either be empty or of shape (num_faces, 3)."
raise ValueError(message)
if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3):
message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
raise ValueError(message)
if verts_uvs is not None and (verts_uvs.dim() != 2 or verts_uvs.size(1) != 2):
message = "'verts_uvs' should either be empty or of shape (num_verts, 2)."
raise ValueError(message)
if texture_map is not None and (texture_map.dim() != 3 or texture_map.size(2) != 3):
message = "'texture_map' should either be empty or of shape (H, W, 3)."
raise ValueError(message) raise ValueError(message)
if path_manager is None: if path_manager is None:
path_manager = PathManager() path_manager = PathManager()
save_texture = all([t is not None for t in [faces_uvs, verts_uvs, texture_map]])
output_path = Path(f)
# Save the .obj file
with _open_file(f, path_manager, "w") as f: with _open_file(f, path_manager, "w") as f:
return _save(f, verts, faces, decimal_places) if save_texture:
# Add the header required for the texture info to be loaded correctly
obj_header = "\nmtllib {0}.mtl\nusemtl mesh\n\n".format(output_path.stem)
f.write(obj_header)
_save(
f,
verts,
faces,
decimal_places,
verts_uvs=verts_uvs,
faces_uvs=faces_uvs,
save_texture=save_texture,
)
# Save the .mtl and .png files associated with the texture
if save_texture:
image_path = output_path.with_suffix(".png")
mtl_path = output_path.with_suffix(".mtl")
if isinstance(f, str):
# Back to str for iopath interpretation.
image_path = str(image_path)
mtl_path = str(mtl_path)
# Save texture map to output folder
# pyre-fixme[16] # undefined attribute cpu
texture_map = texture_map.detach().cpu() * 255.0
image = Image.fromarray(texture_map.numpy().astype(np.uint8))
with _open_file(image_path, path_manager, "wb") as im_f:
# pyre-fixme[6] # incompatible parameter type
image.save(im_f)
# Create .mtl file with the material name and texture map filename
# TODO: enable material properties to also be saved.
with _open_file(mtl_path, path_manager, "w") as f_mtl:
lines = f"newmtl mesh\n" f"map_Kd {output_path.stem}.png\n"
f_mtl.write(lines)
# TODO (nikhilar) Speed up this function. # TODO (nikhilar) Speed up this function.
def _save(f, verts, faces, decimal_places: Optional[int] = None) -> None: def _save(
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3) f,
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3) verts,
faces,
decimal_places: Optional[int] = None,
*,
verts_uvs: Optional[torch.Tensor] = None,
faces_uvs: Optional[torch.Tensor] = None,
save_texture: bool = False,
) -> None:
if len(verts) and (verts.dim() != 2 or verts.size(1) != 3):
message = "'verts' should either be empty or of shape (num_verts, 3)."
raise ValueError(message)
if len(faces) and (faces.dim() != 2 or faces.size(1) != 3):
message = "'faces' should either be empty or of shape (num_faces, 3)."
raise ValueError(message)
if not (len(verts) or len(faces)): if not (len(verts) or len(faces)):
warnings.warn("Empty 'verts' and 'faces' arguments provided") warnings.warn("Empty 'verts' and 'faces' arguments provided")
@ -705,15 +782,42 @@ def _save(f, verts, faces, decimal_places: Optional[int] = None) -> None:
vert = [float_str % verts[i, j] for j in range(D)] vert = [float_str % verts[i, j] for j in range(D)]
lines += "v %s\n" % " ".join(vert) lines += "v %s\n" % " ".join(vert)
if save_texture:
if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3):
message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
raise ValueError(message)
if verts_uvs is not None and (verts_uvs.dim() != 2 or verts_uvs.size(1) != 2):
message = "'verts_uvs' should either be empty or of shape (num_verts, 2)."
raise ValueError(message)
# pyre-fixme[16] # undefined attribute cpu
verts_uvs, faces_uvs = verts_uvs.cpu(), faces_uvs.cpu()
# Save verts uvs after verts
if len(verts_uvs):
uV, uD = verts_uvs.shape
for i in range(uV):
uv = [float_str % verts_uvs[i, j] for j in range(uD)]
lines += "vt %s\n" % " ".join(uv)
if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0): if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0):
warnings.warn("Faces have invalid indices") warnings.warn("Faces have invalid indices")
if len(faces): if len(faces):
F, P = faces.shape F, P = faces.shape
for i in range(F): for i in range(F):
if save_texture:
# Format faces as {verts_idx}/{verts_uvs_idx}
face = [
"%d/%d" % (faces[i, j] + 1, faces_uvs[i, j] + 1) for j in range(P)
]
else:
face = ["%d" % (faces[i, j] + 1) for j in range(P)] face = ["%d" % (faces[i, j] + 1) for j in range(P)]
if i + 1 < F: if i + 1 < F:
lines += "f %s\n" % " ".join(face) lines += "f %s\n" % " ".join(face)
elif i + 1 == F: elif i + 1 == F:
# No newline at the end of the file. # No newline at the end of the file.
lines += "f %s" % " ".join(face) lines += "f %s" % " ".join(face)

View File

@ -34,7 +34,7 @@ def get_pytorch3d_dir() -> Path:
def load_rgb_image(filename: str, data_dir: Union[str, Path]): def load_rgb_image(filename: str, data_dir: Union[str, Path]):
filepath = data_dir / filename filepath = os.path.join(data_dir, filename)
with Image.open(filepath) as raw_image: with Image.open(filepath) as raw_image:
image = torch.from_numpy(np.array(raw_image) / 255.0) image = torch.from_numpy(np.array(raw_image) / 255.0)
image = image.to(dtype=torch.float32) image = image.to(dtype=torch.float32)

View File

@ -7,12 +7,18 @@
import os import os
import unittest import unittest
import warnings import warnings
from collections import Counter
from io import StringIO from io import StringIO
from pathlib import Path from pathlib import Path
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile, TemporaryDirectory
import torch import torch
from common_testing import TestCaseMixin, get_pytorch3d_dir, get_tests_dir from common_testing import (
TestCaseMixin,
get_pytorch3d_dir,
get_tests_dir,
load_rgb_image,
)
from iopath.common.file_io import PathManager from iopath.common.file_io import PathManager
from pytorch3d.io import IO, load_obj, load_objs_as_meshes, save_obj from pytorch3d.io import IO, load_obj, load_objs_as_meshes, save_obj
from pytorch3d.io.mtl_io import ( from pytorch3d.io.mtl_io import (
@ -42,8 +48,11 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"f 1 2 4 3 1", # Polygons should be split into triangles "f 1 2 4 3 1", # Polygons should be split into triangles
] ]
) )
obj_file = StringIO(obj_file) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
verts, faces, aux = load_obj(obj_file) f.write(obj_file)
f.flush()
verts, faces, aux = load_obj(Path(f.name))
normals = aux.normals normals = aux.normals
textures = aux.verts_uvs textures = aux.verts_uvs
materials = aux.material_colors materials = aux.material_colors
@ -96,8 +105,12 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"f -1 -2 1", # Negative indexing counts from the end. "f -1 -2 1", # Negative indexing counts from the end.
] ]
) )
obj_file = StringIO(obj_file)
verts, faces, aux = load_obj(obj_file) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
verts, faces, aux = load_obj(Path(f.name))
normals = aux.normals normals = aux.normals
textures = aux.verts_uvs textures = aux.verts_uvs
materials = aux.material_colors materials = aux.material_colors
@ -139,11 +152,15 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
expected_faces_normals_idx = -( expected_faces_normals_idx = -(
torch.ones_like(expected_faces, dtype=torch.int64) torch.ones_like(expected_faces, dtype=torch.int64)
) )
expected_faces_normals_idx[4, :] = torch.tensor([1, 1, 1], dtype=torch.int64) expected_faces_normals_idx[4, :] = torch.tensor(
[1, 1, 1], dtype=torch.int64
)
expected_faces_textures_idx = -( expected_faces_textures_idx = -(
torch.ones_like(expected_faces, dtype=torch.int64) torch.ones_like(expected_faces, dtype=torch.int64)
) )
expected_faces_textures_idx[4, :] = torch.tensor([0, 0, 1], dtype=torch.int64) expected_faces_textures_idx[4, :] = torch.tensor(
[0, 0, 1], dtype=torch.int64
)
self.assertTrue(torch.all(verts == expected_verts)) self.assertTrue(torch.all(verts == expected_verts))
self.assertTrue(torch.all(faces.verts_idx == expected_faces)) self.assertTrue(torch.all(faces.verts_idx == expected_faces))
@ -230,7 +247,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"f 2//1 3//1 4//2", "f 2//1 3//1 4//2",
] ]
) )
obj_file = StringIO(obj_file)
expected_faces_normals_idx = torch.tensor([[0, 0, 1]], dtype=torch.int64) expected_faces_normals_idx = torch.tensor([[0, 0, 1]], dtype=torch.int64)
expected_normals = torch.tensor( expected_normals = torch.tensor(
[[0.000000, 0.000000, -1.000000], [-1.000000, -0.000000, -0.000000]], [[0.000000, 0.000000, -1.000000], [-1.000000, -0.000000, -0.000000]],
@ -240,7 +257,12 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
[[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]], [[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]],
dtype=torch.float32, dtype=torch.float32,
) )
verts, faces, aux = load_obj(obj_file)
with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
verts, faces, aux = load_obj(Path(f.name))
normals = aux.normals normals = aux.normals
textures = aux.verts_uvs textures = aux.verts_uvs
materials = aux.material_colors materials = aux.material_colors
@ -266,7 +288,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"f 2/1 3/1 4/2", "f 2/1 3/1 4/2",
] ]
) )
obj_file = StringIO(obj_file)
expected_faces_textures_idx = torch.tensor([[0, 0, 1]], dtype=torch.int64) expected_faces_textures_idx = torch.tensor([[0, 0, 1]], dtype=torch.int64)
expected_textures = torch.tensor( expected_textures = torch.tensor(
[[0.999110, 0.501077], [0.999455, 0.750380]], dtype=torch.float32 [[0.999110, 0.501077], [0.999455, 0.750380]], dtype=torch.float32
@ -275,7 +297,12 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
[[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]], [[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]],
dtype=torch.float32, dtype=torch.float32,
) )
verts, faces, aux = load_obj(obj_file)
with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
verts, faces, aux = load_obj(Path(f.name))
normals = aux.normals normals = aux.normals
textures = aux.verts_uvs textures = aux.verts_uvs
materials = aux.material_colors materials = aux.material_colors
@ -293,42 +320,52 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
def test_load_obj_error_textures(self): def test_load_obj_error_textures(self):
obj_file = "\n".join(["vt 0.1"]) obj_file = "\n".join(["vt 0.1"])
obj_file = StringIO(obj_file) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
with self.assertRaises(ValueError) as err: with self.assertRaises(ValueError) as err:
load_obj(obj_file) load_obj(Path(f.name))
self.assertTrue("does not have 2 values" in str(err.exception)) self.assertTrue("does not have 2 values" in str(err.exception))
def test_load_obj_error_normals(self): def test_load_obj_error_normals(self):
obj_file = "\n".join(["vn 0.1"]) obj_file = "\n".join(["vn 0.1"])
obj_file = StringIO(obj_file) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
with self.assertRaises(ValueError) as err: with self.assertRaises(ValueError) as err:
load_obj(obj_file) load_obj(Path(f.name))
self.assertTrue("does not have 3 values" in str(err.exception)) self.assertTrue("does not have 3 values" in str(err.exception))
def test_load_obj_error_vertices(self): def test_load_obj_error_vertices(self):
obj_file = "\n".join(["v 1"]) obj_file = "\n".join(["v 1"])
obj_file = StringIO(obj_file) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
with self.assertRaises(ValueError) as err: with self.assertRaises(ValueError) as err:
load_obj(obj_file) load_obj(Path(f.name))
self.assertTrue("does not have 3 values" in str(err.exception)) self.assertTrue("does not have 3 values" in str(err.exception))
def test_load_obj_error_inconsistent_triplets(self): def test_load_obj_error_inconsistent_triplets(self):
obj_file = "\n".join(["f 2//1 3/1 4/1/2"]) obj_file = "\n".join(["f 2//1 3/1 4/1/2"])
obj_file = StringIO(obj_file) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
with self.assertRaises(ValueError) as err: with self.assertRaises(ValueError) as err:
load_obj(obj_file) load_obj(Path(f.name))
self.assertTrue("Vertex properties are inconsistent" in str(err.exception)) self.assertTrue("Vertex properties are inconsistent" in str(err.exception))
def test_load_obj_error_too_many_vertex_properties(self): def test_load_obj_error_too_many_vertex_properties(self):
obj_file = "\n".join(["f 2/1/1/3"]) obj_file = "\n".join(["f 2/1/1/3"])
obj_file = StringIO(obj_file) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
with self.assertRaises(ValueError) as err: with self.assertRaises(ValueError) as err:
load_obj(obj_file) load_obj(Path(f.name))
self.assertTrue( self.assertTrue(
"Face vertices can only have 3 properties" in str(err.exception) "Face vertices can only have 3 properties" in str(err.exception)
) )
@ -337,10 +374,12 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
obj_file = "\n".join( obj_file = "\n".join(
["v 0.1 0.2 0.3", "v 0.1 0.2 0.3", "v 0.1 0.2 0.3", "f -2 5 1"] ["v 0.1 0.2 0.3", "v 0.1 0.2 0.3", "v 0.1 0.2 0.3", "f -2 5 1"]
) )
obj_file = StringIO(obj_file) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"): with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"):
load_obj(obj_file) load_obj(Path(f.name))
def test_load_obj_error_invalid_normal_indices(self): def test_load_obj_error_invalid_normal_indices(self):
obj_file = "\n".join( obj_file = "\n".join(
@ -354,10 +393,12 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"f -2/2 2/4 1/1", "f -2/2 2/4 1/1",
] ]
) )
obj_file = StringIO(obj_file) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"): with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"):
load_obj(obj_file) load_obj(Path(f.name))
def test_load_obj_error_invalid_texture_indices(self): def test_load_obj_error_invalid_texture_indices(self):
obj_file = "\n".join( obj_file = "\n".join(
@ -371,17 +412,20 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"f -2//2 2//6 1//1", "f -2//2 2//6 1//1",
] ]
) )
obj_file = StringIO(obj_file) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"): with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"):
load_obj(obj_file) load_obj(Path(f.name))
def test_save_obj_invalid_shapes(self): def test_save_obj_invalid_shapes(self):
# Invalid vertices shape # Invalid vertices shape
with self.assertRaises(ValueError) as error: with self.assertRaises(ValueError) as error:
verts = torch.FloatTensor([[0.1, 0.2, 0.3, 0.4]]) # (V, 4) verts = torch.FloatTensor([[0.1, 0.2, 0.3, 0.4]]) # (V, 4)
faces = torch.LongTensor([[0, 1, 2]]) faces = torch.LongTensor([[0, 1, 2]])
save_obj(StringIO(), verts, faces) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
save_obj(Path(f.name), verts, faces)
expected_message = ( expected_message = (
"Argument 'verts' should either be empty or of shape (num_verts, 3)." "Argument 'verts' should either be empty or of shape (num_verts, 3)."
) )
@ -391,7 +435,8 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
with self.assertRaises(ValueError) as error: with self.assertRaises(ValueError) as error:
verts = torch.FloatTensor([[0.1, 0.2, 0.3]]) verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
faces = torch.LongTensor([[0, 1, 2, 3]]) # (F, 4) faces = torch.LongTensor([[0, 1, 2, 3]]) # (F, 4)
save_obj(StringIO(), verts, faces) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
save_obj(Path(f.name), verts, faces)
expected_message = ( expected_message = (
"Argument 'faces' should either be empty or of shape (num_faces, 3)." "Argument 'faces' should either be empty or of shape (num_faces, 3)."
) )
@ -402,22 +447,26 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
verts = torch.FloatTensor([[0.1, 0.2, 0.3]]) verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
faces = torch.LongTensor([[0, 1, 2]]) faces = torch.LongTensor([[0, 1, 2]])
with self.assertWarnsRegex(UserWarning, message_regex): with self.assertWarnsRegex(UserWarning, message_regex):
save_obj(StringIO(), verts, faces) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
save_obj(Path(f.name), verts, faces)
faces = torch.LongTensor([[-1, 0, 1]]) faces = torch.LongTensor([[-1, 0, 1]])
with self.assertWarnsRegex(UserWarning, message_regex): with self.assertWarnsRegex(UserWarning, message_regex):
save_obj(StringIO(), verts, faces) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
save_obj(Path(f.name), verts, faces)
def _test_save_load(self, verts, faces): def _test_save_load(self, verts, faces):
f = StringIO() with NamedTemporaryFile(mode="w", suffix=".obj") as f:
save_obj(f, verts, faces) file_path = Path(f.name)
f.seek(0) save_obj(file_path, verts, faces)
f.flush()
expected_verts, expected_faces = verts, faces expected_verts, expected_faces = verts, faces
if not len(expected_verts): # Always compare with a (V, 3) tensor if not len(expected_verts): # Always compare with a (V, 3) tensor
expected_verts = torch.zeros(size=(0, 3), dtype=torch.float32) expected_verts = torch.zeros(size=(0, 3), dtype=torch.float32)
if not len(expected_faces): # Always compare with an (F, 3) tensor if not len(expected_faces): # Always compare with an (F, 3) tensor
expected_faces = torch.zeros(size=(0, 3), dtype=torch.int64) expected_faces = torch.zeros(size=(0, 3), dtype=torch.int64)
actual_verts, actual_faces, _ = load_obj(f) actual_verts, actual_faces, _ = load_obj(file_path)
self.assertClose(expected_verts, actual_verts) self.assertClose(expected_verts, actual_verts)
self.assertClose(expected_faces, actual_faces.verts_idx) self.assertClose(expected_faces, actual_faces.verts_idx)
@ -467,8 +516,9 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
faces = torch.tensor( faces = torch.tensor(
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64 [[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
) )
obj_file = StringIO() with NamedTemporaryFile(mode="w", suffix=".obj") as f:
save_obj(obj_file, verts, faces, decimal_places=2) save_obj(Path(f.name), verts, faces, decimal_places=2)
expected_file = "\n".join( expected_file = "\n".join(
[ [
"v 0.01 0.20 0.30", "v 0.01 0.20 0.30",
@ -481,8 +531,8 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"f 4 2 1", "f 4 2 1",
] ]
) )
actual_file = obj_file.getvalue() actual_file = open(Path(f.name), "r")
self.assertEqual(actual_file, expected_file) self.assertEqual(actual_file.read(), expected_file)
def test_load_mtl(self): def test_load_mtl(self):
obj_filename = "cow_mesh/cow.obj" obj_filename = "cow_mesh/cow.obj"
@ -534,9 +584,12 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"Ns 10.0", "Ns 10.0",
] ]
) )
mtl_file = StringIO(mtl_file) with NamedTemporaryFile(mode="w", suffix=".mtl") as f:
f.write(mtl_file)
f.flush()
material_properties, texture_files = _parse_mtl( material_properties, texture_files = _parse_mtl(
mtl_file, path_manager=PathManager(), device="cpu" Path(f.name), path_manager=PathManager(), device="cpu"
) )
dtype = torch.float32 dtype = torch.float32
@ -618,9 +671,13 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"f 1 2 4", "f 1 2 4",
] ]
) )
obj_file = StringIO(obj_file)
with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
with self.assertWarnsRegex(UserWarning, "No mtl file provided"): with self.assertWarnsRegex(UserWarning, "No mtl file provided"):
verts, faces, aux = load_obj(obj_file) verts, faces, aux = load_obj(Path(f.name))
expected_verts = torch.tensor( expected_verts = torch.tensor(
[[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]], [[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]],
@ -825,6 +882,161 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
with self.assertRaisesRegex(ValueError, "same type of texture"): with self.assertRaisesRegex(ValueError, "same type of texture"):
join_meshes_as_batch([mesh_atlas, mesh_rgb, mesh_atlas]) join_meshes_as_batch([mesh_atlas, mesh_rgb, mesh_atlas])
def test_save_obj_with_texture(self):
verts = torch.tensor(
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
dtype=torch.float32,
)
faces = torch.tensor(
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
)
verts_uvs = torch.tensor(
[[0.02, 0.5], [0.3, 0.03], [0.32, 0.12], [0.36, 0.17]],
dtype=torch.float32,
)
faces_uvs = faces
texture_map = torch.randint(size=(2, 2, 3), high=255) / 255.0
with TemporaryDirectory() as temp_dir:
obj_file = os.path.join(temp_dir, "mesh.obj")
save_obj(
obj_file,
verts,
faces,
decimal_places=2,
verts_uvs=verts_uvs,
faces_uvs=faces_uvs,
texture_map=texture_map,
)
expected_obj_file = "\n".join(
[
"",
"mtllib mesh.mtl",
"usemtl mesh",
"",
"v 0.01 0.20 0.30",
"v 0.20 0.03 0.41",
"v 0.30 0.40 0.05",
"v 0.60 0.70 0.80",
"vt 0.02 0.50",
"vt 0.30 0.03",
"vt 0.32 0.12",
"vt 0.36 0.17",
"f 1/1 3/3 2/2",
"f 1/1 2/2 3/3",
"f 4/4 3/3 2/2",
"f 4/4 2/2 1/1",
]
)
expected_mtl_file = "\n".join(["newmtl mesh", "map_Kd mesh.png", ""])
# Check there are only 3 files in the temp dir
tempfiles = ["mesh.obj", "mesh.png", "mesh.mtl"]
tempfiles_dir = os.listdir(temp_dir)
self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
# Check the obj file is saved correctly
actual_file = open(obj_file, "r")
self.assertEqual(actual_file.read(), expected_obj_file)
# Check the mtl file is saved correctly
mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
mtl_file = open(mtl_file_name, "r")
self.assertEqual(mtl_file.read(), expected_mtl_file)
# Check the texture image file is saved correctly
texture_image = load_rgb_image("mesh.png", temp_dir)
self.assertClose(texture_image, texture_map)
def test_save_obj_with_texture_errors(self):
verts = torch.tensor(
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
dtype=torch.float32,
)
faces = torch.tensor(
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
)
verts_uvs = torch.tensor(
[[0.02, 0.5], [0.3, 0.03], [0.32, 0.12], [0.36, 0.17]],
dtype=torch.float32,
)
faces_uvs = faces
texture_map = torch.randint(size=(2, 2, 3), high=255)
expected_obj_file = "\n".join(
[
"v 0.01 0.20 0.30",
"v 0.20 0.03 0.41",
"v 0.30 0.40 0.05",
"v 0.60 0.70 0.80",
"f 1 3 2",
"f 1 2 3",
"f 4 3 2",
"f 4 2 1",
]
)
with TemporaryDirectory() as temp_dir:
obj_file = os.path.join(temp_dir, "mesh.obj")
# If only one of verts_uvs/faces_uvs/texture_map is provided
# then textures are not saved
for arg in [
{"verts_uvs": verts_uvs},
{"faces_uvs": faces_uvs},
{"texture_map": texture_map},
]:
save_obj(
obj_file,
verts,
faces,
decimal_places=2,
**arg,
)
# Check there is only 1 file in the temp dir
tempfiles = ["mesh.obj"]
tempfiles_dir = os.listdir(temp_dir)
self.assertEqual(tempfiles, tempfiles_dir)
# Check the obj file is saved correctly
actual_file = open(obj_file, "r")
self.assertEqual(actual_file.read(), expected_obj_file)
obj_file = StringIO()
with self.assertRaises(ValueError):
save_obj(
obj_file,
verts,
faces,
decimal_places=2,
verts_uvs=verts_uvs,
faces_uvs=faces_uvs[..., 2], # Incorrect shape
texture_map=texture_map,
)
with self.assertRaises(ValueError):
save_obj(
obj_file,
verts,
faces,
decimal_places=2,
verts_uvs=verts_uvs[..., 0], # Incorrect shape
faces_uvs=faces_uvs,
texture_map=texture_map,
)
with self.assertRaises(ValueError):
save_obj(
obj_file,
verts,
faces,
decimal_places=2,
verts_uvs=verts_uvs,
faces_uvs=faces_uvs,
texture_map=texture_map[..., 1], # Incorrect shape
)
@staticmethod @staticmethod
def _bm_save_obj(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int): def _bm_save_obj(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):
return lambda: save_obj(StringIO(), verts, faces, decimal_places) return lambda: save_obj(StringIO(), verts, faces, decimal_places)