Save UV texture with obj mesh

Summary: Add functionality to to save an `.obj` file with associated UV textures: `.png` image and `.mtl` file as well as saving verts_uvs and faces_uvs to the `.obj` file.

Reviewed By: bottler

Differential Revision: D29337562

fbshipit-source-id: 86829b40dae9224088b328e7f5a16eacf8582eb5
This commit is contained in:
Nikhila Ravi 2021-06-24 15:55:09 -07:00 committed by Facebook GitHub Bot
parent 64289a491d
commit 542e2e7c07
3 changed files with 548 additions and 232 deletions

View File

@ -15,6 +15,7 @@ from typing import List, Optional, Union
import numpy as np
import torch
from iopath.common.file_io import PathManager
from PIL import Image
from pytorch3d.common.types import Device
from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
@ -649,42 +650,118 @@ def _load_obj(
def save_obj(
f,
f: Union[str, os.PathLike],
verts,
faces,
decimal_places: Optional[int] = None,
path_manager: Optional[PathManager] = None,
):
*,
verts_uvs: Optional[torch.Tensor] = None,
faces_uvs: Optional[torch.Tensor] = None,
texture_map: Optional[torch.Tensor] = None,
) -> None:
"""
Save a mesh to an .obj file.
Args:
f: File (or path) to which the mesh should be written.
f: File (str or path) to which the mesh should be written.
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
faces: LongTensor of shape (F, 3) giving faces.
decimal_places: Number of decimal places for saving.
path_manager: Optional PathManager for interpreting f if
it is a str.
verts_uvs: FloatTensor of shape (V, 2) giving the uv coordinate per vertex.
faces_uvs: LongTensor of shape (F, 3) giving the index into verts_uvs for
each vertex in the face.
texture_map: FloatTensor of shape (H, W, 3) representing the texture map
for the mesh which will be saved as an image. The values are expected
to be in the range [0, 1],
"""
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
if len(verts) and (verts.dim() != 2 or verts.size(1) != 3):
message = "'verts' should either be empty or of shape (num_verts, 3)."
raise ValueError(message)
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
if len(faces) and (faces.dim() != 2 or faces.size(1) != 3):
message = "'faces' should either be empty or of shape (num_faces, 3)."
raise ValueError(message)
if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3):
message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
raise ValueError(message)
if verts_uvs is not None and (verts_uvs.dim() != 2 or verts_uvs.size(1) != 2):
message = "'verts_uvs' should either be empty or of shape (num_verts, 2)."
raise ValueError(message)
if texture_map is not None and (texture_map.dim() != 3 or texture_map.size(2) != 3):
message = "'texture_map' should either be empty or of shape (H, W, 3)."
raise ValueError(message)
if path_manager is None:
path_manager = PathManager()
save_texture = all([t is not None for t in [faces_uvs, verts_uvs, texture_map]])
output_path = Path(f)
# Save the .obj file
with _open_file(f, path_manager, "w") as f:
return _save(f, verts, faces, decimal_places)
if save_texture:
# Add the header required for the texture info to be loaded correctly
obj_header = "\nmtllib {0}.mtl\nusemtl mesh\n\n".format(output_path.stem)
f.write(obj_header)
_save(
f,
verts,
faces,
decimal_places,
verts_uvs=verts_uvs,
faces_uvs=faces_uvs,
save_texture=save_texture,
)
# Save the .mtl and .png files associated with the texture
if save_texture:
image_path = output_path.with_suffix(".png")
mtl_path = output_path.with_suffix(".mtl")
if isinstance(f, str):
# Back to str for iopath interpretation.
image_path = str(image_path)
mtl_path = str(mtl_path)
# Save texture map to output folder
# pyre-fixme[16] # undefined attribute cpu
texture_map = texture_map.detach().cpu() * 255.0
image = Image.fromarray(texture_map.numpy().astype(np.uint8))
with _open_file(image_path, path_manager, "wb") as im_f:
# pyre-fixme[6] # incompatible parameter type
image.save(im_f)
# Create .mtl file with the material name and texture map filename
# TODO: enable material properties to also be saved.
with _open_file(mtl_path, path_manager, "w") as f_mtl:
lines = f"newmtl mesh\n" f"map_Kd {output_path.stem}.png\n"
f_mtl.write(lines)
# TODO (nikhilar) Speed up this function.
def _save(f, verts, faces, decimal_places: Optional[int] = None) -> None:
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
def _save(
f,
verts,
faces,
decimal_places: Optional[int] = None,
*,
verts_uvs: Optional[torch.Tensor] = None,
faces_uvs: Optional[torch.Tensor] = None,
save_texture: bool = False,
) -> None:
if len(verts) and (verts.dim() != 2 or verts.size(1) != 3):
message = "'verts' should either be empty or of shape (num_verts, 3)."
raise ValueError(message)
if len(faces) and (faces.dim() != 2 or faces.size(1) != 3):
message = "'faces' should either be empty or of shape (num_faces, 3)."
raise ValueError(message)
if not (len(verts) or len(faces)):
warnings.warn("Empty 'verts' and 'faces' arguments provided")
@ -705,15 +782,42 @@ def _save(f, verts, faces, decimal_places: Optional[int] = None) -> None:
vert = [float_str % verts[i, j] for j in range(D)]
lines += "v %s\n" % " ".join(vert)
if save_texture:
if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3):
message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
raise ValueError(message)
if verts_uvs is not None and (verts_uvs.dim() != 2 or verts_uvs.size(1) != 2):
message = "'verts_uvs' should either be empty or of shape (num_verts, 2)."
raise ValueError(message)
# pyre-fixme[16] # undefined attribute cpu
verts_uvs, faces_uvs = verts_uvs.cpu(), faces_uvs.cpu()
# Save verts uvs after verts
if len(verts_uvs):
uV, uD = verts_uvs.shape
for i in range(uV):
uv = [float_str % verts_uvs[i, j] for j in range(uD)]
lines += "vt %s\n" % " ".join(uv)
if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0):
warnings.warn("Faces have invalid indices")
if len(faces):
F, P = faces.shape
for i in range(F):
face = ["%d" % (faces[i, j] + 1) for j in range(P)]
if save_texture:
# Format faces as {verts_idx}/{verts_uvs_idx}
face = [
"%d/%d" % (faces[i, j] + 1, faces_uvs[i, j] + 1) for j in range(P)
]
else:
face = ["%d" % (faces[i, j] + 1) for j in range(P)]
if i + 1 < F:
lines += "f %s\n" % " ".join(face)
elif i + 1 == F:
# No newline at the end of the file.
lines += "f %s" % " ".join(face)

