From 25c065e9dafa90163e7cec873dbb324a637c68b7 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Thu, 24 Dec 2020 10:14:37 -0800 Subject: [PATCH] PathManager passing Summary: Make no internal functions inside pytorch3d/io interpret str paths except using a PathManager from iopath which they have been given. This means we no longer use any global PathManager object and we no longer use fvcore's deprecated file_io. To preserve the APIs, various top level functions create their own default-initialized PathManager object if they are not provided one. Reviewed By: theschnitz Differential Revision: D25372969 fbshipit-source-id: c176ee31439645fa54a157d6f1aef18b09501569 --- pytorch3d/io/mtl_io.py | 28 +++++++++++++++----- pytorch3d/io/obj_io.py | 59 +++++++++++++++++++++++++++++++++++------- pytorch3d/io/ply_io.py | 21 +++++++++++---- pytorch3d/io/utils.py | 11 ++++---- tests/test_obj_io.py | 5 +++- tests/test_ply_io.py | 11 +++++++- 6 files changed, 108 insertions(+), 27 deletions(-) diff --git a/pytorch3d/io/mtl_io.py b/pytorch3d/io/mtl_io.py index 2137095b..79f65d01 100644 --- a/pytorch3d/io/mtl_io.py +++ b/pytorch3d/io/mtl_io.py @@ -8,6 +8,7 @@ from typing import Dict, List, Optional, Tuple import numpy as np import torch import torch.nn.functional as F +from iopath.common.file_io import PathManager from pytorch3d.io.utils import _open_file, _read_image @@ -391,12 +392,14 @@ TextureFiles = Dict[str, str] TextureImages = Dict[str, torch.Tensor] -def _parse_mtl(f, device="cpu") -> Tuple[MaterialProperties, TextureFiles]: +def _parse_mtl( + f, path_manager: PathManager, device="cpu" +) -> Tuple[MaterialProperties, TextureFiles]: material_properties = {} texture_files = {} material_name = "" - with _open_file(f, "r") as f: + with _open_file(f, path_manager, "r") as f: for line in f: tokens = line.strip().split() if not tokens: @@ -438,6 +441,7 @@ def _load_texture_images( data_dir: str, material_properties: MaterialProperties, texture_files: TextureFiles, + path_manager: PathManager, ) -> Tuple[MaterialProperties, TextureImages]: final_material_properties = {} texture_images = {} @@ -448,7 +452,9 @@ def _load_texture_images( # Load the texture image. path = os.path.join(data_dir, texture_files[material_name]) if os.path.isfile(path): - image = _read_image(path, format="RGB") / 255.0 + image = ( + _read_image(path, path_manager=path_manager, format="RGB") / 255.0 + ) image = torch.from_numpy(image) texture_images[material_name] = image else: @@ -464,7 +470,12 @@ def _load_texture_images( def load_mtl( - f, material_names: List[str], data_dir: str, device="cpu" + f, + *, + material_names: List[str], + data_dir: str, + device="cpu", + path_manager: PathManager, ) -> Tuple[MaterialProperties, TextureImages]: """ Load texture images and material reflectivity values for ambient, diffuse @@ -474,6 +485,7 @@ def load_mtl( f: a file-like object of the material information. material_names: a list of the material names found in the .obj file. data_dir: the directory where the material texture files are located. + path_manager: PathManager for interpreting both f and material_names. Returns: material_properties: dict of properties for each material. If a material @@ -494,7 +506,11 @@ def load_mtl( ... } """ - material_properties, texture_files = _parse_mtl(f, device) + material_properties, texture_files = _parse_mtl(f, path_manager, device) return _load_texture_images( - material_names, data_dir, material_properties, texture_files + material_names, + data_dir, + material_properties, + texture_files, + path_manager=path_manager, ) diff --git a/pytorch3d/io/obj_io.py b/pytorch3d/io/obj_io.py index c4733603..cac0ce25 100644 --- a/pytorch3d/io/obj_io.py +++ b/pytorch3d/io/obj_io.py @@ -9,6 +9,7 @@ from typing import List, Optional import numpy as np import torch +from iopath.common.file_io import PathManager from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file from pytorch3d.renderer import TexturesAtlas, TexturesUV @@ -68,6 +69,7 @@ def load_obj( texture_atlas_size: int = 4, texture_wrap: Optional[str] = "repeat", device="cpu", + path_manager: Optional[PathManager] = None, ): """ Load a mesh from a .obj file and optionally textures from a .mtl file. @@ -139,6 +141,7 @@ def load_obj( If `texture_mode="clamp"` the values are clamped to the range [0, 1]. If None, then there is no transformation of the texture values. device: string or torch.device on which to return the new tensors. + path_manager: optionally a PathManager object to interpret paths. Returns: 6-element tuple containing @@ -207,14 +210,17 @@ def load_obj( # pyre-fixme[6]: Expected `_PathLike[Variable[typing.AnyStr <: [str, # bytes]]]` for 1st param but got `Union[_PathLike[typing.Any], bytes, str]`. data_dir = os.path.dirname(f) - with _open_file(f, "r") as f: + if path_manager is None: + path_manager = PathManager() + with _open_file(f, path_manager, "r") as f: return _load_obj( f, - data_dir, + data_dir=data_dir, load_textures=load_textures, create_texture_atlas=create_texture_atlas, texture_atlas_size=texture_atlas_size, texture_wrap=texture_wrap, + path_manager=path_manager, device=device, ) @@ -226,6 +232,7 @@ def load_objs_as_meshes( create_texture_atlas: bool = False, texture_atlas_size: int = 4, texture_wrap: Optional[str] = "repeat", + path_manager: Optional[PathManager] = None, ): """ Load meshes from a list of .obj files using the load_obj function, and @@ -234,11 +241,13 @@ def load_objs_as_meshes( details. material_colors and normals are not stored. Args: - f: A list of file-like objects (with methods read, readline, tell, - and seek), pathlib paths or strings containing file names. + files: A list of file-like objects (with methods read, readline, tell, + and seek), pathlib paths or strings containing file names. device: Desired device of returned Meshes. Default: uses the current device for the default tensor type. load_textures: Boolean indicating whether material files are loaded + create_texture_atlas, texture_atlas_size, texture_wrap: as for load_obj. + path_manager: optionally a PathManager object to interpret paths. Returns: New Meshes object. @@ -251,6 +260,7 @@ def load_objs_as_meshes( create_texture_atlas=create_texture_atlas, texture_atlas_size=texture_atlas_size, texture_wrap=texture_wrap, + path_manager=path_manager, ) tex = None if create_texture_atlas: @@ -431,7 +441,13 @@ def _parse_obj(f, data_dir: str): def _load_materials( - material_names: List[str], f, data_dir: str, *, load_textures: bool, device + material_names: List[str], + f, + *, + data_dir: str, + load_textures: bool, + device, + path_manager: PathManager, ): """ Load materials and optionally textures from the specified path. @@ -442,6 +458,7 @@ def _load_materials( data_dir: the directory where the material texture files are located. load_textures: whether textures should be loaded. device: string or torch.device on which to return the new tensors. + path_manager: PathManager object to interpret paths. Returns: material_colors: dict of properties for each material. @@ -460,16 +477,24 @@ def _load_materials( return None, None # Texture mode uv wrap - return load_mtl(f, material_names, data_dir, device=device) + return load_mtl( + f, + material_names=material_names, + data_dir=data_dir, + path_manager=path_manager, + device=device, + ) def _load_obj( f_obj, + *, data_dir, load_textures: bool = True, create_texture_atlas: bool = False, texture_atlas_size: int = 4, texture_wrap: Optional[str] = "repeat", + path_manager: PathManager, device="cpu", ): """ @@ -522,7 +547,12 @@ def _load_obj( texture_atlas = None material_colors, texture_images = _load_materials( - material_names, mtl_path, data_dir, load_textures=load_textures, device=device + material_names, + mtl_path, + data_dir=data_dir, + load_textures=load_textures, + path_manager=path_manager, + device=device, ) if create_texture_atlas: @@ -562,7 +592,13 @@ def _load_obj( return verts, faces, aux -def save_obj(f, verts, faces, decimal_places: Optional[int] = None): +def save_obj( + f, + verts, + faces, + decimal_places: Optional[int] = None, + path_manager: Optional[PathManager] = None, +): """ Save a mesh to an .obj file. @@ -571,6 +607,8 @@ def save_obj(f, verts, faces, decimal_places: Optional[int] = None): verts: FloatTensor of shape (V, 3) giving vertex coordinates. faces: LongTensor of shape (F, 3) giving faces. decimal_places: Number of decimal places for saving. + path_manager: Optional PathManager for interpreting f if + it is a str. """ if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3): message = "Argument 'verts' should either be empty or of shape (num_verts, 3)." @@ -580,7 +618,10 @@ def save_obj(f, verts, faces, decimal_places: Optional[int] = None): message = "Argument 'faces' should either be empty or of shape (num_faces, 3)." raise ValueError(message) - with _open_file(f, "w") as f: + if path_manager is None: + path_manager = PathManager() + + with _open_file(f, path_manager, "w") as f: return _save(f, verts, faces, decimal_places) diff --git a/pytorch3d/io/ply_io.py b/pytorch3d/io/ply_io.py index a9e9942e..5a1b2baa 100644 --- a/pytorch3d/io/ply_io.py +++ b/pytorch3d/io/ply_io.py @@ -13,6 +13,7 @@ from typing import Optional, Tuple import numpy as np import torch +from iopath.common.file_io import PathManager from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file @@ -585,7 +586,7 @@ def _load_ply_raw_stream(f) -> Tuple[_PlyHeader, dict]: return header, elements -def _load_ply_raw(f) -> Tuple[_PlyHeader, dict]: +def _load_ply_raw(f, path_manager: PathManager) -> Tuple[_PlyHeader, dict]: """ Load the data from a .ply file. @@ -594,6 +595,7 @@ def _load_ply_raw(f) -> Tuple[_PlyHeader, dict]: tell and seek), a pathlib path or a string containing a file name. If the ply file is binary, a text stream is not supported. It is recommended to use a binary stream. + path_manager: PathManager for loading if f is a str. Returns: header: A _PlyHeader object describing the metadata in the ply file. @@ -602,12 +604,12 @@ def _load_ply_raw(f) -> Tuple[_PlyHeader, dict]: uniformly-sized list, then the value will be a 2D numpy array. If not, it is a list of the relevant property values. """ - with _open_file(f, "rb") as f: + with _open_file(f, path_manager, "rb") as f: header, elements = _load_ply_raw_stream(f) return header, elements -def load_ply(f): +def load_ply(f, path_manager: Optional[PathManager] = None): """ Load the data from a .ply file. @@ -645,12 +647,16 @@ def load_ply(f): If the ply file is in the binary ply format rather than the text ply format, then a text stream is not supported. It is easiest to use a binary stream in all cases. + path_manager: PathManager for loading if f is a str. + Returns: verts: FloatTensor of shape (V, 3). faces: LongTensor of vertex indices, shape (F, 3). """ - header, elements = _load_ply_raw(f) + if path_manager is None: + path_manager = PathManager() + header, elements = _load_ply_raw(f, path_manager=path_manager) vertex = elements.get("vertex", None) if vertex is None: @@ -780,6 +786,7 @@ def save_ply( verts_normals: Optional[torch.Tensor] = None, ascii: bool = False, decimal_places: Optional[int] = None, + path_manager: Optional[PathManager] = None, ) -> None: """ Save a mesh to a .ply file. @@ -791,6 +798,8 @@ def save_ply( verts_normals: FloatTensor of shape (V, 3) giving vertex normals. ascii: (bool) whether to use the ascii ply format. decimal_places: Number of decimal places for saving if ascii=True. + path_manager: PathManager for interpreting f if it is a str. + """ verts_normals = ( @@ -816,5 +825,7 @@ def save_ply( message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)." raise ValueError(message) - with _open_file(f, "wb") as f: + if path_manager is None: + path_manager = PathManager() + with _open_file(f, path_manager, "wb") as f: _save_ply(f, verts, faces, verts_normals, ascii, decimal_places) diff --git a/pytorch3d/io/utils.py b/pytorch3d/io/utils.py index 7cea7947..a9191789 100644 --- a/pytorch3d/io/utils.py +++ b/pytorch3d/io/utils.py @@ -7,7 +7,7 @@ from typing import IO, ContextManager, Optional import numpy as np import torch -from fvcore.common.file_io import PathManager +from iopath.common.file_io import PathManager from PIL import Image @@ -19,9 +19,9 @@ def nullcontext(x): yield x -def _open_file(f, mode="r") -> ContextManager[IO]: +def _open_file(f, path_manager: PathManager, mode="r") -> ContextManager[IO]: if isinstance(f, str): - f = open(f, mode) + f = path_manager.open(f, mode) return contextlib.closing(f) elif isinstance(f, pathlib.Path): f = f.open(mode) @@ -58,18 +58,19 @@ def _check_faces_indices( return faces_indices -def _read_image(file_name: str, format=None): +def _read_image(file_name: str, path_manager: PathManager, format=None): """ Read an image from a file using Pillow. Args: file_name: image file path. + path_manager: PathManager for interpreting file_name. format: one of ["RGB", "BGR"] Returns: image: an image of shape (H, W, C). """ if format not in ["RGB", "BGR"]: raise ValueError("format can only be one of [RGB, BGR]; got %s", format) - with PathManager.open(file_name, "rb") as f: + with path_manager.open(file_name, "rb") as f: # pyre-fixme[6]: Expected `Union[str, typing.BinaryIO]` for 1st param but # got `Union[typing.IO[bytes], typing.IO[str]]`. image = Image.open(f) diff --git a/tests/test_obj_io.py b/tests/test_obj_io.py index a0cd644f..cccc87fa 100644 --- a/tests/test_obj_io.py +++ b/tests/test_obj_io.py @@ -8,6 +8,7 @@ from pathlib import Path import torch from common_testing import TestCaseMixin +from iopath.common.file_io import PathManager from pytorch3d.io import load_obj, load_objs_as_meshes, save_obj from pytorch3d.io.mtl_io import ( _bilinear_interpolation_grid_sample, @@ -460,7 +461,9 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): ] ) mtl_file = StringIO(mtl_file) - material_properties, texture_files = _parse_mtl(mtl_file, device="cpu") + material_properties, texture_files = _parse_mtl( + mtl_file, path_manager=PathManager(), device="cpu" + ) dtype = torch.float32 expected_materials = { diff --git a/tests/test_ply_io.py b/tests/test_ply_io.py index 408311bf..48d2dfaf 100644 --- a/tests/test_ply_io.py +++ b/tests/test_ply_io.py @@ -5,12 +5,21 @@ import unittest from io import BytesIO, StringIO from tempfile import TemporaryFile +import pytorch3d.io.ply_io import torch from common_testing import TestCaseMixin -from pytorch3d.io.ply_io import _load_ply_raw, load_ply, save_ply +from iopath.common.file_io import PathManager +from pytorch3d.io.ply_io import load_ply, save_ply from pytorch3d.utils import torus +global_path_manager = PathManager() + + +def _load_ply_raw(stream): + return pytorch3d.io.ply_io._load_ply_raw(stream, global_path_manager) + + class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): def test_raw_load_simple_ascii(self): ply_file = "\n".join(