From 89532a876e77c09edf581f3b7a0de39df761d457 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Thu, 7 Jan 2021 15:38:49 -0800 Subject: [PATCH] add existing mesh formats to pluggable Summary: We already have code for obj and ply formats. Here we actually make it available in `IO.load_mesh` and `IO.save_mesh`. Reviewed By: theschnitz, nikhilaravi Differential Revision: D25400650 fbshipit-source-id: f26d6d7fc46c48634a948eea4d255afad13b807b --- pytorch3d/io/obj_io.py | 57 ++++++++++++++- pytorch3d/io/pluggable.py | 6 +- pytorch3d/io/ply_io.py | 54 +++++++++++++- tests/test_obj_io.py | 71 +++++++++++++++++- tests/test_ply_io.py | 148 +++++++++++++++++++++++--------------- 5 files changed, 271 insertions(+), 65 deletions(-) diff --git a/pytorch3d/io/obj_io.py b/pytorch3d/io/obj_io.py index cac0ce25..e6fe9d80 100644 --- a/pytorch3d/io/obj_io.py +++ b/pytorch3d/io/obj_io.py @@ -5,7 +5,8 @@ import os import warnings from collections import namedtuple -from typing import List, Optional +from pathlib import Path +from typing import List, Optional, Union import numpy as np import torch @@ -15,6 +16,8 @@ from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file from pytorch3d.renderer import TexturesAtlas, TexturesUV from pytorch3d.structures import Meshes, join_meshes_as_batch +from .pluggable_formats import MeshFormatInterpreter, endswith + # Faces & Aux type returned from load_obj function. _Faces = namedtuple("Faces", "verts_idx normals_idx textures_idx materials_idx") @@ -286,6 +289,58 @@ def load_objs_as_meshes( return join_meshes_as_batch(mesh_list) +class MeshObjFormat(MeshFormatInterpreter): + def __init__(self): + self.known_suffixes = (".obj",) + + def read( + self, + path: Union[str, Path], + include_textures: bool, + device, + path_manager: PathManager, + create_texture_atlas: bool = False, + texture_atlas_size: int = 4, + texture_wrap: Optional[str] = "repeat", + **kwargs, + ) -> Optional[Meshes]: + if not endswith(path, self.known_suffixes): + return None + mesh = load_objs_as_meshes( + files=[path], + device=device, + load_textures=include_textures, + create_texture_atlas=create_texture_atlas, + texture_atlas_size=texture_atlas_size, + texture_wrap=texture_wrap, + path_manager=path_manager, + ) + return mesh + + def save( + self, + data: Meshes, + path: Union[str, Path], + path_manager: PathManager, + binary: Optional[bool], + decimal_places: Optional[int] = None, + **kwargs, + ) -> bool: + if not endswith(path, self.known_suffixes): + return False + + verts = data.verts_list()[0] + faces = data.faces_list()[0] + save_obj( + f=path, + verts=verts, + faces=faces, + decimal_places=decimal_places, + path_manager=path_manager, + ) + return True + + def _parse_face( line, tokens, diff --git a/pytorch3d/io/pluggable.py b/pytorch3d/io/pluggable.py index 5c03dd4a..b9f0e035 100644 --- a/pytorch3d/io/pluggable.py +++ b/pytorch3d/io/pluggable.py @@ -10,7 +10,9 @@ from typing import Deque, Optional, Union from iopath.common.file_io import PathManager from pytorch3d.structures import Meshes, Pointclouds +from .obj_io import MeshObjFormat from .pluggable_formats import MeshFormatInterpreter, PointcloudFormatInterpreter +from .ply_io import MeshPlyFormat """ @@ -70,8 +72,8 @@ class IO: self.register_default_formats() def register_default_formats(self) -> None: - # This will be populated in later diffs - pass + self.register_meshes_format(MeshObjFormat()) + self.register_meshes_format(MeshPlyFormat()) def register_meshes_format(self, interpreter: MeshFormatInterpreter) -> None: """ diff --git a/pytorch3d/io/ply_io.py b/pytorch3d/io/ply_io.py index f96ca3ef..755e52c5 100644 --- a/pytorch3d/io/ply_io.py +++ b/pytorch3d/io/ply_io.py @@ -9,12 +9,16 @@ import sys import warnings from collections import namedtuple from io import BytesIO -from typing import Optional, Tuple +from pathlib import Path +from typing import Optional, Tuple, Union import numpy as np import torch from iopath.common.file_io import PathManager from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file +from pytorch3d.structures import Meshes + +from .pluggable_formats import MeshFormatInterpreter, endswith _PlyTypeData = namedtuple("_PlyTypeData", "size struct_char np_type") @@ -679,8 +683,7 @@ def load_ply(f, path_manager: Optional[PathManager] = None): # but we don't need to enforce this. if not len(face): - # pyre-fixme[28]: Unexpected keyword argument `size`. - faces = torch.zeros(size=(0, 3), dtype=torch.int64) + faces = torch.zeros((0, 3), dtype=torch.int64) elif isinstance(face, np.ndarray) and face.ndim == 2: # Homogeneous elements if face.shape[1] < 3: raise ValueError("Faces must have at least 3 vertices.") @@ -831,3 +834,48 @@ def save_ply( path_manager = PathManager() with _open_file(f, path_manager, "wb") as f: _save_ply(f, verts, faces, verts_normals, ascii, decimal_places) + + +class MeshPlyFormat(MeshFormatInterpreter): + def __init__(self): + self.known_suffixes = (".ply",) + + def read( + self, + path: Union[str, Path], + include_textures: bool, + device, + path_manager: PathManager, + **kwargs, + ) -> Optional[Meshes]: + if not endswith(path, self.known_suffixes): + return None + + verts, faces = load_ply(f=path, path_manager=path_manager) + mesh = Meshes(verts=[verts.to(device)], faces=[faces.to(device)]) + return mesh + + def save( + self, + data: Meshes, + path: Union[str, Path], + path_manager: PathManager, + binary: Optional[bool], + decimal_places: Optional[int] = None, + **kwargs, + ) -> bool: + if not endswith(path, self.known_suffixes): + return False + + # TODO: normals are not saved. We only want to save them if they already exist. + verts = data.verts_list()[0] + faces = data.faces_list()[0] + save_ply( + f=path, + verts=verts, + faces=faces, + ascii=binary is False, + decimal_places=decimal_places, + path_manager=path_manager, + ) + return True diff --git a/tests/test_obj_io.py b/tests/test_obj_io.py index cccc87fa..6b543c6f 100644 --- a/tests/test_obj_io.py +++ b/tests/test_obj_io.py @@ -5,11 +5,12 @@ import unittest import warnings from io import StringIO from pathlib import Path +from tempfile import NamedTemporaryFile import torch from common_testing import TestCaseMixin from iopath.common.file_io import PathManager -from pytorch3d.io import load_obj, load_objs_as_meshes, save_obj +from pytorch3d.io import IO, load_obj, load_objs_as_meshes, save_obj from pytorch3d.io.mtl_io import ( _bilinear_interpolation_grid_sample, _bilinear_interpolation_vectorized, @@ -145,6 +146,70 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): self.assertTrue(materials is None) self.assertTrue(tex_maps is None) + def test_load_obj_complex_pluggable(self): + """ + This won't work on Windows due to the behavior of NamedTemporaryFile + """ + obj_file = "\n".join( + [ + "# this is a comment", # Comments should be ignored. + "v 0.1 0.2 0.3", + "v 0.2 0.3 0.4", + "v 0.3 0.4 0.5", + "v 0.4 0.5 0.6", + "vn 0.000000 0.000000 -1.000000", + "vn -1.000000 -0.000000 -0.000000", + "vn -0.000000 -0.000000 1.000000", # Normals should not be ignored. + "v 0.5 0.6 0.7", + "vt 0.749279 0.501284 0.0", # Some files add 0.0 - ignore this. + "vt 0.999110 0.501077", + "vt 0.999455 0.750380", + "f 1 2 3", + "f 1 2 4 3 5", # Polygons should be split into triangles + "f 2/1/2 3/1/2 4/2/2", # Texture/normals are loaded correctly. + "f -1 -2 1", # Negative indexing counts from the end. + ] + ) + io = IO() + with NamedTemporaryFile(mode="w", suffix=".obj") as f: + f.write(obj_file) + f.flush() + mesh = io.load_mesh(f.name) + mesh_from_path = io.load_mesh(Path(f.name)) + + with NamedTemporaryFile(mode="w", suffix=".ply") as f: + f.write(obj_file) + f.flush() + with self.assertRaisesRegex(ValueError, "Invalid file header."): + io.load_mesh(f.name) + + expected_verts = torch.tensor( + [ + [0.1, 0.2, 0.3], + [0.2, 0.3, 0.4], + [0.3, 0.4, 0.5], + [0.4, 0.5, 0.6], + [0.5, 0.6, 0.7], + ], + dtype=torch.float32, + ) + expected_faces = torch.tensor( + [ + [0, 1, 2], # First face + [0, 1, 3], # Second face (polygon) + [0, 3, 2], # Second face (polygon) + [0, 2, 4], # Second face (polygon) + [1, 2, 3], # Third face (normals / texture) + [4, 3, 0], # Fourth face (negative indices) + ], + dtype=torch.int64, + ) + self.assertClose(mesh.verts_padded(), expected_verts[None]) + self.assertClose(mesh.faces_padded(), expected_faces[None]) + self.assertClose(mesh_from_path.verts_padded(), expected_verts[None]) + self.assertClose(mesh_from_path.faces_padded(), expected_faces[None]) + self.assertIsNone(mesh.textures) + def test_load_obj_normals_only(self): obj_file = "\n".join( [ @@ -588,8 +653,8 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): expected_atlas = torch.tensor([0.5, 0.0, 0.0], dtype=torch.float32) expected_atlas = expected_atlas[None, None, None, :].expand(2, R, R, -1) self.assertTrue(torch.allclose(aux.texture_atlas, expected_atlas)) - self.assertEquals(len(aux.material_colors.keys()), 1) - self.assertEquals(list(aux.material_colors.keys()), ["material_1"]) + self.assertEqual(len(aux.material_colors.keys()), 1) + self.assertEqual(list(aux.material_colors.keys()), ["material_1"]) def test_load_obj_missing_texture(self): DATA_DIR = Path(__file__).resolve().parent / "data" diff --git a/tests/test_ply_io.py b/tests/test_ply_io.py index 48d2dfaf..764ab0bf 100644 --- a/tests/test_ply_io.py +++ b/tests/test_ply_io.py @@ -3,12 +3,13 @@ import struct import unittest from io import BytesIO, StringIO -from tempfile import TemporaryFile +from tempfile import NamedTemporaryFile, TemporaryFile import pytorch3d.io.ply_io import torch from common_testing import TestCaseMixin from iopath.common.file_io import PathManager +from pytorch3d.io import IO from pytorch3d.io.ply_io import load_ply, save_ply from pytorch3d.utils import torus @@ -20,6 +21,60 @@ def _load_ply_raw(stream): return pytorch3d.io.ply_io._load_ply_raw(stream, global_path_manager) +CUBE_PLY_LINES = [ + "ply", + "format ascii 1.0", + "comment made by Greg Turk", + "comment this file is a cube", + "element vertex 8", + "property float x", + "property float y", + "property float z", + "element face 6", + "property list uchar int vertex_index", + "end_header", + "0 0 0", + "0 0 1", + "0 1 1", + "0 1 0", + "1 0 0", + "1 0 1", + "1 1 1", + "1 1 0", + "4 0 1 2 3", + "4 7 6 5 4", + "4 0 4 5 1", + "4 1 5 6 2", + "4 2 6 7 3", + "4 3 7 4 0", +] + +CUBE_VERTS = [ + [0, 0, 0], + [0, 0, 1], + [0, 1, 1], + [0, 1, 0], + [1, 0, 0], + [1, 0, 1], + [1, 1, 1], + [1, 1, 0], +] +CUBE_FACES = [ + [0, 1, 2], + [7, 6, 5], + [0, 4, 5], + [1, 5, 6], + [2, 6, 7], + [3, 7, 4], + [0, 2, 3], + [7, 5, 4], + [0, 5, 1], + [1, 6, 2], + [2, 7, 3], + [3, 4, 0], +] + + class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): def test_raw_load_simple_ascii(self): ply_file = "\n".join( @@ -82,35 +137,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): self.assertClose(x, [4, 5, 1]) def test_load_simple_ascii(self): - ply_file = "\n".join( - [ - "ply", - "format ascii 1.0", - "comment made by Greg Turk", - "comment this file is a cube", - "element vertex 8", - "property float x", - "property float y", - "property float z", - "element face 6", - "property list uchar int vertex_index", - "end_header", - "0 0 0", - "0 0 1", - "0 1 1", - "0 1 0", - "1 0 0", - "1 0 1", - "1 1 1", - "1 1 0", - "4 0 1 2 3", - "4 7 6 5 4", - "4 0 4 5 1", - "4 1 5 6 2", - "4 2 6 7 3", - "4 3 7 4 0", - ] - ) + ply_file = "\n".join(CUBE_PLY_LINES) for line_ending in [None, "\n", "\r\n"]: if line_ending is None: stream = StringIO(ply_file) @@ -122,32 +149,41 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): verts, faces = load_ply(stream) self.assertEqual(verts.shape, (8, 3)) self.assertEqual(faces.shape, (12, 3)) - verts_expected = [ - [0, 0, 0], - [0, 0, 1], - [0, 1, 1], - [0, 1, 0], - [1, 0, 0], - [1, 0, 1], - [1, 1, 1], - [1, 1, 0], - ] - self.assertClose(verts, torch.FloatTensor(verts_expected)) - faces_expected = [ - [0, 1, 2], - [7, 6, 5], - [0, 4, 5], - [1, 5, 6], - [2, 6, 7], - [3, 7, 4], - [0, 2, 3], - [7, 5, 4], - [0, 5, 1], - [1, 6, 2], - [2, 7, 3], - [3, 4, 0], - ] - self.assertClose(faces, torch.LongTensor(faces_expected)) + self.assertClose(verts, torch.FloatTensor(CUBE_VERTS)) + self.assertClose(faces, torch.LongTensor(CUBE_FACES)) + + def test_pluggable_load_cube(self): + """ + This won't work on Windows due to NamedTemporaryFile being reopened. + """ + ply_file = "\n".join(CUBE_PLY_LINES) + io = IO() + with NamedTemporaryFile(mode="w", suffix=".ply") as f: + f.write(ply_file) + f.flush() + mesh = io.load_mesh(f.name) + self.assertClose(mesh.verts_padded(), torch.FloatTensor(CUBE_VERTS)[None]) + self.assertClose(mesh.faces_padded(), torch.LongTensor(CUBE_FACES)[None]) + + device = torch.device("cuda:0") + + with NamedTemporaryFile(mode="w", suffix=".ply") as f2: + io.save_mesh(mesh, f2.name) + f2.flush() + mesh2 = io.load_mesh(f2.name, device=device) + self.assertEqual(mesh2.verts_padded().device, device) + self.assertClose(mesh2.verts_padded().cpu(), mesh.verts_padded()) + self.assertClose(mesh2.faces_padded().cpu(), mesh.faces_padded()) + + with NamedTemporaryFile(mode="w") as f3: + with self.assertRaisesRegex( + ValueError, "No mesh interpreter found to write to" + ): + io.save_mesh(mesh, f3.name) + with self.assertRaisesRegex( + ValueError, "No mesh interpreter found to read " + ): + io.load_mesh(f3.name) def test_save_ply_invalid_shapes(self): # Invalid vertices shape