View File

@ -34,7 +34,7 @@ def get_pytorch3d_dir() -> Path:
def load_rgb_image(filename: str, data_dir: Union[str, Path]):
filepath = data_dir / filename
filepath = os.path.join(data_dir, filename)
with Image.open(filepath) as raw_image:
image = torch.from_numpy(np.array(raw_image) / 255.0)
image = image.to(dtype=torch.float32)

View File

@ -7,12 +7,18 @@
import os
import unittest
import warnings
from collections import Counter
from io import StringIO
from pathlib import Path
from tempfile import NamedTemporaryFile
from tempfile import NamedTemporaryFile, TemporaryDirectory
import torch
from common_testing import TestCaseMixin, get_pytorch3d_dir, get_tests_dir
from common_testing import (
TestCaseMixin,
get_pytorch3d_dir,
get_tests_dir,
load_rgb_image,
)
from iopath.common.file_io import PathManager
from pytorch3d.io import IO, load_obj, load_objs_as_meshes, save_obj
from pytorch3d.io.mtl_io import (
@ -42,38 +48,41 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"f 1 2 4 3 1", # Polygons should be split into triangles
]
)
obj_file = StringIO(obj_file)
verts, faces, aux = load_obj(obj_file)
normals = aux.normals
textures = aux.verts_uvs
materials = aux.material_colors
tex_maps = aux.texture_images
with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
expected_verts = torch.tensor(
[[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]],
dtype=torch.float32,
)
expected_faces = torch.tensor(
[
[0, 1, 2], # First face
[0, 1, 3], # Second face (polygon)
[0, 3, 2], # Second face (polygon)
[0, 2, 0], # Second face (polygon)
],
dtype=torch.int64,
)
self.assertTrue(torch.all(verts == expected_verts))
self.assertTrue(torch.all(faces.verts_idx == expected_faces))
padded_vals = -(torch.ones_like(faces.verts_idx))
self.assertTrue(torch.all(faces.normals_idx == padded_vals))
self.assertTrue(torch.all(faces.textures_idx == padded_vals))
self.assertTrue(
torch.all(faces.materials_idx == -(torch.ones(len(expected_faces))))
)
self.assertTrue(normals is None)
self.assertTrue(textures is None)
self.assertTrue(materials is None)
self.assertTrue(tex_maps is None)
verts, faces, aux = load_obj(Path(f.name))
normals = aux.normals
textures = aux.verts_uvs
materials = aux.material_colors
tex_maps = aux.texture_images
expected_verts = torch.tensor(
[[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]],
dtype=torch.float32,
)
expected_faces = torch.tensor(
[
[0, 1, 2], # First face
[0, 1, 3], # Second face (polygon)
[0, 3, 2], # Second face (polygon)
[0, 2, 0], # Second face (polygon)
],
dtype=torch.int64,
)
self.assertTrue(torch.all(verts == expected_verts))
self.assertTrue(torch.all(faces.verts_idx == expected_faces))
padded_vals = -(torch.ones_like(faces.verts_idx))
self.assertTrue(torch.all(faces.normals_idx == padded_vals))
self.assertTrue(torch.all(faces.textures_idx == padded_vals))
self.assertTrue(
torch.all(faces.materials_idx == -(torch.ones(len(expected_faces))))
)
self.assertTrue(normals is None)
self.assertTrue(textures is None)
self.assertTrue(materials is None)
self.assertTrue(tex_maps is None)
def test_load_obj_complex(self):
obj_file = "\n".join(
@ -96,63 +105,71 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"f -1 -2 1", # Negative indexing counts from the end.
]
)
obj_file = StringIO(obj_file)
verts, faces, aux = load_obj(obj_file)
normals = aux.normals
textures = aux.verts_uvs
materials = aux.material_colors
tex_maps = aux.texture_images
expected_verts = torch.tensor(
[
[0.1, 0.2, 0.3],
[0.2, 0.3, 0.4],
[0.3, 0.4, 0.5],
[0.4, 0.5, 0.6],
[0.5, 0.6, 0.7],
],
dtype=torch.float32,
)
expected_faces = torch.tensor(
[
[0, 1, 2], # First face
[0, 1, 3], # Second face (polygon)
[0, 3, 2], # Second face (polygon)
[0, 2, 4], # Second face (polygon)
[1, 2, 3], # Third face (normals / texture)
[4, 3, 0], # Fourth face (negative indices)
],
dtype=torch.int64,
)
expected_normals = torch.tensor(
[
[0.000000, 0.000000, -1.000000],
[-1.000000, -0.000000, -0.000000],
[-0.000000, -0.000000, 1.000000],
],
dtype=torch.float32,
)
expected_textures = torch.tensor(
[[0.749279, 0.501284], [0.999110, 0.501077], [0.999455, 0.750380]],
dtype=torch.float32,
)
expected_faces_normals_idx = -(
torch.ones_like(expected_faces, dtype=torch.int64)
)
expected_faces_normals_idx[4, :] = torch.tensor([1, 1, 1], dtype=torch.int64)
expected_faces_textures_idx = -(
torch.ones_like(expected_faces, dtype=torch.int64)
)
expected_faces_textures_idx[4, :] = torch.tensor([0, 0, 1], dtype=torch.int64)
with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
self.assertTrue(torch.all(verts == expected_verts))
self.assertTrue(torch.all(faces.verts_idx == expected_faces))
self.assertClose(normals, expected_normals)
self.assertClose(textures, expected_textures)
self.assertClose(faces.normals_idx, expected_faces_normals_idx)
self.assertClose(faces.textures_idx, expected_faces_textures_idx)
self.assertTrue(materials is None)
self.assertTrue(tex_maps is None)
verts, faces, aux = load_obj(Path(f.name))
normals = aux.normals
textures = aux.verts_uvs
materials = aux.material_colors
tex_maps = aux.texture_images
expected_verts = torch.tensor(
[
[0.1, 0.2, 0.3],
[0.2, 0.3, 0.4],
[0.3, 0.4, 0.5],
[0.4, 0.5, 0.6],
[0.5, 0.6, 0.7],
],
dtype=torch.float32,
)
expected_faces = torch.tensor(
[
[0, 1, 2], # First face
[0, 1, 3], # Second face (polygon)
[0, 3, 2], # Second face (polygon)
[0, 2, 4], # Second face (polygon)
[1, 2, 3], # Third face (normals / texture)
[4, 3, 0], # Fourth face (negative indices)
],
dtype=torch.int64,
)
expected_normals = torch.tensor(
[
[0.000000, 0.000000, -1.000000],
[-1.000000, -0.000000, -0.000000],
[-0.000000, -0.000000, 1.000000],
],
dtype=torch.float32,
)
expected_textures = torch.tensor(
[[0.749279, 0.501284], [0.999110, 0.501077], [0.999455, 0.750380]],
dtype=torch.float32,
)
expected_faces_normals_idx = -(
torch.ones_like(expected_faces, dtype=torch.int64)
)
expected_faces_normals_idx[4, :] = torch.tensor(
[1, 1, 1], dtype=torch.int64
)
expected_faces_textures_idx = -(
torch.ones_like(expected_faces, dtype=torch.int64)
)
expected_faces_textures_idx[4, :] = torch.tensor(
[0, 0, 1], dtype=torch.int64
)
self.assertTrue(torch.all(verts == expected_verts))
self.assertTrue(torch.all(faces.verts_idx == expected_faces))
self.assertClose(normals, expected_normals)
self.assertClose(textures, expected_textures)
self.assertClose(faces.normals_idx, expected_faces_normals_idx)
self.assertClose(faces.textures_idx, expected_faces_textures_idx)
self.assertTrue(materials is None)
self.assertTrue(tex_maps is None)
def test_load_obj_complex_pluggable(self):
"""
@ -230,7 +247,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"f 2//1 3//1 4//2",
]
)
obj_file = StringIO(obj_file)
expected_faces_normals_idx = torch.tensor([[0, 0, 1]], dtype=torch.int64)
expected_normals = torch.tensor(
[[0.000000, 0.000000, -1.000000], [-1.000000, -0.000000, -0.000000]],
@ -240,19 +257,24 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
[[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]],
dtype=torch.float32,
)
verts, faces, aux = load_obj(obj_file)
normals = aux.normals
textures = aux.verts_uvs
materials = aux.material_colors
tex_maps = aux.texture_images
self.assertClose(faces.normals_idx, expected_faces_normals_idx)
self.assertClose(normals, expected_normals)
self.assertClose(verts, expected_verts)
# Textures idx padded with -1.
self.assertClose(faces.textures_idx, torch.ones_like(faces.verts_idx) * -1)
self.assertTrue(textures is None)
self.assertTrue(materials is None)
self.assertTrue(tex_maps is None)
with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
verts, faces, aux = load_obj(Path(f.name))
normals = aux.normals
textures = aux.verts_uvs
materials = aux.material_colors
tex_maps = aux.texture_images
self.assertClose(faces.normals_idx, expected_faces_normals_idx)
self.assertClose(normals, expected_normals)
self.assertClose(verts, expected_verts)
# Textures idx padded with -1.
self.assertClose(faces.textures_idx, torch.ones_like(faces.verts_idx) * -1)
self.assertTrue(textures is None)
self.assertTrue(materials is None)
self.assertTrue(tex_maps is None)
def test_load_obj_textures_only(self):
obj_file = "\n".join(
@ -266,7 +288,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"f 2/1 3/1 4/2",
]
)
obj_file = StringIO(obj_file)
expected_faces_textures_idx = torch.tensor([[0, 0, 1]], dtype=torch.int64)
expected_textures = torch.tensor(
[[0.999110, 0.501077], [0.999455, 0.750380]], dtype=torch.float32
@ -275,72 +297,89 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
[[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]],
dtype=torch.float32,
)
verts, faces, aux = load_obj(obj_file)
normals = aux.normals
textures = aux.verts_uvs
materials = aux.material_colors
tex_maps = aux.texture_images
self.assertClose(faces.textures_idx, expected_faces_textures_idx)
self.assertClose(expected_textures, textures)
self.assertClose(expected_verts, verts)
self.assertTrue(
torch.all(faces.normals_idx == -(torch.ones_like(faces.textures_idx)))
)
self.assertTrue(normals is None)
self.assertTrue(materials is None)
self.assertTrue(tex_maps is None)
with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
verts, faces, aux = load_obj(Path(f.name))
normals = aux.normals
textures = aux.verts_uvs
materials = aux.material_colors
tex_maps = aux.texture_images
self.assertClose(faces.textures_idx, expected_faces_textures_idx)
self.assertClose(expected_textures, textures)
self.assertClose(expected_verts, verts)
self.assertTrue(
torch.all(faces.normals_idx == -(torch.ones_like(faces.textures_idx)))
)
self.assertTrue(normals is None)
self.assertTrue(materials is None)
self.assertTrue(tex_maps is None)
def test_load_obj_error_textures(self):
obj_file = "\n".join(["vt 0.1"])
obj_file = StringIO(obj_file)
with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
with self.assertRaises(ValueError) as err:
load_obj(obj_file)
self.assertTrue("does not have 2 values" in str(err.exception))
with self.assertRaises(ValueError) as err:
load_obj(Path(f.name))
self.assertTrue("does not have 2 values" in str(err.exception))
def test_load_obj_error_normals(self):
obj_file = "\n".join(["vn 0.1"])
obj_file = StringIO(obj_file)
with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
with self.assertRaises(ValueError) as err:
load_obj(obj_file)
self.assertTrue("does not have 3 values" in str(err.exception))
with self.assertRaises(ValueError) as err:
load_obj(Path(f.name))
self.assertTrue("does not have 3 values" in str(err.exception))
def test_load_obj_error_vertices(self):
obj_file = "\n".join(["v 1"])
obj_file = StringIO(obj_file)
with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
with self.assertRaises(ValueError) as err:
load_obj(obj_file)
self.assertTrue("does not have 3 values" in str(err.exception))
with self.assertRaises(ValueError) as err:
load_obj(Path(f.name))
self.assertTrue("does not have 3 values" in str(err.exception))
def test_load_obj_error_inconsistent_triplets(self):
obj_file = "\n".join(["f 2//1 3/1 4/1/2"])
obj_file = StringIO(obj_file)
with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
with self.assertRaises(ValueError) as err:
load_obj(obj_file)
self.assertTrue("Vertex properties are inconsistent" in str(err.exception))
with self.assertRaises(ValueError) as err:
load_obj(Path(f.name))
self.assertTrue("Vertex properties are inconsistent" in str(err.exception))
def test_load_obj_error_too_many_vertex_properties(self):
obj_file = "\n".join(["f 2/1/1/3"])
obj_file = StringIO(obj_file)
with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
with self.assertRaises(ValueError) as err:
load_obj(obj_file)
self.assertTrue(
"Face vertices can only have 3 properties" in str(err.exception)
)
with self.assertRaises(ValueError) as err:
load_obj(Path(f.name))
self.assertTrue(
"Face vertices can only have 3 properties" in str(err.exception)
)
def test_load_obj_error_invalid_vertex_indices(self):
obj_file = "\n".join(
["v 0.1 0.2 0.3", "v 0.1 0.2 0.3", "v 0.1 0.2 0.3", "f -2 5 1"]
)
obj_file = StringIO(obj_file)
with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"):
load_obj(obj_file)
with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"):
load_obj(Path(f.name))
def test_load_obj_error_invalid_normal_indices(self):
obj_file = "\n".join(
@ -354,10 +393,12 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"f -2/2 2/4 1/1",
]
)
obj_file = StringIO(obj_file)
with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"):
load_obj(obj_file)
with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"):
load_obj(Path(f.name))
def test_load_obj_error_invalid_texture_indices(self):
obj_file = "\n".join(
@ -371,17 +412,20 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"f -2//2 2//6 1//1",
]
)
obj_file = StringIO(obj_file)
with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"):
load_obj(obj_file)
with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"):
load_obj(Path(f.name))
def test_save_obj_invalid_shapes(self):
# Invalid vertices shape
with self.assertRaises(ValueError) as error:
verts = torch.FloatTensor([[0.1, 0.2, 0.3, 0.4]]) # (V, 4)
faces = torch.LongTensor([[0, 1, 2]])
save_obj(StringIO(), verts, faces)
with NamedTemporaryFile(mode="w", suffix=".obj") as f:
save_obj(Path(f.name), verts, faces)
expected_message = (
"Argument 'verts' should either be empty or of shape (num_verts, 3)."
)
@ -391,7 +435,8 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
with self.assertRaises(ValueError) as error:
verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
faces = torch.LongTensor([[0, 1, 2, 3]]) # (F, 4)
save_obj(StringIO(), verts, faces)
with NamedTemporaryFile(mode="w", suffix=".obj") as f:
save_obj(Path(f.name), verts, faces)
expected_message = (
"Argument 'faces' should either be empty or of shape (num_faces, 3)."
)
@ -402,24 +447,28 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
faces = torch.LongTensor([[0, 1, 2]])
with self.assertWarnsRegex(UserWarning, message_regex):
save_obj(StringIO(), verts, faces)
with NamedTemporaryFile(mode="w", suffix=".obj") as f:
save_obj(Path(f.name), verts, faces)
faces = torch.LongTensor([[-1, 0, 1]])
with self.assertWarnsRegex(UserWarning, message_regex):
save_obj(StringIO(), verts, faces)
with NamedTemporaryFile(mode="w", suffix=".obj") as f:
save_obj(Path(f.name), verts, faces)
def _test_save_load(self, verts, faces):
f = StringIO()
save_obj(f, verts, faces)
f.seek(0)
expected_verts, expected_faces = verts, faces
if not len(expected_verts): # Always compare with a (V, 3) tensor
expected_verts = torch.zeros(size=(0, 3), dtype=torch.float32)
if not len(expected_faces): # Always compare with an (F, 3) tensor
expected_faces = torch.zeros(size=(0, 3), dtype=torch.int64)
actual_verts, actual_faces, _ = load_obj(f)
self.assertClose(expected_verts, actual_verts)
self.assertClose(expected_faces, actual_faces.verts_idx)
with NamedTemporaryFile(mode="w", suffix=".obj") as f:
file_path = Path(f.name)
save_obj(file_path, verts, faces)
f.flush()
expected_verts, expected_faces = verts, faces
if not len(expected_verts): # Always compare with a (V, 3) tensor
expected_verts = torch.zeros(size=(0, 3), dtype=torch.float32)
if not len(expected_faces): # Always compare with an (F, 3) tensor
expected_faces = torch.zeros(size=(0, 3), dtype=torch.int64)
actual_verts, actual_faces, _ = load_obj(file_path)
self.assertClose(expected_verts, actual_verts)
self.assertClose(expected_faces, actual_faces.verts_idx)
def test_empty_save_load_obj(self):
# Vertices + empty faces
@ -467,22 +516,23 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
faces = torch.tensor(
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
)
obj_file = StringIO()
save_obj(obj_file, verts, faces, decimal_places=2)
expected_file = "\n".join(
[
"v 0.01 0.20 0.30",
"v 0.20 0.03 0.41",
"v 0.30 0.40 0.05",
"v 0.60 0.70 0.80",
"f 1 3 2",
"f 1 2 3",
"f 4 3 2",
"f 4 2 1",
]
)
actual_file = obj_file.getvalue()
self.assertEqual(actual_file, expected_file)
with NamedTemporaryFile(mode="w", suffix=".obj") as f:
save_obj(Path(f.name), verts, faces, decimal_places=2)
expected_file = "\n".join(
[
"v 0.01 0.20 0.30",
"v 0.20 0.03 0.41",
"v 0.30 0.40 0.05",
"v 0.60 0.70 0.80",
"f 1 3 2",
"f 1 2 3",
"f 4 3 2",
"f 4 2 1",
]
)
actual_file = open(Path(f.name), "r")
self.assertEqual(actual_file.read(), expected_file)
def test_load_mtl(self):
obj_filename = "cow_mesh/cow.obj"
@ -534,36 +584,39 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"Ns 10.0",
]
)
mtl_file = StringIO(mtl_file)
material_properties, texture_files = _parse_mtl(
mtl_file, path_manager=PathManager(), device="cpu"
)
with NamedTemporaryFile(mode="w", suffix=".mtl") as f:
f.write(mtl_file)
f.flush()
dtype = torch.float32
expected_materials = {
"material_1": {
"ambient_color": torch.tensor([1.0, 1.0, 1.0], dtype=dtype),
"diffuse_color": torch.tensor([1.0, 1.0, 1.0], dtype=dtype),
"specular_color": torch.tensor([0.0, 0.0, 0.0], dtype=dtype),
"shininess": torch.tensor([10.0], dtype=dtype),
material_properties, texture_files = _parse_mtl(
Path(f.name), path_manager=PathManager(), device="cpu"
)
dtype = torch.float32
expected_materials = {
"material_1": {
"ambient_color": torch.tensor([1.0, 1.0, 1.0], dtype=dtype),
"diffuse_color": torch.tensor([1.0, 1.0, 1.0], dtype=dtype),
"specular_color": torch.tensor([0.0, 0.0, 0.0], dtype=dtype),
"shininess": torch.tensor([10.0], dtype=dtype),
}
}
}
# Check that there is a material with name material_1
self.assertTrue(tuple(texture_files.keys()) == ("material_1",))
# Check that there is an image with name material 1.png
self.assertTrue(texture_files["material_1"] == "material 1.png")
# Check that there is a material with name material_1
self.assertTrue(tuple(texture_files.keys()) == ("material_1",))
# Check that there is an image with name material 1.png
self.assertTrue(texture_files["material_1"] == "material 1.png")
# Check all keys and values in dictionary are the same.
for n1, n2 in zip(material_properties.keys(), expected_materials.keys()):
self.assertTrue(n1 == n2)
for k1, k2 in zip(
material_properties[n1].keys(), expected_materials[n2].keys()
):
self.assertTrue(
torch.allclose(
material_properties[n1][k1], expected_materials[n2][k2]
# Check all keys and values in dictionary are the same.
for n1, n2 in zip(material_properties.keys(), expected_materials.keys()):
self.assertTrue(n1 == n2)
for k1, k2 in zip(
material_properties[n1].keys(), expected_materials[n2].keys()
):
self.assertTrue(
torch.allclose(
material_properties[n1][k1], expected_materials[n2][k2]
)
)
)
def test_load_mtl_texture_atlas_compare_softras(self):
# Load saved texture atlas created with SoftRas.
@ -618,21 +671,25 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"f 1 2 4",
]
)
obj_file = StringIO(obj_file)
with self.assertWarnsRegex(UserWarning, "No mtl file provided"):
verts, faces, aux = load_obj(obj_file)
expected_verts = torch.tensor(
[[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]],
dtype=torch.float32,
)
expected_faces = torch.tensor([[0, 1, 2], [0, 1, 3]], dtype=torch.int64)
self.assertTrue(torch.allclose(verts, expected_verts))
self.assertTrue(torch.allclose(faces.verts_idx, expected_faces))
self.assertTrue(aux.material_colors is None)
self.assertTrue(aux.texture_images is None)
self.assertTrue(aux.normals is None)
self.assertTrue(aux.verts_uvs is None)
with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
with self.assertWarnsRegex(UserWarning, "No mtl file provided"):
verts, faces, aux = load_obj(Path(f.name))
expected_verts = torch.tensor(
[[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]],
dtype=torch.float32,
)
expected_faces = torch.tensor([[0, 1, 2], [0, 1, 3]], dtype=torch.int64)
self.assertTrue(torch.allclose(verts, expected_verts))
self.assertTrue(torch.allclose(faces.verts_idx, expected_faces))
self.assertTrue(aux.material_colors is None)
self.assertTrue(aux.texture_images is None)
self.assertTrue(aux.normals is None)
self.assertTrue(aux.verts_uvs is None)
def test_load_obj_mtl_no_image(self):
obj_filename = "obj_mtl_no_image/model.obj"
@ -825,6 +882,161 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
with self.assertRaisesRegex(ValueError, "same type of texture"):
join_meshes_as_batch([mesh_atlas, mesh_rgb, mesh_atlas])
def test_save_obj_with_texture(self):
verts = torch.tensor(
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
dtype=torch.float32,
)
faces = torch.tensor(
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
)
verts_uvs = torch.tensor(
[[0.02, 0.5], [0.3, 0.03], [0.32, 0.12], [0.36, 0.17]],
dtype=torch.float32,
)
faces_uvs = faces
texture_map = torch.randint(size=(2, 2, 3), high=255) / 255.0
with TemporaryDirectory() as temp_dir:
obj_file = os.path.join(temp_dir, "mesh.obj")
save_obj(
obj_file,
verts,
faces,
decimal_places=2,
verts_uvs=verts_uvs,
faces_uvs=faces_uvs,
texture_map=texture_map,
)
expected_obj_file = "\n".join(
[
"",
"mtllib mesh.mtl",
"usemtl mesh",
"",
"v 0.01 0.20 0.30",
"v 0.20 0.03 0.41",
"v 0.30 0.40 0.05",
"v 0.60 0.70 0.80",
"vt 0.02 0.50",
"vt 0.30 0.03",
"vt 0.32 0.12",
"vt 0.36 0.17",
"f 1/1 3/3 2/2",
"f 1/1 2/2 3/3",
"f 4/4 3/3 2/2",
"f 4/4 2/2 1/1",
]
)
expected_mtl_file = "\n".join(["newmtl mesh", "map_Kd mesh.png", ""])
# Check there are only 3 files in the temp dir
tempfiles = ["mesh.obj", "mesh.png", "mesh.mtl"]
tempfiles_dir = os.listdir(temp_dir)
self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
# Check the obj file is saved correctly
actual_file = open(obj_file, "r")
self.assertEqual(actual_file.read(), expected_obj_file)
# Check the mtl file is saved correctly
mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
mtl_file = open(mtl_file_name, "r")
self.assertEqual(mtl_file.read(), expected_mtl_file)
# Check the texture image file is saved correctly
texture_image = load_rgb_image("mesh.png", temp_dir)
self.assertClose(texture_image, texture_map)
def test_save_obj_with_texture_errors(self):
verts = torch.tensor(
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
dtype=torch.float32,
)
faces = torch.tensor(
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
)
verts_uvs = torch.tensor(
[[0.02, 0.5], [0.3, 0.03], [0.32, 0.12], [0.36, 0.17]],
dtype=torch.float32,
)
faces_uvs = faces
texture_map = torch.randint(size=(2, 2, 3), high=255)
expected_obj_file = "\n".join(
[
"v 0.01 0.20 0.30",
"v 0.20 0.03 0.41",
"v 0.30 0.40 0.05",
"v 0.60 0.70 0.80",
"f 1 3 2",
"f 1 2 3",
"f 4 3 2",
"f 4 2 1",
]
)
with TemporaryDirectory() as temp_dir:
obj_file = os.path.join(temp_dir, "mesh.obj")
# If only one of verts_uvs/faces_uvs/texture_map is provided
# then textures are not saved
for arg in [
{"verts_uvs": verts_uvs},
{"faces_uvs": faces_uvs},
{"texture_map": texture_map},
]:
save_obj(
obj_file,
verts,
faces,
decimal_places=2,
**arg,
)
# Check there is only 1 file in the temp dir
tempfiles = ["mesh.obj"]
tempfiles_dir = os.listdir(temp_dir)
self.assertEqual(tempfiles, tempfiles_dir)
# Check the obj file is saved correctly
actual_file = open(obj_file, "r")
self.assertEqual(actual_file.read(), expected_obj_file)
obj_file = StringIO()
with self.assertRaises(ValueError):
save_obj(
obj_file,
verts,
faces,
decimal_places=2,
verts_uvs=verts_uvs,
faces_uvs=faces_uvs[..., 2], # Incorrect shape
texture_map=texture_map,
)
with self.assertRaises(ValueError):
save_obj(
obj_file,
verts,
faces,
decimal_places=2,
verts_uvs=verts_uvs[..., 0], # Incorrect shape
faces_uvs=faces_uvs,
texture_map=texture_map,
)
with self.assertRaises(ValueError):
save_obj(
obj_file,
verts,
faces,
decimal_places=2,
verts_uvs=verts_uvs,
faces_uvs=faces_uvs,
texture_map=texture_map[..., 1], # Incorrect shape
)
@staticmethod
def _bm_save_obj(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):
return lambda: save_obj(StringIO(), verts, faces, decimal_places)