Save UV texture with obj mesh

Summary: Add functionality to to save an `.obj` file with associated UV textures: `.png` image and `.mtl` file as well as saving verts_uvs and faces_uvs to the `.obj` file.

Reviewed By: bottler

Differential Revision: D29337562

fbshipit-source-id: 86829b40dae9224088b328e7f5a16eacf8582eb5
This commit is contained in:
Nikhila Ravi 2021-06-24 15:55:09 -07:00 committed by Facebook GitHub Bot
parent 64289a491d
commit 542e2e7c07
3 changed files with 548 additions and 232 deletions

View File

@ -15,6 +15,7 @@ from typing import List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from iopath.common.file_io import PathManager from iopath.common.file_io import PathManager
from PIL import Image
from pytorch3d.common.types import Device from pytorch3d.common.types import Device
from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
@ -649,42 +650,118 @@ def _load_obj(
def save_obj( def save_obj(
f, f: Union[str, os.PathLike],
verts, verts,
faces, faces,
decimal_places: Optional[int] = None, decimal_places: Optional[int] = None,
path_manager: Optional[PathManager] = None, path_manager: Optional[PathManager] = None,
): *,
verts_uvs: Optional[torch.Tensor] = None,
faces_uvs: Optional[torch.Tensor] = None,
texture_map: Optional[torch.Tensor] = None,
) -> None:
""" """
Save a mesh to an .obj file. Save a mesh to an .obj file.
Args: Args:
f: File (or path) to which the mesh should be written. f: File (str or path) to which the mesh should be written.
verts: FloatTensor of shape (V, 3) giving vertex coordinates. verts: FloatTensor of shape (V, 3) giving vertex coordinates.
faces: LongTensor of shape (F, 3) giving faces. faces: LongTensor of shape (F, 3) giving faces.
decimal_places: Number of decimal places for saving. decimal_places: Number of decimal places for saving.
path_manager: Optional PathManager for interpreting f if path_manager: Optional PathManager for interpreting f if
it is a str. it is a str.
verts_uvs: FloatTensor of shape (V, 2) giving the uv coordinate per vertex.
faces_uvs: LongTensor of shape (F, 3) giving the index into verts_uvs for
each vertex in the face.
texture_map: FloatTensor of shape (H, W, 3) representing the texture map
for the mesh which will be saved as an image. The values are expected
to be in the range [0, 1],
""" """
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3): if len(verts) and (verts.dim() != 2 or verts.size(1) != 3):
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)." message = "'verts' should either be empty or of shape (num_verts, 3)."
raise ValueError(message) raise ValueError(message)
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3): if len(faces) and (faces.dim() != 2 or faces.size(1) != 3):
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)." message = "'faces' should either be empty or of shape (num_faces, 3)."
raise ValueError(message)
if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3):
message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
raise ValueError(message)
if verts_uvs is not None and (verts_uvs.dim() != 2 or verts_uvs.size(1) != 2):
message = "'verts_uvs' should either be empty or of shape (num_verts, 2)."
raise ValueError(message)
if texture_map is not None and (texture_map.dim() != 3 or texture_map.size(2) != 3):
message = "'texture_map' should either be empty or of shape (H, W, 3)."
raise ValueError(message) raise ValueError(message)
if path_manager is None: if path_manager is None:
path_manager = PathManager() path_manager = PathManager()
save_texture = all([t is not None for t in [faces_uvs, verts_uvs, texture_map]])
output_path = Path(f)
# Save the .obj file
with _open_file(f, path_manager, "w") as f: with _open_file(f, path_manager, "w") as f:
return _save(f, verts, faces, decimal_places) if save_texture:
# Add the header required for the texture info to be loaded correctly
obj_header = "\nmtllib {0}.mtl\nusemtl mesh\n\n".format(output_path.stem)
f.write(obj_header)
_save(
f,
verts,
faces,
decimal_places,
verts_uvs=verts_uvs,
faces_uvs=faces_uvs,
save_texture=save_texture,
)
# Save the .mtl and .png files associated with the texture
if save_texture:
image_path = output_path.with_suffix(".png")
mtl_path = output_path.with_suffix(".mtl")
if isinstance(f, str):
# Back to str for iopath interpretation.
image_path = str(image_path)
mtl_path = str(mtl_path)
# Save texture map to output folder
# pyre-fixme[16] # undefined attribute cpu
texture_map = texture_map.detach().cpu() * 255.0
image = Image.fromarray(texture_map.numpy().astype(np.uint8))
with _open_file(image_path, path_manager, "wb") as im_f:
# pyre-fixme[6] # incompatible parameter type
image.save(im_f)
# Create .mtl file with the material name and texture map filename
# TODO: enable material properties to also be saved.
with _open_file(mtl_path, path_manager, "w") as f_mtl:
lines = f"newmtl mesh\n" f"map_Kd {output_path.stem}.png\n"
f_mtl.write(lines)
# TODO (nikhilar) Speed up this function. # TODO (nikhilar) Speed up this function.
def _save(f, verts, faces, decimal_places: Optional[int] = None) -> None: def _save(
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3) f,
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3) verts,
faces,
decimal_places: Optional[int] = None,
*,
verts_uvs: Optional[torch.Tensor] = None,
faces_uvs: Optional[torch.Tensor] = None,
save_texture: bool = False,
) -> None:
if len(verts) and (verts.dim() != 2 or verts.size(1) != 3):
message = "'verts' should either be empty or of shape (num_verts, 3)."
raise ValueError(message)
if len(faces) and (faces.dim() != 2 or faces.size(1) != 3):
message = "'faces' should either be empty or of shape (num_faces, 3)."
raise ValueError(message)
if not (len(verts) or len(faces)): if not (len(verts) or len(faces)):
warnings.warn("Empty 'verts' and 'faces' arguments provided") warnings.warn("Empty 'verts' and 'faces' arguments provided")
@ -705,15 +782,42 @@ def _save(f, verts, faces, decimal_places: Optional[int] = None) -> None:
vert = [float_str % verts[i, j] for j in range(D)] vert = [float_str % verts[i, j] for j in range(D)]
lines += "v %s\n" % " ".join(vert) lines += "v %s\n" % " ".join(vert)
if save_texture:
if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3):
message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
raise ValueError(message)
if verts_uvs is not None and (verts_uvs.dim() != 2 or verts_uvs.size(1) != 2):
message = "'verts_uvs' should either be empty or of shape (num_verts, 2)."
raise ValueError(message)
# pyre-fixme[16] # undefined attribute cpu
verts_uvs, faces_uvs = verts_uvs.cpu(), faces_uvs.cpu()
# Save verts uvs after verts
if len(verts_uvs):
uV, uD = verts_uvs.shape
for i in range(uV):
uv = [float_str % verts_uvs[i, j] for j in range(uD)]
lines += "vt %s\n" % " ".join(uv)
if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0): if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0):
warnings.warn("Faces have invalid indices") warnings.warn("Faces have invalid indices")
if len(faces): if len(faces):
F, P = faces.shape F, P = faces.shape
for i in range(F): for i in range(F):
face = ["%d" % (faces[i, j] + 1) for j in range(P)] if save_texture:
# Format faces as {verts_idx}/{verts_uvs_idx}
face = [
"%d/%d" % (faces[i, j] + 1, faces_uvs[i, j] + 1) for j in range(P)
]
else:
face = ["%d" % (faces[i, j] + 1) for j in range(P)]
if i + 1 < F: if i + 1 < F:
lines += "f %s\n" % " ".join(face) lines += "f %s\n" % " ".join(face)
elif i + 1 == F: elif i + 1 == F:
# No newline at the end of the file. # No newline at the end of the file.
lines += "f %s" % " ".join(face) lines += "f %s" % " ".join(face)

