mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
add existing mesh formats to pluggable
Summary: We already have code for obj and ply formats. Here we actually make it available in `IO.load_mesh` and `IO.save_mesh`. Reviewed By: theschnitz, nikhilaravi Differential Revision: D25400650 fbshipit-source-id: f26d6d7fc46c48634a948eea4d255afad13b807b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
b183dcb6e8
commit
89532a876e
@@ -5,7 +5,8 @@
|
||||
import os
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from typing import List, Optional
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -15,6 +16,8 @@ from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
|
||||
from pytorch3d.renderer import TexturesAtlas, TexturesUV
|
||||
from pytorch3d.structures import Meshes, join_meshes_as_batch
|
||||
|
||||
from .pluggable_formats import MeshFormatInterpreter, endswith
|
||||
|
||||
|
||||
# Faces & Aux type returned from load_obj function.
|
||||
_Faces = namedtuple("Faces", "verts_idx normals_idx textures_idx materials_idx")
|
||||
@@ -286,6 +289,58 @@ def load_objs_as_meshes(
|
||||
return join_meshes_as_batch(mesh_list)
|
||||
|
||||
|
||||
class MeshObjFormat(MeshFormatInterpreter):
|
||||
def __init__(self):
|
||||
self.known_suffixes = (".obj",)
|
||||
|
||||
def read(
|
||||
self,
|
||||
path: Union[str, Path],
|
||||
include_textures: bool,
|
||||
device,
|
||||
path_manager: PathManager,
|
||||
create_texture_atlas: bool = False,
|
||||
texture_atlas_size: int = 4,
|
||||
texture_wrap: Optional[str] = "repeat",
|
||||
**kwargs,
|
||||
) -> Optional[Meshes]:
|
||||
if not endswith(path, self.known_suffixes):
|
||||
return None
|
||||
mesh = load_objs_as_meshes(
|
||||
files=[path],
|
||||
device=device,
|
||||
load_textures=include_textures,
|
||||
create_texture_atlas=create_texture_atlas,
|
||||
texture_atlas_size=texture_atlas_size,
|
||||
texture_wrap=texture_wrap,
|
||||
path_manager=path_manager,
|
||||
)
|
||||
return mesh
|
||||
|
||||
def save(
|
||||
self,
|
||||
data: Meshes,
|
||||
path: Union[str, Path],
|
||||
path_manager: PathManager,
|
||||
binary: Optional[bool],
|
||||
decimal_places: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
if not endswith(path, self.known_suffixes):
|
||||
return False
|
||||
|
||||
verts = data.verts_list()[0]
|
||||
faces = data.faces_list()[0]
|
||||
save_obj(
|
||||
f=path,
|
||||
verts=verts,
|
||||
faces=faces,
|
||||
decimal_places=decimal_places,
|
||||
path_manager=path_manager,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def _parse_face(
|
||||
line,
|
||||
tokens,
|
||||
|
||||
@@ -10,7 +10,9 @@ from typing import Deque, Optional, Union
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.structures import Meshes, Pointclouds
|
||||
|
||||
from .obj_io import MeshObjFormat
|
||||
from .pluggable_formats import MeshFormatInterpreter, PointcloudFormatInterpreter
|
||||
from .ply_io import MeshPlyFormat
|
||||
|
||||
|
||||
"""
|
||||
@@ -70,8 +72,8 @@ class IO:
|
||||
self.register_default_formats()
|
||||
|
||||
def register_default_formats(self) -> None:
|
||||
# This will be populated in later diffs
|
||||
pass
|
||||
self.register_meshes_format(MeshObjFormat())
|
||||
self.register_meshes_format(MeshPlyFormat())
|
||||
|
||||
def register_meshes_format(self, interpreter: MeshFormatInterpreter) -> None:
|
||||
"""
|
||||
|
||||
@@ -9,12 +9,16 @@ import sys
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from io import BytesIO
|
||||
from typing import Optional, Tuple
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
|
||||
from pytorch3d.structures import Meshes
|
||||
|
||||
from .pluggable_formats import MeshFormatInterpreter, endswith
|
||||
|
||||
|
||||
_PlyTypeData = namedtuple("_PlyTypeData", "size struct_char np_type")
|
||||
@@ -679,8 +683,7 @@ def load_ply(f, path_manager: Optional[PathManager] = None):
|
||||
# but we don't need to enforce this.
|
||||
|
||||
if not len(face):
|
||||
# pyre-fixme[28]: Unexpected keyword argument `size`.
|
||||
faces = torch.zeros(size=(0, 3), dtype=torch.int64)
|
||||
faces = torch.zeros((0, 3), dtype=torch.int64)
|
||||
elif isinstance(face, np.ndarray) and face.ndim == 2: # Homogeneous elements
|
||||
if face.shape[1] < 3:
|
||||
raise ValueError("Faces must have at least 3 vertices.")
|
||||
@@ -831,3 +834,48 @@ def save_ply(
|
||||
path_manager = PathManager()
|
||||
with _open_file(f, path_manager, "wb") as f:
|
||||
_save_ply(f, verts, faces, verts_normals, ascii, decimal_places)
|
||||
|
||||
|
||||
class MeshPlyFormat(MeshFormatInterpreter):
|
||||
def __init__(self):
|
||||
self.known_suffixes = (".ply",)
|
||||
|
||||
def read(
|
||||
self,
|
||||
path: Union[str, Path],
|
||||
include_textures: bool,
|
||||
device,
|
||||
path_manager: PathManager,
|
||||
**kwargs,
|
||||
) -> Optional[Meshes]:
|
||||
if not endswith(path, self.known_suffixes):
|
||||
return None
|
||||
|
||||
verts, faces = load_ply(f=path, path_manager=path_manager)
|
||||
mesh = Meshes(verts=[verts.to(device)], faces=[faces.to(device)])
|
||||
return mesh
|
||||
|
||||
def save(
|
||||
self,
|
||||
data: Meshes,
|
||||
path: Union[str, Path],
|
||||
path_manager: PathManager,
|
||||
binary: Optional[bool],
|
||||
decimal_places: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
if not endswith(path, self.known_suffixes):
|
||||
return False
|
||||
|
||||
# TODO: normals are not saved. We only want to save them if they already exist.
|
||||
verts = data.verts_list()[0]
|
||||
faces = data.faces_list()[0]
|
||||
save_ply(
|
||||
f=path,
|
||||
verts=verts,
|
||||
faces=faces,
|
||||
ascii=binary is False,
|
||||
decimal_places=decimal_places,
|
||||
path_manager=path_manager,
|
||||
)
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user