mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Extract more reusable I/O functions
Summary: Continue extracting reusable I/O functions to a separate utils module (and remove duplication). Reviewed By: nikhilaravi Differential Revision: D20720433 fbshipit-source-id: e82b19560a5dc8a506c4c4d098da69c202790c4f
This commit is contained in:
parent
63ba74f1a8
commit
6d76336501
@ -10,22 +10,11 @@ from typing import Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas
|
||||
from pytorch3d.io.utils import _open_file
|
||||
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
|
||||
from pytorch3d.renderer import TexturesAtlas, TexturesUV
|
||||
from pytorch3d.structures import Meshes, join_meshes_as_batch
|
||||
|
||||
|
||||
def _make_tensor(data, cols: int, dtype: torch.dtype, device="cpu") -> torch.Tensor:
|
||||
"""
|
||||
Return a 2D tensor with the specified cols and dtype filled with data,
|
||||
even when data is empty.
|
||||
"""
|
||||
if not data:
|
||||
return torch.zeros((0, cols), dtype=dtype, device=device)
|
||||
|
||||
return torch.tensor(data, dtype=dtype, device=device)
|
||||
|
||||
|
||||
# Faces & Aux type returned from load_obj function.
|
||||
_Faces = namedtuple("Faces", "verts_idx normals_idx textures_idx materials_idx")
|
||||
_Aux = namedtuple(
|
||||
@ -57,8 +46,8 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
|
||||
faces_indices, cols=3, dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
if pad_value:
|
||||
mask = faces_indices.eq(pad_value).all(-1)
|
||||
if pad_value is not None:
|
||||
mask = faces_indices.eq(pad_value).all(dim=-1)
|
||||
|
||||
# Change to 0 based indexing.
|
||||
faces_indices[(faces_indices > 0)] -= 1
|
||||
@ -66,16 +55,10 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
|
||||
# Negative indexing counts from the end.
|
||||
faces_indices[(faces_indices < 0)] += max_index
|
||||
|
||||
if pad_value:
|
||||
if pad_value is not None:
|
||||
faces_indices[mask] = pad_value
|
||||
|
||||
# Check indices are valid.
|
||||
if torch.any(faces_indices >= max_index) or (
|
||||
pad_value is None and torch.any(faces_indices < 0)
|
||||
):
|
||||
warnings.warn("Faces have invalid indices")
|
||||
|
||||
return faces_indices
|
||||
return _check_faces_indices(faces_indices, max_index, pad_value)
|
||||
|
||||
|
||||
def load_obj(
|
||||
|
@ -12,7 +12,7 @@ from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d.io.utils import _open_file
|
||||
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
|
||||
|
||||
|
||||
_PlyTypeData = namedtuple("_PlyTypeData", "size struct_char np_type")
|
||||
@ -221,17 +221,6 @@ class _PlyHeader:
|
||||
self.elements.append(_PlyElementType(items[1], count))
|
||||
|
||||
|
||||
def _make_tensor(data, cols: int, dtype: torch.dtype) -> torch.Tensor:
|
||||
"""
|
||||
Return a 2D tensor with the specified cols and dtype filled with data,
|
||||
even when data is empty.
|
||||
"""
|
||||
if not len(data):
|
||||
return torch.zeros((0, cols), dtype=dtype)
|
||||
|
||||
return torch.tensor(data, dtype=dtype)
|
||||
|
||||
|
||||
def _read_ply_fixed_size_element_ascii(f, definition: _PlyElementType):
|
||||
"""
|
||||
Given an element which has no lists and one type, read the
|
||||
@ -691,9 +680,7 @@ def load_ply(f):
|
||||
face_list.append([face_item[0], face_item[i + 1], face_item[i + 2]])
|
||||
faces = _make_tensor(face_list, cols=3, dtype=torch.int64)
|
||||
|
||||
if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0):
|
||||
warnings.warn("Faces have invalid indices")
|
||||
|
||||
_check_faces_indices(faces, max_index=verts.shape[0])
|
||||
return verts, faces
|
||||
|
||||
|
||||
@ -747,8 +734,7 @@ def _save_ply(
|
||||
|
||||
faces_array = faces.detach().numpy()
|
||||
|
||||
if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0):
|
||||
warnings.warn("Faces have invalid indices")
|
||||
_check_faces_indices(faces, max_index=verts.shape[0])
|
||||
|
||||
if len(faces_array):
|
||||
np.savetxt(f, faces_array, "3 %d %d %d")
|
||||
|
@ -2,9 +2,11 @@
|
||||
|
||||
import contextlib
|
||||
import pathlib
|
||||
from typing import IO, ContextManager
|
||||
import warnings
|
||||
from typing import IO, ContextManager, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fvcore.common.file_io import PathManager
|
||||
from PIL import Image
|
||||
|
||||
@ -20,6 +22,34 @@ def _open_file(f, mode="r") -> ContextManager[IO]:
|
||||
return contextlib.nullcontext(f)
|
||||
|
||||
|
||||
def _make_tensor(
|
||||
data, cols: int, dtype: torch.dtype, device: str = "cpu"
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Return a 2D tensor with the specified cols and dtype filled with data,
|
||||
even when data is empty.
|
||||
"""
|
||||
if not len(data):
|
||||
return torch.zeros((0, cols), dtype=dtype, device=device)
|
||||
|
||||
return torch.tensor(data, dtype=dtype, device=device)
|
||||
|
||||
|
||||
def _check_faces_indices(
|
||||
faces_indices: torch.Tensor, max_index: int, pad_value: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
if pad_value is None:
|
||||
mask = torch.ones(faces_indices.shape[:-1]).bool() # Keep all faces
|
||||
else:
|
||||
# pyre-fixme[16]: `torch.ByteTensor` has no attribute `any`
|
||||
mask = faces_indices.ne(pad_value).any(dim=-1)
|
||||
if torch.any(faces_indices[mask] >= max_index) or torch.any(
|
||||
faces_indices[mask] < 0
|
||||
):
|
||||
warnings.warn("Faces have invalid indices")
|
||||
return faces_indices
|
||||
|
||||
|
||||
def _read_image(file_name: str, format=None):
|
||||
"""
|
||||
Read an image from a file using Pillow.
|
||||
|
Loading…
x
Reference in New Issue
Block a user