View File

@ -34,7 +34,7 @@ def get_pytorch3d_dir() -> Path:
def load_rgb_image(filename: str, data_dir: Union[str, Path]): def load_rgb_image(filename: str, data_dir: Union[str, Path]):
filepath = data_dir / filename filepath = os.path.join(data_dir, filename)
with Image.open(filepath) as raw_image: with Image.open(filepath) as raw_image:
image = torch.from_numpy(np.array(raw_image) / 255.0) image = torch.from_numpy(np.array(raw_image) / 255.0)
image = image.to(dtype=torch.float32) image = image.to(dtype=torch.float32)

View File

@ -7,12 +7,18 @@
import os import os
import unittest import unittest
import warnings import warnings
from collections import Counter
from io import StringIO from io import StringIO
from pathlib import Path from pathlib import Path
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile, TemporaryDirectory
import torch import torch
from common_testing import TestCaseMixin, get_pytorch3d_dir, get_tests_dir from common_testing import (
TestCaseMixin,
get_pytorch3d_dir,
get_tests_dir,
load_rgb_image,
)
from iopath.common.file_io import PathManager from iopath.common.file_io import PathManager
from pytorch3d.io import IO, load_obj, load_objs_as_meshes, save_obj from pytorch3d.io import IO, load_obj, load_objs_as_meshes, save_obj
from pytorch3d.io.mtl_io import ( from pytorch3d.io.mtl_io import (
@ -42,38 +48,41 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"f 1 2 4 3 1", # Polygons should be split into triangles "f 1 2 4 3 1", # Polygons should be split into triangles
] ]
) )
obj_file = StringIO(obj_file) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
verts, faces, aux = load_obj(obj_file) f.write(obj_file)
normals = aux.normals f.flush()
textures = aux.verts_uvs
materials = aux.material_colors
tex_maps = aux.texture_images
expected_verts = torch.tensor( verts, faces, aux = load_obj(Path(f.name))
[[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]], normals = aux.normals
dtype=torch.float32, textures = aux.verts_uvs
) materials = aux.material_colors
expected_faces = torch.tensor( tex_maps = aux.texture_images
[
[0, 1, 2], # First face expected_verts = torch.tensor(
[0, 1, 3], # Second face (polygon) [[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]],
[0, 3, 2], # Second face (polygon) dtype=torch.float32,
[0, 2, 0], # Second face (polygon) )
], expected_faces = torch.tensor(
dtype=torch.int64, [
) [0, 1, 2], # First face
self.assertTrue(torch.all(verts == expected_verts)) [0, 1, 3], # Second face (polygon)
self.assertTrue(torch.all(faces.verts_idx == expected_faces)) [0, 3, 2], # Second face (polygon)
padded_vals = -(torch.ones_like(faces.verts_idx)) [0, 2, 0], # Second face (polygon)
self.assertTrue(torch.all(faces.normals_idx == padded_vals)) ],
self.assertTrue(torch.all(faces.textures_idx == padded_vals)) dtype=torch.int64,
self.assertTrue( )
torch.all(faces.materials_idx == -(torch.ones(len(expected_faces)))) self.assertTrue(torch.all(verts == expected_verts))
) self.assertTrue(torch.all(faces.verts_idx == expected_faces))
self.assertTrue(normals is None) padded_vals = -(torch.ones_like(faces.verts_idx))
self.assertTrue(textures is None) self.assertTrue(torch.all(faces.normals_idx == padded_vals))
self.assertTrue(materials is None) self.assertTrue(torch.all(faces.textures_idx == padded_vals))
self.assertTrue(tex_maps is None) self.assertTrue(
torch.all(faces.materials_idx == -(torch.ones(len(expected_faces))))
)
self.assertTrue(normals is None)
self.assertTrue(textures is None)
self.assertTrue(materials is None)
self.assertTrue(tex_maps is None)
def test_load_obj_complex(self): def test_load_obj_complex(self):
obj_file = "\n".join( obj_file = "\n".join(
@ -96,63 +105,71 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"f -1 -2 1", # Negative indexing counts from the end. "f -1 -2 1", # Negative indexing counts from the end.
] ]
) )
obj_file = StringIO(obj_file)
verts, faces, aux = load_obj(obj_file)
normals = aux.normals
textures = aux.verts_uvs
materials = aux.material_colors
tex_maps = aux.texture_images
expected_verts = torch.tensor( with NamedTemporaryFile(mode="w", suffix=".obj") as f:
[ f.write(obj_file)
[0.1, 0.2, 0.3], f.flush()
[0.2, 0.3, 0.4],
[0.3, 0.4, 0.5],
[0.4, 0.5, 0.6],
[0.5, 0.6, 0.7],
],
dtype=torch.float32,
)
expected_faces = torch.tensor(
[
[0, 1, 2], # First face
[0, 1, 3], # Second face (polygon)
[0, 3, 2], # Second face (polygon)
[0, 2, 4], # Second face (polygon)
[1, 2, 3], # Third face (normals / texture)
[4, 3, 0], # Fourth face (negative indices)
],
dtype=torch.int64,
)
expected_normals = torch.tensor(
[
[0.000000, 0.000000, -1.000000],
[-1.000000, -0.000000, -0.000000],
[-0.000000, -0.000000, 1.000000],
],
dtype=torch.float32,
)
expected_textures = torch.tensor(
[[0.749279, 0.501284], [0.999110, 0.501077], [0.999455, 0.750380]],
dtype=torch.float32,
)
expected_faces_normals_idx = -(
torch.ones_like(expected_faces, dtype=torch.int64)
)
expected_faces_normals_idx[4, :] = torch.tensor([1, 1, 1], dtype=torch.int64)
expected_faces_textures_idx = -(
torch.ones_like(expected_faces, dtype=torch.int64)
)
expected_faces_textures_idx[4, :] = torch.tensor([0, 0, 1], dtype=torch.int64)
self.assertTrue(torch.all(verts == expected_verts)) verts, faces, aux = load_obj(Path(f.name))
self.assertTrue(torch.all(faces.verts_idx == expected_faces)) normals = aux.normals
self.assertClose(normals, expected_normals) textures = aux.verts_uvs
self.assertClose(textures, expected_textures) materials = aux.material_colors
self.assertClose(faces.normals_idx, expected_faces_normals_idx) tex_maps = aux.texture_images
self.assertClose(faces.textures_idx, expected_faces_textures_idx)
self.assertTrue(materials is None) expected_verts = torch.tensor(
self.assertTrue(tex_maps is None) [
[0.1, 0.2, 0.3],
[0.2, 0.3, 0.4],
[0.3, 0.4, 0.5],
[0.4, 0.5, 0.6],
[0.5, 0.6, 0.7],
],
dtype=torch.float32,
)
expected_faces = torch.tensor(
[
[0, 1, 2], # First face
[0, 1, 3], # Second face (polygon)
[0, 3, 2], # Second face (polygon)
[0, 2, 4], # Second face (polygon)
[1, 2, 3], # Third face (normals / texture)
[4, 3, 0], # Fourth face (negative indices)
],
dtype=torch.int64,
)
expected_normals = torch.tensor(
[
[0.000000, 0.000000, -1.000000],
[-1.000000, -0.000000, -0.000000],
[-0.000000, -0.000000, 1.000000],
],
dtype=torch.float32,
)
expected_textures = torch.tensor(
[[0.749279, 0.501284], [0.999110, 0.501077], [0.999455, 0.750380]],
dtype=torch.float32,
)
expected_faces_normals_idx = -(
torch.ones_like(expected_faces, dtype=torch.int64)
)
expected_faces_normals_idx[4, :] = torch.tensor(
[1, 1, 1], dtype=torch.int64
)
expected_faces_textures_idx = -(
torch.ones_like(expected_faces, dtype=torch.int64)
)
expected_faces_textures_idx[4, :] = torch.tensor(
[0, 0, 1], dtype=torch.int64
)
self.assertTrue(torch.all(verts == expected_verts))
self.assertTrue(torch.all(faces.verts_idx == expected_faces))
self.assertClose(normals, expected_normals)
self.assertClose(textures, expected_textures)
self.assertClose(faces.normals_idx, expected_faces_normals_idx)
self.assertClose(faces.textures_idx, expected_faces_textures_idx)
self.assertTrue(materials is None)
self.assertTrue(tex_maps is None)
def test_load_obj_complex_pluggable(self): def test_load_obj_complex_pluggable(self):
""" """
@ -230,7 +247,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"f 2//1 3//1 4//2", "f 2//1 3//1 4//2",
] ]
) )
obj_file = StringIO(obj_file)
expected_faces_normals_idx = torch.tensor([[0, 0, 1]], dtype=torch.int64) expected_faces_normals_idx = torch.tensor([[0, 0, 1]], dtype=torch.int64)
expected_normals = torch.tensor( expected_normals = torch.tensor(
[[0.000000, 0.000000, -1.000000], [-1.000000, -0.000000, -0.000000]], [[0.000000, 0.000000, -1.000000], [-1.000000, -0.000000, -0.000000]],
@ -240,19 +257,24 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
[[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]], [[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]],
dtype=torch.float32, dtype=torch.float32,
) )
verts, faces, aux = load_obj(obj_file)
normals = aux.normals with NamedTemporaryFile(mode="w", suffix=".obj") as f:
textures = aux.verts_uvs f.write(obj_file)
materials = aux.material_colors f.flush()
tex_maps = aux.texture_images
self.assertClose(faces.normals_idx, expected_faces_normals_idx) verts, faces, aux = load_obj(Path(f.name))
self.assertClose(normals, expected_normals) normals = aux.normals
self.assertClose(verts, expected_verts) textures = aux.verts_uvs
# Textures idx padded with -1. materials = aux.material_colors
self.assertClose(faces.textures_idx, torch.ones_like(faces.verts_idx) * -1) tex_maps = aux.texture_images
self.assertTrue(textures is None) self.assertClose(faces.normals_idx, expected_faces_normals_idx)
self.assertTrue(materials is None) self.assertClose(normals, expected_normals)
self.assertTrue(tex_maps is None) self.assertClose(verts, expected_verts)
# Textures idx padded with -1.
self.assertClose(faces.textures_idx, torch.ones_like(faces.verts_idx) * -1)
self.assertTrue(textures is None)
self.assertTrue(materials is None)
self.assertTrue(tex_maps is None)
def test_load_obj_textures_only(self): def test_load_obj_textures_only(self):
obj_file = "\n".join( obj_file = "\n".join(
@ -266,7 +288,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"f 2/1 3/1 4/2", "f 2/1 3/1 4/2",
] ]
) )
obj_file = StringIO(obj_file)
expected_faces_textures_idx = torch.tensor([[0, 0, 1]], dtype=torch.int64) expected_faces_textures_idx = torch.tensor([[0, 0, 1]], dtype=torch.int64)
expected_textures = torch.tensor( expected_textures = torch.tensor(
[[0.999110, 0.501077], [0.999455, 0.750380]], dtype=torch.float32 [[0.999110, 0.501077], [0.999455, 0.750380]], dtype=torch.float32
@ -275,72 +297,89 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
[[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]], [[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]],
dtype=torch.float32, dtype=torch.float32,
) )
verts, faces, aux = load_obj(obj_file)
normals = aux.normals
textures = aux.verts_uvs
materials = aux.material_colors
tex_maps = aux.texture_images
self.assertClose(faces.textures_idx, expected_faces_textures_idx) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
self.assertClose(expected_textures, textures) f.write(obj_file)
self.assertClose(expected_verts, verts) f.flush()
self.assertTrue(
torch.all(faces.normals_idx == -(torch.ones_like(faces.textures_idx))) verts, faces, aux = load_obj(Path(f.name))
) normals = aux.normals
self.assertTrue(normals is None) textures = aux.verts_uvs
self.assertTrue(materials is None) materials = aux.material_colors
self.assertTrue(tex_maps is None) tex_maps = aux.texture_images
self.assertClose(faces.textures_idx, expected_faces_textures_idx)
self.assertClose(expected_textures, textures)
self.assertClose(expected_verts, verts)
self.assertTrue(
torch.all(faces.normals_idx == -(torch.ones_like(faces.textures_idx)))
)
self.assertTrue(normals is None)
self.assertTrue(materials is None)
self.assertTrue(tex_maps is None)
def test_load_obj_error_textures(self): def test_load_obj_error_textures(self):
obj_file = "\n".join(["vt 0.1"]) obj_file = "\n".join(["vt 0.1"])
obj_file = StringIO(obj_file) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
with self.assertRaises(ValueError) as err: with self.assertRaises(ValueError) as err:
load_obj(obj_file) load_obj(Path(f.name))
self.assertTrue("does not have 2 values" in str(err.exception)) self.assertTrue("does not have 2 values" in str(err.exception))
def test_load_obj_error_normals(self): def test_load_obj_error_normals(self):
obj_file = "\n".join(["vn 0.1"]) obj_file = "\n".join(["vn 0.1"])
obj_file = StringIO(obj_file) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
with self.assertRaises(ValueError) as err: with self.assertRaises(ValueError) as err:
load_obj(obj_file) load_obj(Path(f.name))
self.assertTrue("does not have 3 values" in str(err.exception)) self.assertTrue("does not have 3 values" in str(err.exception))
def test_load_obj_error_vertices(self): def test_load_obj_error_vertices(self):
obj_file = "\n".join(["v 1"]) obj_file = "\n".join(["v 1"])
obj_file = StringIO(obj_file) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
with self.assertRaises(ValueError) as err: with self.assertRaises(ValueError) as err:
load_obj(obj_file) load_obj(Path(f.name))
self.assertTrue("does not have 3 values" in str(err.exception)) self.assertTrue("does not have 3 values" in str(err.exception))
def test_load_obj_error_inconsistent_triplets(self): def test_load_obj_error_inconsistent_triplets(self):
obj_file = "\n".join(["f 2//1 3/1 4/1/2"]) obj_file = "\n".join(["f 2//1 3/1 4/1/2"])
obj_file = StringIO(obj_file) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
with self.assertRaises(ValueError) as err: with self.assertRaises(ValueError) as err:
load_obj(obj_file) load_obj(Path(f.name))
self.assertTrue("Vertex properties are inconsistent" in str(err.exception)) self.assertTrue("Vertex properties are inconsistent" in str(err.exception))
def test_load_obj_error_too_many_vertex_properties(self): def test_load_obj_error_too_many_vertex_properties(self):
obj_file = "\n".join(["f 2/1/1/3"]) obj_file = "\n".join(["f 2/1/1/3"])
obj_file = StringIO(obj_file) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
with self.assertRaises(ValueError) as err: with self.assertRaises(ValueError) as err:
load_obj(obj_file) load_obj(Path(f.name))
self.assertTrue( self.assertTrue(
"Face vertices can only have 3 properties" in str(err.exception) "Face vertices can only have 3 properties" in str(err.exception)
) )
def test_load_obj_error_invalid_vertex_indices(self): def test_load_obj_error_invalid_vertex_indices(self):
obj_file = "\n".join( obj_file = "\n".join(
["v 0.1 0.2 0.3", "v 0.1 0.2 0.3", "v 0.1 0.2 0.3", "f -2 5 1"] ["v 0.1 0.2 0.3", "v 0.1 0.2 0.3", "v 0.1 0.2 0.3", "f -2 5 1"]
) )
obj_file = StringIO(obj_file) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"): with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"):
load_obj(obj_file) load_obj(Path(f.name))
def test_load_obj_error_invalid_normal_indices(self): def test_load_obj_error_invalid_normal_indices(self):
obj_file = "\n".join( obj_file = "\n".join(
@ -354,10 +393,12 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"f -2/2 2/4 1/1", "f -2/2 2/4 1/1",
] ]
) )
obj_file = StringIO(obj_file) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"): with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"):
load_obj(obj_file) load_obj(Path(f.name))
def test_load_obj_error_invalid_texture_indices(self): def test_load_obj_error_invalid_texture_indices(self):
obj_file = "\n".join( obj_file = "\n".join(
@ -371,17 +412,20 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"f -2//2 2//6 1//1", "f -2//2 2//6 1//1",
] ]
) )
obj_file = StringIO(obj_file) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"): with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"):
load_obj(obj_file) load_obj(Path(f.name))
def test_save_obj_invalid_shapes(self): def test_save_obj_invalid_shapes(self):
# Invalid vertices shape # Invalid vertices shape
with self.assertRaises(ValueError) as error: with self.assertRaises(ValueError) as error:
verts = torch.FloatTensor([[0.1, 0.2, 0.3, 0.4]]) # (V, 4) verts = torch.FloatTensor([[0.1, 0.2, 0.3, 0.4]]) # (V, 4)
faces = torch.LongTensor([[0, 1, 2]]) faces = torch.LongTensor([[0, 1, 2]])
save_obj(StringIO(), verts, faces) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
save_obj(Path(f.name), verts, faces)
expected_message = ( expected_message = (
"Argument 'verts' should either be empty or of shape (num_verts, 3)." "Argument 'verts' should either be empty or of shape (num_verts, 3)."
) )
@ -391,7 +435,8 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
with self.assertRaises(ValueError) as error: with self.assertRaises(ValueError) as error:
verts = torch.FloatTensor([[0.1, 0.2, 0.3]]) verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
faces = torch.LongTensor([[0, 1, 2, 3]]) # (F, 4) faces = torch.LongTensor([[0, 1, 2, 3]]) # (F, 4)
save_obj(StringIO(), verts, faces) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
save_obj(Path(f.name), verts, faces)
expected_message = ( expected_message = (
"Argument 'faces' should either be empty or of shape (num_faces, 3)." "Argument 'faces' should either be empty or of shape (num_faces, 3)."
) )
@ -402,24 +447,28 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
verts = torch.FloatTensor([[0.1, 0.2, 0.3]]) verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
faces = torch.LongTensor([[0, 1, 2]]) faces = torch.LongTensor([[0, 1, 2]])
with self.assertWarnsRegex(UserWarning, message_regex): with self.assertWarnsRegex(UserWarning, message_regex):
save_obj(StringIO(), verts, faces) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
save_obj(Path(f.name), verts, faces)
faces = torch.LongTensor([[-1, 0, 1]]) faces = torch.LongTensor([[-1, 0, 1]])
with self.assertWarnsRegex(UserWarning, message_regex): with self.assertWarnsRegex(UserWarning, message_regex):
save_obj(StringIO(), verts, faces) with NamedTemporaryFile(mode="w", suffix=".obj") as f:
save_obj(Path(f.name), verts, faces)
def _test_save_load(self, verts, faces): def _test_save_load(self, verts, faces):
f = StringIO() with NamedTemporaryFile(mode="w", suffix=".obj") as f:
save_obj(f, verts, faces) file_path = Path(f.name)
f.seek(0) save_obj(file_path, verts, faces)
expected_verts, expected_faces = verts, faces f.flush()
if not len(expected_verts): # Always compare with a (V, 3) tensor
expected_verts = torch.zeros(size=(0, 3), dtype=torch.float32) expected_verts, expected_faces = verts, faces
if not len(expected_faces): # Always compare with an (F, 3) tensor if not len(expected_verts): # Always compare with a (V, 3) tensor
expected_faces = torch.zeros(size=(0, 3), dtype=torch.int64) expected_verts = torch.zeros(size=(0, 3), dtype=torch.float32)
actual_verts, actual_faces, _ = load_obj(f) if not len(expected_faces): # Always compare with an (F, 3) tensor
self.assertClose(expected_verts, actual_verts) expected_faces = torch.zeros(size=(0, 3), dtype=torch.int64)
self.assertClose(expected_faces, actual_faces.verts_idx) actual_verts, actual_faces, _ = load_obj(file_path)
self.assertClose(expected_verts, actual_verts)
self.assertClose(expected_faces, actual_faces.verts_idx)
def test_empty_save_load_obj(self): def test_empty_save_load_obj(self):
# Vertices + empty faces # Vertices + empty faces
@ -467,22 +516,23 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
faces = torch.tensor( faces = torch.tensor(
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64 [[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
) )
obj_file = StringIO() with NamedTemporaryFile(mode="w", suffix=".obj") as f:
save_obj(obj_file, verts, faces, decimal_places=2) save_obj(Path(f.name), verts, faces, decimal_places=2)
expected_file = "\n".join(
[ expected_file = "\n".join(
"v 0.01 0.20 0.30", [
"v 0.20 0.03 0.41", "v 0.01 0.20 0.30",
"v 0.30 0.40 0.05", "v 0.20 0.03 0.41",
"v 0.60 0.70 0.80", "v 0.30 0.40 0.05",
"f 1 3 2", "v 0.60 0.70 0.80",
"f 1 2 3", "f 1 3 2",
"f 4 3 2", "f 1 2 3",
"f 4 2 1", "f 4 3 2",
] "f 4 2 1",
) ]
actual_file = obj_file.getvalue() )
self.assertEqual(actual_file, expected_file) actual_file = open(Path(f.name), "r")
self.assertEqual(actual_file.read(), expected_file)
def test_load_mtl(self): def test_load_mtl(self):
obj_filename = "cow_mesh/cow.obj" obj_filename = "cow_mesh/cow.obj"
@ -534,36 +584,39 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"Ns 10.0", "Ns 10.0",
] ]
) )
mtl_file = StringIO(mtl_file) with NamedTemporaryFile(mode="w", suffix=".mtl") as f:
material_properties, texture_files = _parse_mtl( f.write(mtl_file)
mtl_file, path_manager=PathManager(), device="cpu" f.flush()
)
dtype = torch.float32 material_properties, texture_files = _parse_mtl(
expected_materials = { Path(f.name), path_manager=PathManager(), device="cpu"
"material_1": { )
"ambient_color": torch.tensor([1.0, 1.0, 1.0], dtype=dtype),
"diffuse_color": torch.tensor([1.0, 1.0, 1.0], dtype=dtype), dtype = torch.float32
"specular_color": torch.tensor([0.0, 0.0, 0.0], dtype=dtype), expected_materials = {
"shininess": torch.tensor([10.0], dtype=dtype), "material_1": {
"ambient_color": torch.tensor([1.0, 1.0, 1.0], dtype=dtype),
"diffuse_color": torch.tensor([1.0, 1.0, 1.0], dtype=dtype),
"specular_color": torch.tensor([0.0, 0.0, 0.0], dtype=dtype),
"shininess": torch.tensor([10.0], dtype=dtype),
}
} }
} # Check that there is a material with name material_1
# Check that there is a material with name material_1 self.assertTrue(tuple(texture_files.keys()) == ("material_1",))
self.assertTrue(tuple(texture_files.keys()) == ("material_1",)) # Check that there is an image with name material 1.png
# Check that there is an image with name material 1.png self.assertTrue(texture_files["material_1"] == "material 1.png")
self.assertTrue(texture_files["material_1"] == "material 1.png")
# Check all keys and values in dictionary are the same. # Check all keys and values in dictionary are the same.
for n1, n2 in zip(material_properties.keys(), expected_materials.keys()): for n1, n2 in zip(material_properties.keys(), expected_materials.keys()):
self.assertTrue(n1 == n2) self.assertTrue(n1 == n2)
for k1, k2 in zip( for k1, k2 in zip(
material_properties[n1].keys(), expected_materials[n2].keys() material_properties[n1].keys(), expected_materials[n2].keys()
): ):
self.assertTrue( self.assertTrue(
torch.allclose( torch.allclose(
material_properties[n1][k1], expected_materials[n2][k2] material_properties[n1][k1], expected_materials[n2][k2]
)
) )
)
def test_load_mtl_texture_atlas_compare_softras(self): def test_load_mtl_texture_atlas_compare_softras(self):
# Load saved texture atlas created with SoftRas. # Load saved texture atlas created with SoftRas.
@ -618,21 +671,25 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"f 1 2 4", "f 1 2 4",
] ]
) )
obj_file = StringIO(obj_file)
with self.assertWarnsRegex(UserWarning, "No mtl file provided"):
verts, faces, aux = load_obj(obj_file)
expected_verts = torch.tensor( with NamedTemporaryFile(mode="w", suffix=".obj") as f:
[[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]], f.write(obj_file)
dtype=torch.float32, f.flush()
)
expected_faces = torch.tensor([[0, 1, 2], [0, 1, 3]], dtype=torch.int64) with self.assertWarnsRegex(UserWarning, "No mtl file provided"):
self.assertTrue(torch.allclose(verts, expected_verts)) verts, faces, aux = load_obj(Path(f.name))
self.assertTrue(torch.allclose(faces.verts_idx, expected_faces))
self.assertTrue(aux.material_colors is None) expected_verts = torch.tensor(
self.assertTrue(aux.texture_images is None) [[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]],
self.assertTrue(aux.normals is None) dtype=torch.float32,
self.assertTrue(aux.verts_uvs is None) )
expected_faces = torch.tensor([[0, 1, 2], [0, 1, 3]], dtype=torch.int64)
self.assertTrue(torch.allclose(verts, expected_verts))
self.assertTrue(torch.allclose(faces.verts_idx, expected_faces))
self.assertTrue(aux.material_colors is None)
self.assertTrue(aux.texture_images is None)
self.assertTrue(aux.normals is None)
self.assertTrue(aux.verts_uvs is None)
def test_load_obj_mtl_no_image(self): def test_load_obj_mtl_no_image(self):
obj_filename = "obj_mtl_no_image/model.obj" obj_filename = "obj_mtl_no_image/model.obj"
@ -825,6 +882,161 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
with self.assertRaisesRegex(ValueError, "same type of texture"): with self.assertRaisesRegex(ValueError, "same type of texture"):
join_meshes_as_batch([mesh_atlas, mesh_rgb, mesh_atlas]) join_meshes_as_batch([mesh_atlas, mesh_rgb, mesh_atlas])
def test_save_obj_with_texture(self):
verts = torch.tensor(
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
dtype=torch.float32,
)
faces = torch.tensor(
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
)
verts_uvs = torch.tensor(
[[0.02, 0.5], [0.3, 0.03], [0.32, 0.12], [0.36, 0.17]],
dtype=torch.float32,
)
faces_uvs = faces
texture_map = torch.randint(size=(2, 2, 3), high=255) / 255.0
with TemporaryDirectory() as temp_dir:
obj_file = os.path.join(temp_dir, "mesh.obj")
save_obj(
obj_file,
verts,
faces,
decimal_places=2,
verts_uvs=verts_uvs,
faces_uvs=faces_uvs,
texture_map=texture_map,
)
expected_obj_file = "\n".join(
[
"",
"mtllib mesh.mtl",
"usemtl mesh",
"",
"v 0.01 0.20 0.30",
"v 0.20 0.03 0.41",
"v 0.30 0.40 0.05",
"v 0.60 0.70 0.80",
"vt 0.02 0.50",
"vt 0.30 0.03",
"vt 0.32 0.12",
"vt 0.36 0.17",
"f 1/1 3/3 2/2",
"f 1/1 2/2 3/3",
"f 4/4 3/3 2/2",
"f 4/4 2/2 1/1",
]
)
expected_mtl_file = "\n".join(["newmtl mesh", "map_Kd mesh.png", ""])
# Check there are only 3 files in the temp dir
tempfiles = ["mesh.obj", "mesh.png", "mesh.mtl"]
tempfiles_dir = os.listdir(temp_dir)
self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
# Check the obj file is saved correctly
actual_file = open(obj_file, "r")
self.assertEqual(actual_file.read(), expected_obj_file)
# Check the mtl file is saved correctly
mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
mtl_file = open(mtl_file_name, "r")
self.assertEqual(mtl_file.read(), expected_mtl_file)
# Check the texture image file is saved correctly
texture_image = load_rgb_image("mesh.png", temp_dir)
self.assertClose(texture_image, texture_map)
def test_save_obj_with_texture_errors(self):
verts = torch.tensor(
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
dtype=torch.float32,
)
faces = torch.tensor(
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
)
verts_uvs = torch.tensor(
[[0.02, 0.5], [0.3, 0.03], [0.32, 0.12], [0.36, 0.17]],
dtype=torch.float32,
)
faces_uvs = faces
texture_map = torch.randint(size=(2, 2, 3), high=255)
expected_obj_file = "\n".join(
[
"v 0.01 0.20 0.30",
"v 0.20 0.03 0.41",
"v 0.30 0.40 0.05",
"v 0.60 0.70 0.80",
"f 1 3 2",
"f 1 2 3",
"f 4 3 2",
"f 4 2 1",
]
)
with TemporaryDirectory() as temp_dir:
obj_file = os.path.join(temp_dir, "mesh.obj")
# If only one of verts_uvs/faces_uvs/texture_map is provided
# then textures are not saved
for arg in [
{"verts_uvs": verts_uvs},
{"faces_uvs": faces_uvs},
{"texture_map": texture_map},
]:
save_obj(
obj_file,
verts,
faces,
decimal_places=2,
**arg,
)
# Check there is only 1 file in the temp dir
tempfiles = ["mesh.obj"]
tempfiles_dir = os.listdir(temp_dir)
self.assertEqual(tempfiles, tempfiles_dir)
# Check the obj file is saved correctly
actual_file = open(obj_file, "r")
self.assertEqual(actual_file.read(), expected_obj_file)
obj_file = StringIO()
with self.assertRaises(ValueError):
save_obj(
obj_file,
verts,
faces,
decimal_places=2,
verts_uvs=verts_uvs,
faces_uvs=faces_uvs[..., 2], # Incorrect shape
texture_map=texture_map,
)
with self.assertRaises(ValueError):
save_obj(
obj_file,
verts,
faces,
decimal_places=2,
verts_uvs=verts_uvs[..., 0], # Incorrect shape
faces_uvs=faces_uvs,
texture_map=texture_map,
)
with self.assertRaises(ValueError):
save_obj(
obj_file,
verts,
faces,
decimal_places=2,
verts_uvs=verts_uvs,
faces_uvs=faces_uvs,
texture_map=texture_map[..., 1], # Incorrect shape
)
@staticmethod @staticmethod
def _bm_save_obj(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int): def _bm_save_obj(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):
return lambda: save_obj(StringIO(), verts, faces, decimal_places) return lambda: save_obj(StringIO(), verts, faces, decimal_places)