add existing mesh formats to pluggable

Summary: We already have code for obj and ply formats. Here we actually make it available in `IO.load_mesh` and `IO.save_mesh`.

Reviewed By: theschnitz, nikhilaravi

Differential Revision: D25400650

fbshipit-source-id: f26d6d7fc46c48634a948eea4d255afad13b807b
This commit is contained in:
Jeremy Reizenstein 2021-01-07 15:38:49 -08:00 committed by Facebook GitHub Bot
parent b183dcb6e8
commit 89532a876e
5 changed files with 271 additions and 65 deletions

View File

@ -5,7 +5,8 @@
import os import os
import warnings import warnings
from collections import namedtuple from collections import namedtuple
from typing import List, Optional from pathlib import Path
from typing import List, Optional, Union
import numpy as np import numpy as np
import torch import torch
@ -15,6 +16,8 @@ from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
from pytorch3d.renderer import TexturesAtlas, TexturesUV from pytorch3d.renderer import TexturesAtlas, TexturesUV
from pytorch3d.structures import Meshes, join_meshes_as_batch from pytorch3d.structures import Meshes, join_meshes_as_batch
from .pluggable_formats import MeshFormatInterpreter, endswith
# Faces & Aux type returned from load_obj function. # Faces & Aux type returned from load_obj function.
_Faces = namedtuple("Faces", "verts_idx normals_idx textures_idx materials_idx") _Faces = namedtuple("Faces", "verts_idx normals_idx textures_idx materials_idx")
@ -286,6 +289,58 @@ def load_objs_as_meshes(
return join_meshes_as_batch(mesh_list) return join_meshes_as_batch(mesh_list)
class MeshObjFormat(MeshFormatInterpreter):
def __init__(self):
self.known_suffixes = (".obj",)
def read(
self,
path: Union[str, Path],
include_textures: bool,
device,
path_manager: PathManager,
create_texture_atlas: bool = False,
texture_atlas_size: int = 4,
texture_wrap: Optional[str] = "repeat",
**kwargs,
) -> Optional[Meshes]:
if not endswith(path, self.known_suffixes):
return None
mesh = load_objs_as_meshes(
files=[path],
device=device,
load_textures=include_textures,
create_texture_atlas=create_texture_atlas,
texture_atlas_size=texture_atlas_size,
texture_wrap=texture_wrap,
path_manager=path_manager,
)
return mesh
def save(
self,
data: Meshes,
path: Union[str, Path],
path_manager: PathManager,
binary: Optional[bool],
decimal_places: Optional[int] = None,
**kwargs,
) -> bool:
if not endswith(path, self.known_suffixes):
return False
verts = data.verts_list()[0]
faces = data.faces_list()[0]
save_obj(
f=path,
verts=verts,
faces=faces,
decimal_places=decimal_places,
path_manager=path_manager,
)
return True
def _parse_face( def _parse_face(
line, line,
tokens, tokens,

View File

@ -10,7 +10,9 @@ from typing import Deque, Optional, Union
from iopath.common.file_io import PathManager from iopath.common.file_io import PathManager
from pytorch3d.structures import Meshes, Pointclouds from pytorch3d.structures import Meshes, Pointclouds
from .obj_io import MeshObjFormat
from .pluggable_formats import MeshFormatInterpreter, PointcloudFormatInterpreter from .pluggable_formats import MeshFormatInterpreter, PointcloudFormatInterpreter
from .ply_io import MeshPlyFormat
""" """
@ -70,8 +72,8 @@ class IO:
self.register_default_formats() self.register_default_formats()
def register_default_formats(self) -> None: def register_default_formats(self) -> None:
# This will be populated in later diffs self.register_meshes_format(MeshObjFormat())
pass self.register_meshes_format(MeshPlyFormat())
def register_meshes_format(self, interpreter: MeshFormatInterpreter) -> None: def register_meshes_format(self, interpreter: MeshFormatInterpreter) -> None:
""" """

View File

@ -9,12 +9,16 @@ import sys
import warnings import warnings
from collections import namedtuple from collections import namedtuple
from io import BytesIO from io import BytesIO
from typing import Optional, Tuple from pathlib import Path
from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from iopath.common.file_io import PathManager from iopath.common.file_io import PathManager
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
from pytorch3d.structures import Meshes
from .pluggable_formats import MeshFormatInterpreter, endswith
_PlyTypeData = namedtuple("_PlyTypeData", "size struct_char np_type") _PlyTypeData = namedtuple("_PlyTypeData", "size struct_char np_type")
@ -679,8 +683,7 @@ def load_ply(f, path_manager: Optional[PathManager] = None):
# but we don't need to enforce this. # but we don't need to enforce this.
if not len(face): if not len(face):
# pyre-fixme[28]: Unexpected keyword argument `size`. faces = torch.zeros((0, 3), dtype=torch.int64)
faces = torch.zeros(size=(0, 3), dtype=torch.int64)
elif isinstance(face, np.ndarray) and face.ndim == 2: # Homogeneous elements elif isinstance(face, np.ndarray) and face.ndim == 2: # Homogeneous elements
if face.shape[1] < 3: if face.shape[1] < 3:
raise ValueError("Faces must have at least 3 vertices.") raise ValueError("Faces must have at least 3 vertices.")
@ -831,3 +834,48 @@ def save_ply(
path_manager = PathManager() path_manager = PathManager()
with _open_file(f, path_manager, "wb") as f: with _open_file(f, path_manager, "wb") as f:
_save_ply(f, verts, faces, verts_normals, ascii, decimal_places) _save_ply(f, verts, faces, verts_normals, ascii, decimal_places)
class MeshPlyFormat(MeshFormatInterpreter):
def __init__(self):
self.known_suffixes = (".ply",)
def read(
self,
path: Union[str, Path],
include_textures: bool,
device,
path_manager: PathManager,
**kwargs,
) -> Optional[Meshes]:
if not endswith(path, self.known_suffixes):
return None
verts, faces = load_ply(f=path, path_manager=path_manager)
mesh = Meshes(verts=[verts.to(device)], faces=[faces.to(device)])
return mesh
def save(
self,
data: Meshes,
path: Union[str, Path],
path_manager: PathManager,
binary: Optional[bool],
decimal_places: Optional[int] = None,
**kwargs,
) -> bool:
if not endswith(path, self.known_suffixes):
return False
# TODO: normals are not saved. We only want to save them if they already exist.
verts = data.verts_list()[0]
faces = data.faces_list()[0]
save_ply(
f=path,
verts=verts,
faces=faces,
ascii=binary is False,
decimal_places=decimal_places,
path_manager=path_manager,
)
return True

View File

@ -5,11 +5,12 @@ import unittest
import warnings import warnings
from io import StringIO from io import StringIO
from pathlib import Path from pathlib import Path
from tempfile import NamedTemporaryFile
import torch import torch
from common_testing import TestCaseMixin from common_testing import TestCaseMixin
from iopath.common.file_io import PathManager from iopath.common.file_io import PathManager
from pytorch3d.io import load_obj, load_objs_as_meshes, save_obj from pytorch3d.io import IO, load_obj, load_objs_as_meshes, save_obj
from pytorch3d.io.mtl_io import ( from pytorch3d.io.mtl_io import (
_bilinear_interpolation_grid_sample, _bilinear_interpolation_grid_sample,
_bilinear_interpolation_vectorized, _bilinear_interpolation_vectorized,
@ -145,6 +146,70 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
self.assertTrue(materials is None) self.assertTrue(materials is None)
self.assertTrue(tex_maps is None) self.assertTrue(tex_maps is None)
def test_load_obj_complex_pluggable(self):
"""
This won't work on Windows due to the behavior of NamedTemporaryFile
"""
obj_file = "\n".join(
[
"# this is a comment", # Comments should be ignored.
"v 0.1 0.2 0.3",
"v 0.2 0.3 0.4",
"v 0.3 0.4 0.5",
"v 0.4 0.5 0.6",
"vn 0.000000 0.000000 -1.000000",
"vn -1.000000 -0.000000 -0.000000",
"vn -0.000000 -0.000000 1.000000", # Normals should not be ignored.
"v 0.5 0.6 0.7",
"vt 0.749279 0.501284 0.0", # Some files add 0.0 - ignore this.
"vt 0.999110 0.501077",
"vt 0.999455 0.750380",
"f 1 2 3",
"f 1 2 4 3 5", # Polygons should be split into triangles
"f 2/1/2 3/1/2 4/2/2", # Texture/normals are loaded correctly.
"f -1 -2 1", # Negative indexing counts from the end.
]
)
io = IO()
with NamedTemporaryFile(mode="w", suffix=".obj") as f:
f.write(obj_file)
f.flush()
mesh = io.load_mesh(f.name)
mesh_from_path = io.load_mesh(Path(f.name))
with NamedTemporaryFile(mode="w", suffix=".ply") as f:
f.write(obj_file)
f.flush()
with self.assertRaisesRegex(ValueError, "Invalid file header."):
io.load_mesh(f.name)
expected_verts = torch.tensor(
[
[0.1, 0.2, 0.3],
[0.2, 0.3, 0.4],
[0.3, 0.4, 0.5],
[0.4, 0.5, 0.6],
[0.5, 0.6, 0.7],
],
dtype=torch.float32,
)
expected_faces = torch.tensor(
[
[0, 1, 2], # First face
[0, 1, 3], # Second face (polygon)
[0, 3, 2], # Second face (polygon)
[0, 2, 4], # Second face (polygon)
[1, 2, 3], # Third face (normals / texture)
[4, 3, 0], # Fourth face (negative indices)
],
dtype=torch.int64,
)
self.assertClose(mesh.verts_padded(), expected_verts[None])
self.assertClose(mesh.faces_padded(), expected_faces[None])
self.assertClose(mesh_from_path.verts_padded(), expected_verts[None])
self.assertClose(mesh_from_path.faces_padded(), expected_faces[None])
self.assertIsNone(mesh.textures)
def test_load_obj_normals_only(self): def test_load_obj_normals_only(self):
obj_file = "\n".join( obj_file = "\n".join(
[ [
@ -588,8 +653,8 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
expected_atlas = torch.tensor([0.5, 0.0, 0.0], dtype=torch.float32) expected_atlas = torch.tensor([0.5, 0.0, 0.0], dtype=torch.float32)
expected_atlas = expected_atlas[None, None, None, :].expand(2, R, R, -1) expected_atlas = expected_atlas[None, None, None, :].expand(2, R, R, -1)
self.assertTrue(torch.allclose(aux.texture_atlas, expected_atlas)) self.assertTrue(torch.allclose(aux.texture_atlas, expected_atlas))
self.assertEquals(len(aux.material_colors.keys()), 1) self.assertEqual(len(aux.material_colors.keys()), 1)
self.assertEquals(list(aux.material_colors.keys()), ["material_1"]) self.assertEqual(list(aux.material_colors.keys()), ["material_1"])
def test_load_obj_missing_texture(self): def test_load_obj_missing_texture(self):
DATA_DIR = Path(__file__).resolve().parent / "data" DATA_DIR = Path(__file__).resolve().parent / "data"

View File

@ -3,12 +3,13 @@
import struct import struct
import unittest import unittest
from io import BytesIO, StringIO from io import BytesIO, StringIO
from tempfile import TemporaryFile from tempfile import NamedTemporaryFile, TemporaryFile
import pytorch3d.io.ply_io import pytorch3d.io.ply_io
import torch import torch
from common_testing import TestCaseMixin from common_testing import TestCaseMixin
from iopath.common.file_io import PathManager from iopath.common.file_io import PathManager
from pytorch3d.io import IO
from pytorch3d.io.ply_io import load_ply, save_ply from pytorch3d.io.ply_io import load_ply, save_ply
from pytorch3d.utils import torus from pytorch3d.utils import torus
@ -20,6 +21,60 @@ def _load_ply_raw(stream):
return pytorch3d.io.ply_io._load_ply_raw(stream, global_path_manager) return pytorch3d.io.ply_io._load_ply_raw(stream, global_path_manager)
CUBE_PLY_LINES = [
"ply",
"format ascii 1.0",
"comment made by Greg Turk",
"comment this file is a cube",
"element vertex 8",
"property float x",
"property float y",
"property float z",
"element face 6",
"property list uchar int vertex_index",
"end_header",
"0 0 0",
"0 0 1",
"0 1 1",
"0 1 0",
"1 0 0",
"1 0 1",
"1 1 1",
"1 1 0",
"4 0 1 2 3",
"4 7 6 5 4",
"4 0 4 5 1",
"4 1 5 6 2",
"4 2 6 7 3",
"4 3 7 4 0",
]
CUBE_VERTS = [
[0, 0, 0],
[0, 0, 1],
[0, 1, 1],
[0, 1, 0],
[1, 0, 0],
[1, 0, 1],
[1, 1, 1],
[1, 1, 0],
]
CUBE_FACES = [
[0, 1, 2],
[7, 6, 5],
[0, 4, 5],
[1, 5, 6],
[2, 6, 7],
[3, 7, 4],
[0, 2, 3],
[7, 5, 4],
[0, 5, 1],
[1, 6, 2],
[2, 7, 3],
[3, 4, 0],
]
class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
def test_raw_load_simple_ascii(self): def test_raw_load_simple_ascii(self):
ply_file = "\n".join( ply_file = "\n".join(
@ -82,35 +137,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
self.assertClose(x, [4, 5, 1]) self.assertClose(x, [4, 5, 1])
def test_load_simple_ascii(self): def test_load_simple_ascii(self):
ply_file = "\n".join( ply_file = "\n".join(CUBE_PLY_LINES)
[
"ply",
"format ascii 1.0",
"comment made by Greg Turk",
"comment this file is a cube",
"element vertex 8",
"property float x",
"property float y",
"property float z",
"element face 6",
"property list uchar int vertex_index",
"end_header",
"0 0 0",
"0 0 1",
"0 1 1",
"0 1 0",
"1 0 0",
"1 0 1",
"1 1 1",
"1 1 0",
"4 0 1 2 3",
"4 7 6 5 4",
"4 0 4 5 1",
"4 1 5 6 2",
"4 2 6 7 3",
"4 3 7 4 0",
]
)
for line_ending in [None, "\n", "\r\n"]: for line_ending in [None, "\n", "\r\n"]:
if line_ending is None: if line_ending is None:
stream = StringIO(ply_file) stream = StringIO(ply_file)
@ -122,32 +149,41 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
verts, faces = load_ply(stream) verts, faces = load_ply(stream)
self.assertEqual(verts.shape, (8, 3)) self.assertEqual(verts.shape, (8, 3))
self.assertEqual(faces.shape, (12, 3)) self.assertEqual(faces.shape, (12, 3))
verts_expected = [ self.assertClose(verts, torch.FloatTensor(CUBE_VERTS))
[0, 0, 0], self.assertClose(faces, torch.LongTensor(CUBE_FACES))
[0, 0, 1],
[0, 1, 1], def test_pluggable_load_cube(self):
[0, 1, 0], """
[1, 0, 0], This won't work on Windows due to NamedTemporaryFile being reopened.
[1, 0, 1], """
[1, 1, 1], ply_file = "\n".join(CUBE_PLY_LINES)
[1, 1, 0], io = IO()
] with NamedTemporaryFile(mode="w", suffix=".ply") as f:
self.assertClose(verts, torch.FloatTensor(verts_expected)) f.write(ply_file)
faces_expected = [ f.flush()
[0, 1, 2], mesh = io.load_mesh(f.name)
[7, 6, 5], self.assertClose(mesh.verts_padded(), torch.FloatTensor(CUBE_VERTS)[None])
[0, 4, 5], self.assertClose(mesh.faces_padded(), torch.LongTensor(CUBE_FACES)[None])
[1, 5, 6],
[2, 6, 7], device = torch.device("cuda:0")
[3, 7, 4],
[0, 2, 3], with NamedTemporaryFile(mode="w", suffix=".ply") as f2:
[7, 5, 4], io.save_mesh(mesh, f2.name)
[0, 5, 1], f2.flush()
[1, 6, 2], mesh2 = io.load_mesh(f2.name, device=device)
[2, 7, 3], self.assertEqual(mesh2.verts_padded().device, device)
[3, 4, 0], self.assertClose(mesh2.verts_padded().cpu(), mesh.verts_padded())
] self.assertClose(mesh2.faces_padded().cpu(), mesh.faces_padded())
self.assertClose(faces, torch.LongTensor(faces_expected))
with NamedTemporaryFile(mode="w") as f3:
with self.assertRaisesRegex(
ValueError, "No mesh interpreter found to write to"
):
io.save_mesh(mesh, f3.name)
with self.assertRaisesRegex(
ValueError, "No mesh interpreter found to read "
):
io.load_mesh(f3.name)
def test_save_ply_invalid_shapes(self): def test_save_ply_invalid_shapes(self):
# Invalid vertices shape # Invalid vertices shape