Extract more reusable I/O functions

Summary: Continue extracting reusable I/O functions to a separate utils module (and remove duplication).

Reviewed By: nikhilaravi

Differential Revision: D20720433

fbshipit-source-id: e82b19560a5dc8a506c4c4d098da69c202790c4f
This commit is contained in:
Patrick Labatut 2020-08-11 15:50:30 -07:00 committed by Facebook GitHub Bot
parent 63ba74f1a8
commit 6d76336501
3 changed files with 39 additions and 40 deletions

View File

@ -10,22 +10,11 @@ from typing import Optional
import numpy as np
import torch
from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas
from pytorch3d.io.utils import _open_file
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
from pytorch3d.renderer import TexturesAtlas, TexturesUV
from pytorch3d.structures import Meshes, join_meshes_as_batch
def _make_tensor(data, cols: int, dtype: torch.dtype, device="cpu") -> torch.Tensor:
"""
Return a 2D tensor with the specified cols and dtype filled with data,
even when data is empty.
"""
if not data:
return torch.zeros((0, cols), dtype=dtype, device=device)
return torch.tensor(data, dtype=dtype, device=device)
# Faces & Aux type returned from load_obj function.
_Faces = namedtuple("Faces", "verts_idx normals_idx textures_idx materials_idx")
_Aux = namedtuple(
@ -57,8 +46,8 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
faces_indices, cols=3, dtype=torch.int64, device=device
)
if pad_value:
mask = faces_indices.eq(pad_value).all(-1)
if pad_value is not None:
mask = faces_indices.eq(pad_value).all(dim=-1)
# Change to 0 based indexing.
faces_indices[(faces_indices > 0)] -= 1
@ -66,16 +55,10 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
# Negative indexing counts from the end.
faces_indices[(faces_indices < 0)] += max_index
if pad_value:
if pad_value is not None:
faces_indices[mask] = pad_value
# Check indices are valid.
if torch.any(faces_indices >= max_index) or (
pad_value is None and torch.any(faces_indices < 0)
):
warnings.warn("Faces have invalid indices")
return faces_indices
return _check_faces_indices(faces_indices, max_index, pad_value)
def load_obj(

View File

@ -12,7 +12,7 @@ from typing import Optional, Tuple
import numpy as np
import torch
from pytorch3d.io.utils import _open_file
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
_PlyTypeData = namedtuple("_PlyTypeData", "size struct_char np_type")
@ -221,17 +221,6 @@ class _PlyHeader:
self.elements.append(_PlyElementType(items[1], count))
def _make_tensor(data, cols: int, dtype: torch.dtype) -> torch.Tensor:
"""
Return a 2D tensor with the specified cols and dtype filled with data,
even when data is empty.
"""
if not len(data):
return torch.zeros((0, cols), dtype=dtype)
return torch.tensor(data, dtype=dtype)
def _read_ply_fixed_size_element_ascii(f, definition: _PlyElementType):
"""
Given an element which has no lists and one type, read the
@ -691,9 +680,7 @@ def load_ply(f):
face_list.append([face_item[0], face_item[i + 1], face_item[i + 2]])
faces = _make_tensor(face_list, cols=3, dtype=torch.int64)
if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0):
warnings.warn("Faces have invalid indices")
_check_faces_indices(faces, max_index=verts.shape[0])
return verts, faces
@ -747,8 +734,7 @@ def _save_ply(
faces_array = faces.detach().numpy()
if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0):
warnings.warn("Faces have invalid indices")
_check_faces_indices(faces, max_index=verts.shape[0])
if len(faces_array):
np.savetxt(f, faces_array, "3 %d %d %d")

View File

@ -2,9 +2,11 @@
import contextlib
import pathlib
from typing import IO, ContextManager
import warnings
from typing import IO, ContextManager, Optional
import numpy as np
import torch
from fvcore.common.file_io import PathManager
from PIL import Image
@ -20,6 +22,34 @@ def _open_file(f, mode="r") -> ContextManager[IO]:
return contextlib.nullcontext(f)
def _make_tensor(
data, cols: int, dtype: torch.dtype, device: str = "cpu"
) -> torch.Tensor:
"""
Return a 2D tensor with the specified cols and dtype filled with data,
even when data is empty.
"""
if not len(data):
return torch.zeros((0, cols), dtype=dtype, device=device)
return torch.tensor(data, dtype=dtype, device=device)
def _check_faces_indices(
faces_indices: torch.Tensor, max_index: int, pad_value: Optional[int] = None
) -> torch.Tensor:
if pad_value is None:
mask = torch.ones(faces_indices.shape[:-1]).bool() # Keep all faces
else:
# pyre-fixme[16]: `torch.ByteTensor` has no attribute `any`
mask = faces_indices.ne(pad_value).any(dim=-1)
if torch.any(faces_indices[mask] >= max_index) or torch.any(
faces_indices[mask] < 0
):
warnings.warn("Faces have invalid indices")
return faces_indices
def _read_image(file_name: str, format=None):
"""
Read an image from a file using Pillow.