From 542e2e7c07fdeef815312b087acfa58094a7aa1e Mon Sep 17 00:00:00 2001 From: Nikhila Ravi Date: Thu, 24 Jun 2021 15:55:09 -0700 Subject: [PATCH] Save UV texture with obj mesh Summary: Add functionality to to save an `.obj` file with associated UV textures: `.png` image and `.mtl` file as well as saving verts_uvs and faces_uvs to the `.obj` file. Reviewed By: bottler Differential Revision: D29337562 fbshipit-source-id: 86829b40dae9224088b328e7f5a16eacf8582eb5 --- pytorch3d/io/obj_io.py | 128 +++++++- tests/common_testing.py | 2 +- tests/test_io_obj.py | 650 ++++++++++++++++++++++++++-------------- 3 files changed, 548 insertions(+), 232 deletions(-) diff --git a/pytorch3d/io/obj_io.py b/pytorch3d/io/obj_io.py index cccc0b00..05df1bfd 100644 --- a/pytorch3d/io/obj_io.py +++ b/pytorch3d/io/obj_io.py @@ -15,6 +15,7 @@ from typing import List, Optional, Union import numpy as np import torch from iopath.common.file_io import PathManager +from PIL import Image from pytorch3d.common.types import Device from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file @@ -649,42 +650,118 @@ def _load_obj( def save_obj( - f, + f: Union[str, os.PathLike], verts, faces, decimal_places: Optional[int] = None, path_manager: Optional[PathManager] = None, -): + *, + verts_uvs: Optional[torch.Tensor] = None, + faces_uvs: Optional[torch.Tensor] = None, + texture_map: Optional[torch.Tensor] = None, +) -> None: """ Save a mesh to an .obj file. Args: - f: File (or path) to which the mesh should be written. + f: File (str or path) to which the mesh should be written. verts: FloatTensor of shape (V, 3) giving vertex coordinates. faces: LongTensor of shape (F, 3) giving faces. decimal_places: Number of decimal places for saving. path_manager: Optional PathManager for interpreting f if it is a str. + verts_uvs: FloatTensor of shape (V, 2) giving the uv coordinate per vertex. + faces_uvs: LongTensor of shape (F, 3) giving the index into verts_uvs for + each vertex in the face. + texture_map: FloatTensor of shape (H, W, 3) representing the texture map + for the mesh which will be saved as an image. The values are expected + to be in the range [0, 1], """ - if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3): - message = "Argument 'verts' should either be empty or of shape (num_verts, 3)." + if len(verts) and (verts.dim() != 2 or verts.size(1) != 3): + message = "'verts' should either be empty or of shape (num_verts, 3)." raise ValueError(message) - if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3): - message = "Argument 'faces' should either be empty or of shape (num_faces, 3)." + if len(faces) and (faces.dim() != 2 or faces.size(1) != 3): + message = "'faces' should either be empty or of shape (num_faces, 3)." + raise ValueError(message) + + if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3): + message = "'faces_uvs' should either be empty or of shape (num_faces, 3)." + raise ValueError(message) + + if verts_uvs is not None and (verts_uvs.dim() != 2 or verts_uvs.size(1) != 2): + message = "'verts_uvs' should either be empty or of shape (num_verts, 2)." + raise ValueError(message) + + if texture_map is not None and (texture_map.dim() != 3 or texture_map.size(2) != 3): + message = "'texture_map' should either be empty or of shape (H, W, 3)." raise ValueError(message) if path_manager is None: path_manager = PathManager() + save_texture = all([t is not None for t in [faces_uvs, verts_uvs, texture_map]]) + output_path = Path(f) + + # Save the .obj file with _open_file(f, path_manager, "w") as f: - return _save(f, verts, faces, decimal_places) + if save_texture: + # Add the header required for the texture info to be loaded correctly + obj_header = "\nmtllib {0}.mtl\nusemtl mesh\n\n".format(output_path.stem) + f.write(obj_header) + _save( + f, + verts, + faces, + decimal_places, + verts_uvs=verts_uvs, + faces_uvs=faces_uvs, + save_texture=save_texture, + ) + + # Save the .mtl and .png files associated with the texture + if save_texture: + image_path = output_path.with_suffix(".png") + mtl_path = output_path.with_suffix(".mtl") + if isinstance(f, str): + # Back to str for iopath interpretation. + image_path = str(image_path) + mtl_path = str(mtl_path) + + # Save texture map to output folder + # pyre-fixme[16] # undefined attribute cpu + texture_map = texture_map.detach().cpu() * 255.0 + image = Image.fromarray(texture_map.numpy().astype(np.uint8)) + with _open_file(image_path, path_manager, "wb") as im_f: + # pyre-fixme[6] # incompatible parameter type + image.save(im_f) + + # Create .mtl file with the material name and texture map filename + # TODO: enable material properties to also be saved. + with _open_file(mtl_path, path_manager, "w") as f_mtl: + lines = f"newmtl mesh\n" f"map_Kd {output_path.stem}.png\n" + f_mtl.write(lines) # TODO (nikhilar) Speed up this function. -def _save(f, verts, faces, decimal_places: Optional[int] = None) -> None: - assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3) - assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3) +def _save( + f, + verts, + faces, + decimal_places: Optional[int] = None, + *, + verts_uvs: Optional[torch.Tensor] = None, + faces_uvs: Optional[torch.Tensor] = None, + save_texture: bool = False, +) -> None: + + if len(verts) and (verts.dim() != 2 or verts.size(1) != 3): + message = "'verts' should either be empty or of shape (num_verts, 3)." + raise ValueError(message) + + if len(faces) and (faces.dim() != 2 or faces.size(1) != 3): + message = "'faces' should either be empty or of shape (num_faces, 3)." + raise ValueError(message) if not (len(verts) or len(faces)): warnings.warn("Empty 'verts' and 'faces' arguments provided") @@ -705,15 +782,42 @@ def _save(f, verts, faces, decimal_places: Optional[int] = None) -> None: vert = [float_str % verts[i, j] for j in range(D)] lines += "v %s\n" % " ".join(vert) + if save_texture: + if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3): + message = "'faces_uvs' should either be empty or of shape (num_faces, 3)." + raise ValueError(message) + + if verts_uvs is not None and (verts_uvs.dim() != 2 or verts_uvs.size(1) != 2): + message = "'verts_uvs' should either be empty or of shape (num_verts, 2)." + raise ValueError(message) + + # pyre-fixme[16] # undefined attribute cpu + verts_uvs, faces_uvs = verts_uvs.cpu(), faces_uvs.cpu() + + # Save verts uvs after verts + if len(verts_uvs): + uV, uD = verts_uvs.shape + for i in range(uV): + uv = [float_str % verts_uvs[i, j] for j in range(uD)] + lines += "vt %s\n" % " ".join(uv) + if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0): warnings.warn("Faces have invalid indices") if len(faces): F, P = faces.shape for i in range(F): - face = ["%d" % (faces[i, j] + 1) for j in range(P)] + if save_texture: + # Format faces as {verts_idx}/{verts_uvs_idx} + face = [ + "%d/%d" % (faces[i, j] + 1, faces_uvs[i, j] + 1) for j in range(P) + ] + else: + face = ["%d" % (faces[i, j] + 1) for j in range(P)] + if i + 1 < F: lines += "f %s\n" % " ".join(face) + elif i + 1 == F: # No newline at the end of the file. lines += "f %s" % " ".join(face) diff --git a/tests/common_testing.py b/tests/common_testing.py index ea57cb8d..3a73bfeb 100644 --- a/tests/common_testing.py +++ b/tests/common_testing.py @@ -34,7 +34,7 @@ def get_pytorch3d_dir() -> Path: def load_rgb_image(filename: str, data_dir: Union[str, Path]): - filepath = data_dir / filename + filepath = os.path.join(data_dir, filename) with Image.open(filepath) as raw_image: image = torch.from_numpy(np.array(raw_image) / 255.0) image = image.to(dtype=torch.float32) diff --git a/tests/test_io_obj.py b/tests/test_io_obj.py index 19dcfc77..234526ab 100644 --- a/tests/test_io_obj.py +++ b/tests/test_io_obj.py @@ -7,12 +7,18 @@ import os import unittest import warnings +from collections import Counter from io import StringIO from pathlib import Path -from tempfile import NamedTemporaryFile +from tempfile import NamedTemporaryFile, TemporaryDirectory import torch -from common_testing import TestCaseMixin, get_pytorch3d_dir, get_tests_dir +from common_testing import ( + TestCaseMixin, + get_pytorch3d_dir, + get_tests_dir, + load_rgb_image, +) from iopath.common.file_io import PathManager from pytorch3d.io import IO, load_obj, load_objs_as_meshes, save_obj from pytorch3d.io.mtl_io import ( @@ -42,38 +48,41 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): "f 1 2 4 3 1", # Polygons should be split into triangles ] ) - obj_file = StringIO(obj_file) - verts, faces, aux = load_obj(obj_file) - normals = aux.normals - textures = aux.verts_uvs - materials = aux.material_colors - tex_maps = aux.texture_images + with NamedTemporaryFile(mode="w", suffix=".obj") as f: + f.write(obj_file) + f.flush() - expected_verts = torch.tensor( - [[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]], - dtype=torch.float32, - ) - expected_faces = torch.tensor( - [ - [0, 1, 2], # First face - [0, 1, 3], # Second face (polygon) - [0, 3, 2], # Second face (polygon) - [0, 2, 0], # Second face (polygon) - ], - dtype=torch.int64, - ) - self.assertTrue(torch.all(verts == expected_verts)) - self.assertTrue(torch.all(faces.verts_idx == expected_faces)) - padded_vals = -(torch.ones_like(faces.verts_idx)) - self.assertTrue(torch.all(faces.normals_idx == padded_vals)) - self.assertTrue(torch.all(faces.textures_idx == padded_vals)) - self.assertTrue( - torch.all(faces.materials_idx == -(torch.ones(len(expected_faces)))) - ) - self.assertTrue(normals is None) - self.assertTrue(textures is None) - self.assertTrue(materials is None) - self.assertTrue(tex_maps is None) + verts, faces, aux = load_obj(Path(f.name)) + normals = aux.normals + textures = aux.verts_uvs + materials = aux.material_colors + tex_maps = aux.texture_images + + expected_verts = torch.tensor( + [[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]], + dtype=torch.float32, + ) + expected_faces = torch.tensor( + [ + [0, 1, 2], # First face + [0, 1, 3], # Second face (polygon) + [0, 3, 2], # Second face (polygon) + [0, 2, 0], # Second face (polygon) + ], + dtype=torch.int64, + ) + self.assertTrue(torch.all(verts == expected_verts)) + self.assertTrue(torch.all(faces.verts_idx == expected_faces)) + padded_vals = -(torch.ones_like(faces.verts_idx)) + self.assertTrue(torch.all(faces.normals_idx == padded_vals)) + self.assertTrue(torch.all(faces.textures_idx == padded_vals)) + self.assertTrue( + torch.all(faces.materials_idx == -(torch.ones(len(expected_faces)))) + ) + self.assertTrue(normals is None) + self.assertTrue(textures is None) + self.assertTrue(materials is None) + self.assertTrue(tex_maps is None) def test_load_obj_complex(self): obj_file = "\n".join( @@ -96,63 +105,71 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): "f -1 -2 1", # Negative indexing counts from the end. ] ) - obj_file = StringIO(obj_file) - verts, faces, aux = load_obj(obj_file) - normals = aux.normals - textures = aux.verts_uvs - materials = aux.material_colors - tex_maps = aux.texture_images - expected_verts = torch.tensor( - [ - [0.1, 0.2, 0.3], - [0.2, 0.3, 0.4], - [0.3, 0.4, 0.5], - [0.4, 0.5, 0.6], - [0.5, 0.6, 0.7], - ], - dtype=torch.float32, - ) - expected_faces = torch.tensor( - [ - [0, 1, 2], # First face - [0, 1, 3], # Second face (polygon) - [0, 3, 2], # Second face (polygon) - [0, 2, 4], # Second face (polygon) - [1, 2, 3], # Third face (normals / texture) - [4, 3, 0], # Fourth face (negative indices) - ], - dtype=torch.int64, - ) - expected_normals = torch.tensor( - [ - [0.000000, 0.000000, -1.000000], - [-1.000000, -0.000000, -0.000000], - [-0.000000, -0.000000, 1.000000], - ], - dtype=torch.float32, - ) - expected_textures = torch.tensor( - [[0.749279, 0.501284], [0.999110, 0.501077], [0.999455, 0.750380]], - dtype=torch.float32, - ) - expected_faces_normals_idx = -( - torch.ones_like(expected_faces, dtype=torch.int64) - ) - expected_faces_normals_idx[4, :] = torch.tensor([1, 1, 1], dtype=torch.int64) - expected_faces_textures_idx = -( - torch.ones_like(expected_faces, dtype=torch.int64) - ) - expected_faces_textures_idx[4, :] = torch.tensor([0, 0, 1], dtype=torch.int64) + with NamedTemporaryFile(mode="w", suffix=".obj") as f: + f.write(obj_file) + f.flush() - self.assertTrue(torch.all(verts == expected_verts)) - self.assertTrue(torch.all(faces.verts_idx == expected_faces)) - self.assertClose(normals, expected_normals) - self.assertClose(textures, expected_textures) - self.assertClose(faces.normals_idx, expected_faces_normals_idx) - self.assertClose(faces.textures_idx, expected_faces_textures_idx) - self.assertTrue(materials is None) - self.assertTrue(tex_maps is None) + verts, faces, aux = load_obj(Path(f.name)) + normals = aux.normals + textures = aux.verts_uvs + materials = aux.material_colors + tex_maps = aux.texture_images + + expected_verts = torch.tensor( + [ + [0.1, 0.2, 0.3], + [0.2, 0.3, 0.4], + [0.3, 0.4, 0.5], + [0.4, 0.5, 0.6], + [0.5, 0.6, 0.7], + ], + dtype=torch.float32, + ) + expected_faces = torch.tensor( + [ + [0, 1, 2], # First face + [0, 1, 3], # Second face (polygon) + [0, 3, 2], # Second face (polygon) + [0, 2, 4], # Second face (polygon) + [1, 2, 3], # Third face (normals / texture) + [4, 3, 0], # Fourth face (negative indices) + ], + dtype=torch.int64, + ) + expected_normals = torch.tensor( + [ + [0.000000, 0.000000, -1.000000], + [-1.000000, -0.000000, -0.000000], + [-0.000000, -0.000000, 1.000000], + ], + dtype=torch.float32, + ) + expected_textures = torch.tensor( + [[0.749279, 0.501284], [0.999110, 0.501077], [0.999455, 0.750380]], + dtype=torch.float32, + ) + expected_faces_normals_idx = -( + torch.ones_like(expected_faces, dtype=torch.int64) + ) + expected_faces_normals_idx[4, :] = torch.tensor( + [1, 1, 1], dtype=torch.int64 + ) + expected_faces_textures_idx = -( + torch.ones_like(expected_faces, dtype=torch.int64) + ) + expected_faces_textures_idx[4, :] = torch.tensor( + [0, 0, 1], dtype=torch.int64 + ) + + self.assertTrue(torch.all(verts == expected_verts)) + self.assertTrue(torch.all(faces.verts_idx == expected_faces)) + self.assertClose(normals, expected_normals) + self.assertClose(textures, expected_textures) + self.assertClose(faces.normals_idx, expected_faces_normals_idx) + self.assertClose(faces.textures_idx, expected_faces_textures_idx) + self.assertTrue(materials is None) + self.assertTrue(tex_maps is None) def test_load_obj_complex_pluggable(self): """ @@ -230,7 +247,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): "f 2//1 3//1 4//2", ] ) - obj_file = StringIO(obj_file) + expected_faces_normals_idx = torch.tensor([[0, 0, 1]], dtype=torch.int64) expected_normals = torch.tensor( [[0.000000, 0.000000, -1.000000], [-1.000000, -0.000000, -0.000000]], @@ -240,19 +257,24 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): [[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]], dtype=torch.float32, ) - verts, faces, aux = load_obj(obj_file) - normals = aux.normals - textures = aux.verts_uvs - materials = aux.material_colors - tex_maps = aux.texture_images - self.assertClose(faces.normals_idx, expected_faces_normals_idx) - self.assertClose(normals, expected_normals) - self.assertClose(verts, expected_verts) - # Textures idx padded with -1. - self.assertClose(faces.textures_idx, torch.ones_like(faces.verts_idx) * -1) - self.assertTrue(textures is None) - self.assertTrue(materials is None) - self.assertTrue(tex_maps is None) + + with NamedTemporaryFile(mode="w", suffix=".obj") as f: + f.write(obj_file) + f.flush() + + verts, faces, aux = load_obj(Path(f.name)) + normals = aux.normals + textures = aux.verts_uvs + materials = aux.material_colors + tex_maps = aux.texture_images + self.assertClose(faces.normals_idx, expected_faces_normals_idx) + self.assertClose(normals, expected_normals) + self.assertClose(verts, expected_verts) + # Textures idx padded with -1. + self.assertClose(faces.textures_idx, torch.ones_like(faces.verts_idx) * -1) + self.assertTrue(textures is None) + self.assertTrue(materials is None) + self.assertTrue(tex_maps is None) def test_load_obj_textures_only(self): obj_file = "\n".join( @@ -266,7 +288,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): "f 2/1 3/1 4/2", ] ) - obj_file = StringIO(obj_file) + expected_faces_textures_idx = torch.tensor([[0, 0, 1]], dtype=torch.int64) expected_textures = torch.tensor( [[0.999110, 0.501077], [0.999455, 0.750380]], dtype=torch.float32 @@ -275,72 +297,89 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): [[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]], dtype=torch.float32, ) - verts, faces, aux = load_obj(obj_file) - normals = aux.normals - textures = aux.verts_uvs - materials = aux.material_colors - tex_maps = aux.texture_images - self.assertClose(faces.textures_idx, expected_faces_textures_idx) - self.assertClose(expected_textures, textures) - self.assertClose(expected_verts, verts) - self.assertTrue( - torch.all(faces.normals_idx == -(torch.ones_like(faces.textures_idx))) - ) - self.assertTrue(normals is None) - self.assertTrue(materials is None) - self.assertTrue(tex_maps is None) + with NamedTemporaryFile(mode="w", suffix=".obj") as f: + f.write(obj_file) + f.flush() + + verts, faces, aux = load_obj(Path(f.name)) + normals = aux.normals + textures = aux.verts_uvs + materials = aux.material_colors + tex_maps = aux.texture_images + + self.assertClose(faces.textures_idx, expected_faces_textures_idx) + self.assertClose(expected_textures, textures) + self.assertClose(expected_verts, verts) + self.assertTrue( + torch.all(faces.normals_idx == -(torch.ones_like(faces.textures_idx))) + ) + self.assertTrue(normals is None) + self.assertTrue(materials is None) + self.assertTrue(tex_maps is None) def test_load_obj_error_textures(self): obj_file = "\n".join(["vt 0.1"]) - obj_file = StringIO(obj_file) + with NamedTemporaryFile(mode="w", suffix=".obj") as f: + f.write(obj_file) + f.flush() - with self.assertRaises(ValueError) as err: - load_obj(obj_file) - self.assertTrue("does not have 2 values" in str(err.exception)) + with self.assertRaises(ValueError) as err: + load_obj(Path(f.name)) + self.assertTrue("does not have 2 values" in str(err.exception)) def test_load_obj_error_normals(self): obj_file = "\n".join(["vn 0.1"]) - obj_file = StringIO(obj_file) + with NamedTemporaryFile(mode="w", suffix=".obj") as f: + f.write(obj_file) + f.flush() - with self.assertRaises(ValueError) as err: - load_obj(obj_file) - self.assertTrue("does not have 3 values" in str(err.exception)) + with self.assertRaises(ValueError) as err: + load_obj(Path(f.name)) + self.assertTrue("does not have 3 values" in str(err.exception)) def test_load_obj_error_vertices(self): obj_file = "\n".join(["v 1"]) - obj_file = StringIO(obj_file) + with NamedTemporaryFile(mode="w", suffix=".obj") as f: + f.write(obj_file) + f.flush() - with self.assertRaises(ValueError) as err: - load_obj(obj_file) - self.assertTrue("does not have 3 values" in str(err.exception)) + with self.assertRaises(ValueError) as err: + load_obj(Path(f.name)) + self.assertTrue("does not have 3 values" in str(err.exception)) def test_load_obj_error_inconsistent_triplets(self): obj_file = "\n".join(["f 2//1 3/1 4/1/2"]) - obj_file = StringIO(obj_file) + with NamedTemporaryFile(mode="w", suffix=".obj") as f: + f.write(obj_file) + f.flush() - with self.assertRaises(ValueError) as err: - load_obj(obj_file) - self.assertTrue("Vertex properties are inconsistent" in str(err.exception)) + with self.assertRaises(ValueError) as err: + load_obj(Path(f.name)) + self.assertTrue("Vertex properties are inconsistent" in str(err.exception)) def test_load_obj_error_too_many_vertex_properties(self): obj_file = "\n".join(["f 2/1/1/3"]) - obj_file = StringIO(obj_file) + with NamedTemporaryFile(mode="w", suffix=".obj") as f: + f.write(obj_file) + f.flush() - with self.assertRaises(ValueError) as err: - load_obj(obj_file) - self.assertTrue( - "Face vertices can only have 3 properties" in str(err.exception) - ) + with self.assertRaises(ValueError) as err: + load_obj(Path(f.name)) + self.assertTrue( + "Face vertices can only have 3 properties" in str(err.exception) + ) def test_load_obj_error_invalid_vertex_indices(self): obj_file = "\n".join( ["v 0.1 0.2 0.3", "v 0.1 0.2 0.3", "v 0.1 0.2 0.3", "f -2 5 1"] ) - obj_file = StringIO(obj_file) + with NamedTemporaryFile(mode="w", suffix=".obj") as f: + f.write(obj_file) + f.flush() - with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"): - load_obj(obj_file) + with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"): + load_obj(Path(f.name)) def test_load_obj_error_invalid_normal_indices(self): obj_file = "\n".join( @@ -354,10 +393,12 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): "f -2/2 2/4 1/1", ] ) - obj_file = StringIO(obj_file) + with NamedTemporaryFile(mode="w", suffix=".obj") as f: + f.write(obj_file) + f.flush() - with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"): - load_obj(obj_file) + with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"): + load_obj(Path(f.name)) def test_load_obj_error_invalid_texture_indices(self): obj_file = "\n".join( @@ -371,17 +412,20 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): "f -2//2 2//6 1//1", ] ) - obj_file = StringIO(obj_file) + with NamedTemporaryFile(mode="w", suffix=".obj") as f: + f.write(obj_file) + f.flush() - with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"): - load_obj(obj_file) + with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"): + load_obj(Path(f.name)) def test_save_obj_invalid_shapes(self): # Invalid vertices shape with self.assertRaises(ValueError) as error: verts = torch.FloatTensor([[0.1, 0.2, 0.3, 0.4]]) # (V, 4) faces = torch.LongTensor([[0, 1, 2]]) - save_obj(StringIO(), verts, faces) + with NamedTemporaryFile(mode="w", suffix=".obj") as f: + save_obj(Path(f.name), verts, faces) expected_message = ( "Argument 'verts' should either be empty or of shape (num_verts, 3)." ) @@ -391,7 +435,8 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): with self.assertRaises(ValueError) as error: verts = torch.FloatTensor([[0.1, 0.2, 0.3]]) faces = torch.LongTensor([[0, 1, 2, 3]]) # (F, 4) - save_obj(StringIO(), verts, faces) + with NamedTemporaryFile(mode="w", suffix=".obj") as f: + save_obj(Path(f.name), verts, faces) expected_message = ( "Argument 'faces' should either be empty or of shape (num_faces, 3)." ) @@ -402,24 +447,28 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): verts = torch.FloatTensor([[0.1, 0.2, 0.3]]) faces = torch.LongTensor([[0, 1, 2]]) with self.assertWarnsRegex(UserWarning, message_regex): - save_obj(StringIO(), verts, faces) + with NamedTemporaryFile(mode="w", suffix=".obj") as f: + save_obj(Path(f.name), verts, faces) faces = torch.LongTensor([[-1, 0, 1]]) with self.assertWarnsRegex(UserWarning, message_regex): - save_obj(StringIO(), verts, faces) + with NamedTemporaryFile(mode="w", suffix=".obj") as f: + save_obj(Path(f.name), verts, faces) def _test_save_load(self, verts, faces): - f = StringIO() - save_obj(f, verts, faces) - f.seek(0) - expected_verts, expected_faces = verts, faces - if not len(expected_verts): # Always compare with a (V, 3) tensor - expected_verts = torch.zeros(size=(0, 3), dtype=torch.float32) - if not len(expected_faces): # Always compare with an (F, 3) tensor - expected_faces = torch.zeros(size=(0, 3), dtype=torch.int64) - actual_verts, actual_faces, _ = load_obj(f) - self.assertClose(expected_verts, actual_verts) - self.assertClose(expected_faces, actual_faces.verts_idx) + with NamedTemporaryFile(mode="w", suffix=".obj") as f: + file_path = Path(f.name) + save_obj(file_path, verts, faces) + f.flush() + + expected_verts, expected_faces = verts, faces + if not len(expected_verts): # Always compare with a (V, 3) tensor + expected_verts = torch.zeros(size=(0, 3), dtype=torch.float32) + if not len(expected_faces): # Always compare with an (F, 3) tensor + expected_faces = torch.zeros(size=(0, 3), dtype=torch.int64) + actual_verts, actual_faces, _ = load_obj(file_path) + self.assertClose(expected_verts, actual_verts) + self.assertClose(expected_faces, actual_faces.verts_idx) def test_empty_save_load_obj(self): # Vertices + empty faces @@ -467,22 +516,23 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): faces = torch.tensor( [[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64 ) - obj_file = StringIO() - save_obj(obj_file, verts, faces, decimal_places=2) - expected_file = "\n".join( - [ - "v 0.01 0.20 0.30", - "v 0.20 0.03 0.41", - "v 0.30 0.40 0.05", - "v 0.60 0.70 0.80", - "f 1 3 2", - "f 1 2 3", - "f 4 3 2", - "f 4 2 1", - ] - ) - actual_file = obj_file.getvalue() - self.assertEqual(actual_file, expected_file) + with NamedTemporaryFile(mode="w", suffix=".obj") as f: + save_obj(Path(f.name), verts, faces, decimal_places=2) + + expected_file = "\n".join( + [ + "v 0.01 0.20 0.30", + "v 0.20 0.03 0.41", + "v 0.30 0.40 0.05", + "v 0.60 0.70 0.80", + "f 1 3 2", + "f 1 2 3", + "f 4 3 2", + "f 4 2 1", + ] + ) + actual_file = open(Path(f.name), "r") + self.assertEqual(actual_file.read(), expected_file) def test_load_mtl(self): obj_filename = "cow_mesh/cow.obj" @@ -534,36 +584,39 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): "Ns 10.0", ] ) - mtl_file = StringIO(mtl_file) - material_properties, texture_files = _parse_mtl( - mtl_file, path_manager=PathManager(), device="cpu" - ) + with NamedTemporaryFile(mode="w", suffix=".mtl") as f: + f.write(mtl_file) + f.flush() - dtype = torch.float32 - expected_materials = { - "material_1": { - "ambient_color": torch.tensor([1.0, 1.0, 1.0], dtype=dtype), - "diffuse_color": torch.tensor([1.0, 1.0, 1.0], dtype=dtype), - "specular_color": torch.tensor([0.0, 0.0, 0.0], dtype=dtype), - "shininess": torch.tensor([10.0], dtype=dtype), + material_properties, texture_files = _parse_mtl( + Path(f.name), path_manager=PathManager(), device="cpu" + ) + + dtype = torch.float32 + expected_materials = { + "material_1": { + "ambient_color": torch.tensor([1.0, 1.0, 1.0], dtype=dtype), + "diffuse_color": torch.tensor([1.0, 1.0, 1.0], dtype=dtype), + "specular_color": torch.tensor([0.0, 0.0, 0.0], dtype=dtype), + "shininess": torch.tensor([10.0], dtype=dtype), + } } - } - # Check that there is a material with name material_1 - self.assertTrue(tuple(texture_files.keys()) == ("material_1",)) - # Check that there is an image with name material 1.png - self.assertTrue(texture_files["material_1"] == "material 1.png") + # Check that there is a material with name material_1 + self.assertTrue(tuple(texture_files.keys()) == ("material_1",)) + # Check that there is an image with name material 1.png + self.assertTrue(texture_files["material_1"] == "material 1.png") - # Check all keys and values in dictionary are the same. - for n1, n2 in zip(material_properties.keys(), expected_materials.keys()): - self.assertTrue(n1 == n2) - for k1, k2 in zip( - material_properties[n1].keys(), expected_materials[n2].keys() - ): - self.assertTrue( - torch.allclose( - material_properties[n1][k1], expected_materials[n2][k2] + # Check all keys and values in dictionary are the same. + for n1, n2 in zip(material_properties.keys(), expected_materials.keys()): + self.assertTrue(n1 == n2) + for k1, k2 in zip( + material_properties[n1].keys(), expected_materials[n2].keys() + ): + self.assertTrue( + torch.allclose( + material_properties[n1][k1], expected_materials[n2][k2] + ) ) - ) def test_load_mtl_texture_atlas_compare_softras(self): # Load saved texture atlas created with SoftRas. @@ -618,21 +671,25 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): "f 1 2 4", ] ) - obj_file = StringIO(obj_file) - with self.assertWarnsRegex(UserWarning, "No mtl file provided"): - verts, faces, aux = load_obj(obj_file) - expected_verts = torch.tensor( - [[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]], - dtype=torch.float32, - ) - expected_faces = torch.tensor([[0, 1, 2], [0, 1, 3]], dtype=torch.int64) - self.assertTrue(torch.allclose(verts, expected_verts)) - self.assertTrue(torch.allclose(faces.verts_idx, expected_faces)) - self.assertTrue(aux.material_colors is None) - self.assertTrue(aux.texture_images is None) - self.assertTrue(aux.normals is None) - self.assertTrue(aux.verts_uvs is None) + with NamedTemporaryFile(mode="w", suffix=".obj") as f: + f.write(obj_file) + f.flush() + + with self.assertWarnsRegex(UserWarning, "No mtl file provided"): + verts, faces, aux = load_obj(Path(f.name)) + + expected_verts = torch.tensor( + [[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]], + dtype=torch.float32, + ) + expected_faces = torch.tensor([[0, 1, 2], [0, 1, 3]], dtype=torch.int64) + self.assertTrue(torch.allclose(verts, expected_verts)) + self.assertTrue(torch.allclose(faces.verts_idx, expected_faces)) + self.assertTrue(aux.material_colors is None) + self.assertTrue(aux.texture_images is None) + self.assertTrue(aux.normals is None) + self.assertTrue(aux.verts_uvs is None) def test_load_obj_mtl_no_image(self): obj_filename = "obj_mtl_no_image/model.obj" @@ -825,6 +882,161 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): with self.assertRaisesRegex(ValueError, "same type of texture"): join_meshes_as_batch([mesh_atlas, mesh_rgb, mesh_atlas]) + def test_save_obj_with_texture(self): + verts = torch.tensor( + [[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]], + dtype=torch.float32, + ) + faces = torch.tensor( + [[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64 + ) + verts_uvs = torch.tensor( + [[0.02, 0.5], [0.3, 0.03], [0.32, 0.12], [0.36, 0.17]], + dtype=torch.float32, + ) + faces_uvs = faces + texture_map = torch.randint(size=(2, 2, 3), high=255) / 255.0 + + with TemporaryDirectory() as temp_dir: + obj_file = os.path.join(temp_dir, "mesh.obj") + save_obj( + obj_file, + verts, + faces, + decimal_places=2, + verts_uvs=verts_uvs, + faces_uvs=faces_uvs, + texture_map=texture_map, + ) + + expected_obj_file = "\n".join( + [ + "", + "mtllib mesh.mtl", + "usemtl mesh", + "", + "v 0.01 0.20 0.30", + "v 0.20 0.03 0.41", + "v 0.30 0.40 0.05", + "v 0.60 0.70 0.80", + "vt 0.02 0.50", + "vt 0.30 0.03", + "vt 0.32 0.12", + "vt 0.36 0.17", + "f 1/1 3/3 2/2", + "f 1/1 2/2 3/3", + "f 4/4 3/3 2/2", + "f 4/4 2/2 1/1", + ] + ) + expected_mtl_file = "\n".join(["newmtl mesh", "map_Kd mesh.png", ""]) + + # Check there are only 3 files in the temp dir + tempfiles = ["mesh.obj", "mesh.png", "mesh.mtl"] + tempfiles_dir = os.listdir(temp_dir) + self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir)) + + # Check the obj file is saved correctly + actual_file = open(obj_file, "r") + self.assertEqual(actual_file.read(), expected_obj_file) + + # Check the mtl file is saved correctly + mtl_file_name = os.path.join(temp_dir, "mesh.mtl") + mtl_file = open(mtl_file_name, "r") + self.assertEqual(mtl_file.read(), expected_mtl_file) + + # Check the texture image file is saved correctly + texture_image = load_rgb_image("mesh.png", temp_dir) + self.assertClose(texture_image, texture_map) + + def test_save_obj_with_texture_errors(self): + verts = torch.tensor( + [[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]], + dtype=torch.float32, + ) + faces = torch.tensor( + [[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64 + ) + verts_uvs = torch.tensor( + [[0.02, 0.5], [0.3, 0.03], [0.32, 0.12], [0.36, 0.17]], + dtype=torch.float32, + ) + faces_uvs = faces + texture_map = torch.randint(size=(2, 2, 3), high=255) + + expected_obj_file = "\n".join( + [ + "v 0.01 0.20 0.30", + "v 0.20 0.03 0.41", + "v 0.30 0.40 0.05", + "v 0.60 0.70 0.80", + "f 1 3 2", + "f 1 2 3", + "f 4 3 2", + "f 4 2 1", + ] + ) + with TemporaryDirectory() as temp_dir: + obj_file = os.path.join(temp_dir, "mesh.obj") + + # If only one of verts_uvs/faces_uvs/texture_map is provided + # then textures are not saved + for arg in [ + {"verts_uvs": verts_uvs}, + {"faces_uvs": faces_uvs}, + {"texture_map": texture_map}, + ]: + save_obj( + obj_file, + verts, + faces, + decimal_places=2, + **arg, + ) + + # Check there is only 1 file in the temp dir + tempfiles = ["mesh.obj"] + tempfiles_dir = os.listdir(temp_dir) + self.assertEqual(tempfiles, tempfiles_dir) + + # Check the obj file is saved correctly + actual_file = open(obj_file, "r") + self.assertEqual(actual_file.read(), expected_obj_file) + + obj_file = StringIO() + with self.assertRaises(ValueError): + save_obj( + obj_file, + verts, + faces, + decimal_places=2, + verts_uvs=verts_uvs, + faces_uvs=faces_uvs[..., 2], # Incorrect shape + texture_map=texture_map, + ) + + with self.assertRaises(ValueError): + save_obj( + obj_file, + verts, + faces, + decimal_places=2, + verts_uvs=verts_uvs[..., 0], # Incorrect shape + faces_uvs=faces_uvs, + texture_map=texture_map, + ) + + with self.assertRaises(ValueError): + save_obj( + obj_file, + verts, + faces, + decimal_places=2, + verts_uvs=verts_uvs, + faces_uvs=faces_uvs, + texture_map=texture_map[..., 1], # Incorrect shape + ) + @staticmethod def _bm_save_obj(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int): return lambda: save_obj(StringIO(), verts, faces, decimal_places)