From 6d76336501918e29b9cc15b5b4fe6d0f18818d67 Mon Sep 17 00:00:00 2001 From: Patrick Labatut Date: Tue, 11 Aug 2020 15:50:30 -0700 Subject: [PATCH] Extract more reusable I/O functions Summary: Continue extracting reusable I/O functions to a separate utils module (and remove duplication). Reviewed By: nikhilaravi Differential Revision: D20720433 fbshipit-source-id: e82b19560a5dc8a506c4c4d098da69c202790c4f --- pytorch3d/io/obj_io.py | 27 +++++---------------------- pytorch3d/io/ply_io.py | 20 +++----------------- pytorch3d/io/utils.py | 32 +++++++++++++++++++++++++++++++- 3 files changed, 39 insertions(+), 40 deletions(-) diff --git a/pytorch3d/io/obj_io.py b/pytorch3d/io/obj_io.py index 3012526f..a0e7848a 100644 --- a/pytorch3d/io/obj_io.py +++ b/pytorch3d/io/obj_io.py @@ -10,22 +10,11 @@ from typing import Optional import numpy as np import torch from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas -from pytorch3d.io.utils import _open_file +from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file from pytorch3d.renderer import TexturesAtlas, TexturesUV from pytorch3d.structures import Meshes, join_meshes_as_batch -def _make_tensor(data, cols: int, dtype: torch.dtype, device="cpu") -> torch.Tensor: - """ - Return a 2D tensor with the specified cols and dtype filled with data, - even when data is empty. - """ - if not data: - return torch.zeros((0, cols), dtype=dtype, device=device) - - return torch.tensor(data, dtype=dtype, device=device) - - # Faces & Aux type returned from load_obj function. _Faces = namedtuple("Faces", "verts_idx normals_idx textures_idx materials_idx") _Aux = namedtuple( @@ -57,8 +46,8 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None): faces_indices, cols=3, dtype=torch.int64, device=device ) - if pad_value: - mask = faces_indices.eq(pad_value).all(-1) + if pad_value is not None: + mask = faces_indices.eq(pad_value).all(dim=-1) # Change to 0 based indexing. faces_indices[(faces_indices > 0)] -= 1 @@ -66,16 +55,10 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None): # Negative indexing counts from the end. faces_indices[(faces_indices < 0)] += max_index - if pad_value: + if pad_value is not None: faces_indices[mask] = pad_value - # Check indices are valid. - if torch.any(faces_indices >= max_index) or ( - pad_value is None and torch.any(faces_indices < 0) - ): - warnings.warn("Faces have invalid indices") - - return faces_indices + return _check_faces_indices(faces_indices, max_index, pad_value) def load_obj( diff --git a/pytorch3d/io/ply_io.py b/pytorch3d/io/ply_io.py index 76709436..54abb7a5 100644 --- a/pytorch3d/io/ply_io.py +++ b/pytorch3d/io/ply_io.py @@ -12,7 +12,7 @@ from typing import Optional, Tuple import numpy as np import torch -from pytorch3d.io.utils import _open_file +from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file _PlyTypeData = namedtuple("_PlyTypeData", "size struct_char np_type") @@ -221,17 +221,6 @@ class _PlyHeader: self.elements.append(_PlyElementType(items[1], count)) -def _make_tensor(data, cols: int, dtype: torch.dtype) -> torch.Tensor: - """ - Return a 2D tensor with the specified cols and dtype filled with data, - even when data is empty. - """ - if not len(data): - return torch.zeros((0, cols), dtype=dtype) - - return torch.tensor(data, dtype=dtype) - - def _read_ply_fixed_size_element_ascii(f, definition: _PlyElementType): """ Given an element which has no lists and one type, read the @@ -691,9 +680,7 @@ def load_ply(f): face_list.append([face_item[0], face_item[i + 1], face_item[i + 2]]) faces = _make_tensor(face_list, cols=3, dtype=torch.int64) - if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0): - warnings.warn("Faces have invalid indices") - + _check_faces_indices(faces, max_index=verts.shape[0]) return verts, faces @@ -747,8 +734,7 @@ def _save_ply( faces_array = faces.detach().numpy() - if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0): - warnings.warn("Faces have invalid indices") + _check_faces_indices(faces, max_index=verts.shape[0]) if len(faces_array): np.savetxt(f, faces_array, "3 %d %d %d") diff --git a/pytorch3d/io/utils.py b/pytorch3d/io/utils.py index 278ced15..b28a864b 100644 --- a/pytorch3d/io/utils.py +++ b/pytorch3d/io/utils.py @@ -2,9 +2,11 @@ import contextlib import pathlib -from typing import IO, ContextManager +import warnings +from typing import IO, ContextManager, Optional import numpy as np +import torch from fvcore.common.file_io import PathManager from PIL import Image @@ -20,6 +22,34 @@ def _open_file(f, mode="r") -> ContextManager[IO]: return contextlib.nullcontext(f) +def _make_tensor( + data, cols: int, dtype: torch.dtype, device: str = "cpu" +) -> torch.Tensor: + """ + Return a 2D tensor with the specified cols and dtype filled with data, + even when data is empty. + """ + if not len(data): + return torch.zeros((0, cols), dtype=dtype, device=device) + + return torch.tensor(data, dtype=dtype, device=device) + + +def _check_faces_indices( + faces_indices: torch.Tensor, max_index: int, pad_value: Optional[int] = None +) -> torch.Tensor: + if pad_value is None: + mask = torch.ones(faces_indices.shape[:-1]).bool() # Keep all faces + else: + # pyre-fixme[16]: `torch.ByteTensor` has no attribute `any` + mask = faces_indices.ne(pad_value).any(dim=-1) + if torch.any(faces_indices[mask] >= max_index) or torch.any( + faces_indices[mask] < 0 + ): + warnings.warn("Faces have invalid indices") + return faces_indices + + def _read_image(file_name: str, format=None): """ Read an image from a file using Pillow.