Extract more reusable I/O functions

Summary: Continue extracting reusable I/O functions to a separate utils module (and remove duplication).

Reviewed By: nikhilaravi

Differential Revision: D20720433

fbshipit-source-id: e82b19560a5dc8a506c4c4d098da69c202790c4f
This commit is contained in:
Patrick Labatut 2020-08-11 15:50:30 -07:00 committed by Facebook GitHub Bot
parent 63ba74f1a8
commit 6d76336501
3 changed files with 39 additions and 40 deletions

View File

@ -10,22 +10,11 @@ from typing import Optional
import numpy as np import numpy as np
import torch import torch
from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas
from pytorch3d.io.utils import _open_file from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
from pytorch3d.renderer import TexturesAtlas, TexturesUV from pytorch3d.renderer import TexturesAtlas, TexturesUV
from pytorch3d.structures import Meshes, join_meshes_as_batch from pytorch3d.structures import Meshes, join_meshes_as_batch
def _make_tensor(data, cols: int, dtype: torch.dtype, device="cpu") -> torch.Tensor:
"""
Return a 2D tensor with the specified cols and dtype filled with data,
even when data is empty.
"""
if not data:
return torch.zeros((0, cols), dtype=dtype, device=device)
return torch.tensor(data, dtype=dtype, device=device)
# Faces & Aux type returned from load_obj function. # Faces & Aux type returned from load_obj function.
_Faces = namedtuple("Faces", "verts_idx normals_idx textures_idx materials_idx") _Faces = namedtuple("Faces", "verts_idx normals_idx textures_idx materials_idx")
_Aux = namedtuple( _Aux = namedtuple(
@ -57,8 +46,8 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
faces_indices, cols=3, dtype=torch.int64, device=device faces_indices, cols=3, dtype=torch.int64, device=device
) )
if pad_value: if pad_value is not None:
mask = faces_indices.eq(pad_value).all(-1) mask = faces_indices.eq(pad_value).all(dim=-1)
# Change to 0 based indexing. # Change to 0 based indexing.
faces_indices[(faces_indices > 0)] -= 1 faces_indices[(faces_indices > 0)] -= 1
@ -66,16 +55,10 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
# Negative indexing counts from the end. # Negative indexing counts from the end.
faces_indices[(faces_indices < 0)] += max_index faces_indices[(faces_indices < 0)] += max_index
if pad_value: if pad_value is not None:
faces_indices[mask] = pad_value faces_indices[mask] = pad_value
# Check indices are valid. return _check_faces_indices(faces_indices, max_index, pad_value)
if torch.any(faces_indices >= max_index) or (
pad_value is None and torch.any(faces_indices < 0)
):
warnings.warn("Faces have invalid indices")
return faces_indices
def load_obj( def load_obj(

View File

@ -12,7 +12,7 @@ from typing import Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from pytorch3d.io.utils import _open_file from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
_PlyTypeData = namedtuple("_PlyTypeData", "size struct_char np_type") _PlyTypeData = namedtuple("_PlyTypeData", "size struct_char np_type")
@ -221,17 +221,6 @@ class _PlyHeader:
self.elements.append(_PlyElementType(items[1], count)) self.elements.append(_PlyElementType(items[1], count))
def _make_tensor(data, cols: int, dtype: torch.dtype) -> torch.Tensor:
"""
Return a 2D tensor with the specified cols and dtype filled with data,
even when data is empty.
"""
if not len(data):
return torch.zeros((0, cols), dtype=dtype)
return torch.tensor(data, dtype=dtype)
def _read_ply_fixed_size_element_ascii(f, definition: _PlyElementType): def _read_ply_fixed_size_element_ascii(f, definition: _PlyElementType):
""" """
Given an element which has no lists and one type, read the Given an element which has no lists and one type, read the
@ -691,9 +680,7 @@ def load_ply(f):
face_list.append([face_item[0], face_item[i + 1], face_item[i + 2]]) face_list.append([face_item[0], face_item[i + 1], face_item[i + 2]])
faces = _make_tensor(face_list, cols=3, dtype=torch.int64) faces = _make_tensor(face_list, cols=3, dtype=torch.int64)
if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0): _check_faces_indices(faces, max_index=verts.shape[0])
warnings.warn("Faces have invalid indices")
return verts, faces return verts, faces
@ -747,8 +734,7 @@ def _save_ply(
faces_array = faces.detach().numpy() faces_array = faces.detach().numpy()
if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0): _check_faces_indices(faces, max_index=verts.shape[0])
warnings.warn("Faces have invalid indices")
if len(faces_array): if len(faces_array):
np.savetxt(f, faces_array, "3 %d %d %d") np.savetxt(f, faces_array, "3 %d %d %d")

View File

@ -2,9 +2,11 @@
import contextlib import contextlib
import pathlib import pathlib
from typing import IO, ContextManager import warnings
from typing import IO, ContextManager, Optional
import numpy as np import numpy as np
import torch
from fvcore.common.file_io import PathManager from fvcore.common.file_io import PathManager
from PIL import Image from PIL import Image
@ -20,6 +22,34 @@ def _open_file(f, mode="r") -> ContextManager[IO]:
return contextlib.nullcontext(f) return contextlib.nullcontext(f)
def _make_tensor(
data, cols: int, dtype: torch.dtype, device: str = "cpu"
) -> torch.Tensor:
"""
Return a 2D tensor with the specified cols and dtype filled with data,
even when data is empty.
"""
if not len(data):
return torch.zeros((0, cols), dtype=dtype, device=device)
return torch.tensor(data, dtype=dtype, device=device)
def _check_faces_indices(
faces_indices: torch.Tensor, max_index: int, pad_value: Optional[int] = None
) -> torch.Tensor:
if pad_value is None:
mask = torch.ones(faces_indices.shape[:-1]).bool() # Keep all faces
else:
# pyre-fixme[16]: `torch.ByteTensor` has no attribute `any`
mask = faces_indices.ne(pad_value).any(dim=-1)
if torch.any(faces_indices[mask] >= max_index) or torch.any(
faces_indices[mask] < 0
):
warnings.warn("Faces have invalid indices")
return faces_indices
def _read_image(file_name: str, format=None): def _read_image(file_name: str, format=None):
""" """
Read an image from a file using Pillow. Read an image from a file using Pillow.