PathManager passing

Summary:
Make no internal functions inside pytorch3d/io interpret str paths except using a PathManager from iopath which they have been given. This means we no longer use any global PathManager object and we no longer use fvcore's deprecated file_io.

To preserve the APIs, various top level functions create their own default-initialized PathManager object if they are not provided one.

Reviewed By: theschnitz

Differential Revision: D25372969

fbshipit-source-id: c176ee31439645fa54a157d6f1aef18b09501569
This commit is contained in:
Jeremy Reizenstein 2020-12-24 10:14:37 -08:00 committed by Facebook GitHub Bot
parent b95621573b
commit 25c065e9da
6 changed files with 108 additions and 27 deletions

View File

@ -8,6 +8,7 @@ from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from iopath.common.file_io import PathManager
from pytorch3d.io.utils import _open_file, _read_image from pytorch3d.io.utils import _open_file, _read_image
@ -391,12 +392,14 @@ TextureFiles = Dict[str, str]
TextureImages = Dict[str, torch.Tensor] TextureImages = Dict[str, torch.Tensor]
def _parse_mtl(f, device="cpu") -> Tuple[MaterialProperties, TextureFiles]: def _parse_mtl(
f, path_manager: PathManager, device="cpu"
) -> Tuple[MaterialProperties, TextureFiles]:
material_properties = {} material_properties = {}
texture_files = {} texture_files = {}
material_name = "" material_name = ""
with _open_file(f, "r") as f: with _open_file(f, path_manager, "r") as f:
for line in f: for line in f:
tokens = line.strip().split() tokens = line.strip().split()
if not tokens: if not tokens:
@ -438,6 +441,7 @@ def _load_texture_images(
data_dir: str, data_dir: str,
material_properties: MaterialProperties, material_properties: MaterialProperties,
texture_files: TextureFiles, texture_files: TextureFiles,
path_manager: PathManager,
) -> Tuple[MaterialProperties, TextureImages]: ) -> Tuple[MaterialProperties, TextureImages]:
final_material_properties = {} final_material_properties = {}
texture_images = {} texture_images = {}
@ -448,7 +452,9 @@ def _load_texture_images(
# Load the texture image. # Load the texture image.
path = os.path.join(data_dir, texture_files[material_name]) path = os.path.join(data_dir, texture_files[material_name])
if os.path.isfile(path): if os.path.isfile(path):
image = _read_image(path, format="RGB") / 255.0 image = (
_read_image(path, path_manager=path_manager, format="RGB") / 255.0
)
image = torch.from_numpy(image) image = torch.from_numpy(image)
texture_images[material_name] = image texture_images[material_name] = image
else: else:
@ -464,7 +470,12 @@ def _load_texture_images(
def load_mtl( def load_mtl(
f, material_names: List[str], data_dir: str, device="cpu" f,
*,
material_names: List[str],
data_dir: str,
device="cpu",
path_manager: PathManager,
) -> Tuple[MaterialProperties, TextureImages]: ) -> Tuple[MaterialProperties, TextureImages]:
""" """
Load texture images and material reflectivity values for ambient, diffuse Load texture images and material reflectivity values for ambient, diffuse
@ -474,6 +485,7 @@ def load_mtl(
f: a file-like object of the material information. f: a file-like object of the material information.
material_names: a list of the material names found in the .obj file. material_names: a list of the material names found in the .obj file.
data_dir: the directory where the material texture files are located. data_dir: the directory where the material texture files are located.
path_manager: PathManager for interpreting both f and material_names.
Returns: Returns:
material_properties: dict of properties for each material. If a material material_properties: dict of properties for each material. If a material
@ -494,7 +506,11 @@ def load_mtl(
... ...
} }
""" """
material_properties, texture_files = _parse_mtl(f, device) material_properties, texture_files = _parse_mtl(f, path_manager, device)
return _load_texture_images( return _load_texture_images(
material_names, data_dir, material_properties, texture_files material_names,
data_dir,
material_properties,
texture_files,
path_manager=path_manager,
) )

View File

@ -9,6 +9,7 @@ from typing import List, Optional
import numpy as np import numpy as np
import torch import torch
from iopath.common.file_io import PathManager
from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
from pytorch3d.renderer import TexturesAtlas, TexturesUV from pytorch3d.renderer import TexturesAtlas, TexturesUV
@ -68,6 +69,7 @@ def load_obj(
texture_atlas_size: int = 4, texture_atlas_size: int = 4,
texture_wrap: Optional[str] = "repeat", texture_wrap: Optional[str] = "repeat",
device="cpu", device="cpu",
path_manager: Optional[PathManager] = None,
): ):
""" """
Load a mesh from a .obj file and optionally textures from a .mtl file. Load a mesh from a .obj file and optionally textures from a .mtl file.
@ -139,6 +141,7 @@ def load_obj(
If `texture_mode="clamp"` the values are clamped to the range [0, 1]. If `texture_mode="clamp"` the values are clamped to the range [0, 1].
If None, then there is no transformation of the texture values. If None, then there is no transformation of the texture values.
device: string or torch.device on which to return the new tensors. device: string or torch.device on which to return the new tensors.
path_manager: optionally a PathManager object to interpret paths.
Returns: Returns:
6-element tuple containing 6-element tuple containing
@ -207,14 +210,17 @@ def load_obj(
# pyre-fixme[6]: Expected `_PathLike[Variable[typing.AnyStr <: [str, # pyre-fixme[6]: Expected `_PathLike[Variable[typing.AnyStr <: [str,
# bytes]]]` for 1st param but got `Union[_PathLike[typing.Any], bytes, str]`. # bytes]]]` for 1st param but got `Union[_PathLike[typing.Any], bytes, str]`.
data_dir = os.path.dirname(f) data_dir = os.path.dirname(f)
with _open_file(f, "r") as f: if path_manager is None:
path_manager = PathManager()
with _open_file(f, path_manager, "r") as f:
return _load_obj( return _load_obj(
f, f,
data_dir, data_dir=data_dir,
load_textures=load_textures, load_textures=load_textures,
create_texture_atlas=create_texture_atlas, create_texture_atlas=create_texture_atlas,
texture_atlas_size=texture_atlas_size, texture_atlas_size=texture_atlas_size,
texture_wrap=texture_wrap, texture_wrap=texture_wrap,
path_manager=path_manager,
device=device, device=device,
) )
@ -226,6 +232,7 @@ def load_objs_as_meshes(
create_texture_atlas: bool = False, create_texture_atlas: bool = False,
texture_atlas_size: int = 4, texture_atlas_size: int = 4,
texture_wrap: Optional[str] = "repeat", texture_wrap: Optional[str] = "repeat",
path_manager: Optional[PathManager] = None,
): ):
""" """
Load meshes from a list of .obj files using the load_obj function, and Load meshes from a list of .obj files using the load_obj function, and
@ -234,11 +241,13 @@ def load_objs_as_meshes(
details. material_colors and normals are not stored. details. material_colors and normals are not stored.
Args: Args:
f: A list of file-like objects (with methods read, readline, tell, files: A list of file-like objects (with methods read, readline, tell,
and seek), pathlib paths or strings containing file names. and seek), pathlib paths or strings containing file names.
device: Desired device of returned Meshes. Default: device: Desired device of returned Meshes. Default:
uses the current device for the default tensor type. uses the current device for the default tensor type.
load_textures: Boolean indicating whether material files are loaded load_textures: Boolean indicating whether material files are loaded
create_texture_atlas, texture_atlas_size, texture_wrap: as for load_obj.
path_manager: optionally a PathManager object to interpret paths.
Returns: Returns:
New Meshes object. New Meshes object.
@ -251,6 +260,7 @@ def load_objs_as_meshes(
create_texture_atlas=create_texture_atlas, create_texture_atlas=create_texture_atlas,
texture_atlas_size=texture_atlas_size, texture_atlas_size=texture_atlas_size,
texture_wrap=texture_wrap, texture_wrap=texture_wrap,
path_manager=path_manager,
) )
tex = None tex = None
if create_texture_atlas: if create_texture_atlas:
@ -431,7 +441,13 @@ def _parse_obj(f, data_dir: str):
def _load_materials( def _load_materials(
material_names: List[str], f, data_dir: str, *, load_textures: bool, device material_names: List[str],
f,
*,
data_dir: str,
load_textures: bool,
device,
path_manager: PathManager,
): ):
""" """
Load materials and optionally textures from the specified path. Load materials and optionally textures from the specified path.
@ -442,6 +458,7 @@ def _load_materials(
data_dir: the directory where the material texture files are located. data_dir: the directory where the material texture files are located.
load_textures: whether textures should be loaded. load_textures: whether textures should be loaded.
device: string or torch.device on which to return the new tensors. device: string or torch.device on which to return the new tensors.
path_manager: PathManager object to interpret paths.
Returns: Returns:
material_colors: dict of properties for each material. material_colors: dict of properties for each material.
@ -460,16 +477,24 @@ def _load_materials(
return None, None return None, None
# Texture mode uv wrap # Texture mode uv wrap
return load_mtl(f, material_names, data_dir, device=device) return load_mtl(
f,
material_names=material_names,
data_dir=data_dir,
path_manager=path_manager,
device=device,
)
def _load_obj( def _load_obj(
f_obj, f_obj,
*,
data_dir, data_dir,
load_textures: bool = True, load_textures: bool = True,
create_texture_atlas: bool = False, create_texture_atlas: bool = False,
texture_atlas_size: int = 4, texture_atlas_size: int = 4,
texture_wrap: Optional[str] = "repeat", texture_wrap: Optional[str] = "repeat",
path_manager: PathManager,
device="cpu", device="cpu",
): ):
""" """
@ -522,7 +547,12 @@ def _load_obj(
texture_atlas = None texture_atlas = None
material_colors, texture_images = _load_materials( material_colors, texture_images = _load_materials(
material_names, mtl_path, data_dir, load_textures=load_textures, device=device material_names,
mtl_path,
data_dir=data_dir,
load_textures=load_textures,
path_manager=path_manager,
device=device,
) )
if create_texture_atlas: if create_texture_atlas:
@ -562,7 +592,13 @@ def _load_obj(
return verts, faces, aux return verts, faces, aux
def save_obj(f, verts, faces, decimal_places: Optional[int] = None): def save_obj(
f,
verts,
faces,
decimal_places: Optional[int] = None,
path_manager: Optional[PathManager] = None,
):
""" """
Save a mesh to an .obj file. Save a mesh to an .obj file.
@ -571,6 +607,8 @@ def save_obj(f, verts, faces, decimal_places: Optional[int] = None):
verts: FloatTensor of shape (V, 3) giving vertex coordinates. verts: FloatTensor of shape (V, 3) giving vertex coordinates.
faces: LongTensor of shape (F, 3) giving faces. faces: LongTensor of shape (F, 3) giving faces.
decimal_places: Number of decimal places for saving. decimal_places: Number of decimal places for saving.
path_manager: Optional PathManager for interpreting f if
it is a str.
""" """
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3): if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)." message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
@ -580,7 +618,10 @@ def save_obj(f, verts, faces, decimal_places: Optional[int] = None):
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)." message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
raise ValueError(message) raise ValueError(message)
with _open_file(f, "w") as f: if path_manager is None:
path_manager = PathManager()
with _open_file(f, path_manager, "w") as f:
return _save(f, verts, faces, decimal_places) return _save(f, verts, faces, decimal_places)

View File

@ -13,6 +13,7 @@ from typing import Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from iopath.common.file_io import PathManager
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
@ -585,7 +586,7 @@ def _load_ply_raw_stream(f) -> Tuple[_PlyHeader, dict]:
return header, elements return header, elements
def _load_ply_raw(f) -> Tuple[_PlyHeader, dict]: def _load_ply_raw(f, path_manager: PathManager) -> Tuple[_PlyHeader, dict]:
""" """
Load the data from a .ply file. Load the data from a .ply file.
@ -594,6 +595,7 @@ def _load_ply_raw(f) -> Tuple[_PlyHeader, dict]:
tell and seek), a pathlib path or a string containing a file name. tell and seek), a pathlib path or a string containing a file name.
If the ply file is binary, a text stream is not supported. If the ply file is binary, a text stream is not supported.
It is recommended to use a binary stream. It is recommended to use a binary stream.
path_manager: PathManager for loading if f is a str.
Returns: Returns:
header: A _PlyHeader object describing the metadata in the ply file. header: A _PlyHeader object describing the metadata in the ply file.
@ -602,12 +604,12 @@ def _load_ply_raw(f) -> Tuple[_PlyHeader, dict]:
uniformly-sized list, then the value will be a 2D numpy array. uniformly-sized list, then the value will be a 2D numpy array.
If not, it is a list of the relevant property values. If not, it is a list of the relevant property values.
""" """
with _open_file(f, "rb") as f: with _open_file(f, path_manager, "rb") as f:
header, elements = _load_ply_raw_stream(f) header, elements = _load_ply_raw_stream(f)
return header, elements return header, elements
def load_ply(f): def load_ply(f, path_manager: Optional[PathManager] = None):
""" """
Load the data from a .ply file. Load the data from a .ply file.
@ -645,12 +647,16 @@ def load_ply(f):
If the ply file is in the binary ply format rather than the text If the ply file is in the binary ply format rather than the text
ply format, then a text stream is not supported. ply format, then a text stream is not supported.
It is easiest to use a binary stream in all cases. It is easiest to use a binary stream in all cases.
path_manager: PathManager for loading if f is a str.
Returns: Returns:
verts: FloatTensor of shape (V, 3). verts: FloatTensor of shape (V, 3).
faces: LongTensor of vertex indices, shape (F, 3). faces: LongTensor of vertex indices, shape (F, 3).
""" """
header, elements = _load_ply_raw(f) if path_manager is None:
path_manager = PathManager()
header, elements = _load_ply_raw(f, path_manager=path_manager)
vertex = elements.get("vertex", None) vertex = elements.get("vertex", None)
if vertex is None: if vertex is None:
@ -780,6 +786,7 @@ def save_ply(
verts_normals: Optional[torch.Tensor] = None, verts_normals: Optional[torch.Tensor] = None,
ascii: bool = False, ascii: bool = False,
decimal_places: Optional[int] = None, decimal_places: Optional[int] = None,
path_manager: Optional[PathManager] = None,
) -> None: ) -> None:
""" """
Save a mesh to a .ply file. Save a mesh to a .ply file.
@ -791,6 +798,8 @@ def save_ply(
verts_normals: FloatTensor of shape (V, 3) giving vertex normals. verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
ascii: (bool) whether to use the ascii ply format. ascii: (bool) whether to use the ascii ply format.
decimal_places: Number of decimal places for saving if ascii=True. decimal_places: Number of decimal places for saving if ascii=True.
path_manager: PathManager for interpreting f if it is a str.
""" """
verts_normals = ( verts_normals = (
@ -816,5 +825,7 @@ def save_ply(
message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)." message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)."
raise ValueError(message) raise ValueError(message)
with _open_file(f, "wb") as f: if path_manager is None:
path_manager = PathManager()
with _open_file(f, path_manager, "wb") as f:
_save_ply(f, verts, faces, verts_normals, ascii, decimal_places) _save_ply(f, verts, faces, verts_normals, ascii, decimal_places)

View File

@ -7,7 +7,7 @@ from typing import IO, ContextManager, Optional
import numpy as np import numpy as np
import torch import torch
from fvcore.common.file_io import PathManager from iopath.common.file_io import PathManager
from PIL import Image from PIL import Image
@ -19,9 +19,9 @@ def nullcontext(x):
yield x yield x
def _open_file(f, mode="r") -> ContextManager[IO]: def _open_file(f, path_manager: PathManager, mode="r") -> ContextManager[IO]:
if isinstance(f, str): if isinstance(f, str):
f = open(f, mode) f = path_manager.open(f, mode)
return contextlib.closing(f) return contextlib.closing(f)
elif isinstance(f, pathlib.Path): elif isinstance(f, pathlib.Path):
f = f.open(mode) f = f.open(mode)
@ -58,18 +58,19 @@ def _check_faces_indices(
return faces_indices return faces_indices
def _read_image(file_name: str, format=None): def _read_image(file_name: str, path_manager: PathManager, format=None):
""" """
Read an image from a file using Pillow. Read an image from a file using Pillow.
Args: Args:
file_name: image file path. file_name: image file path.
path_manager: PathManager for interpreting file_name.
format: one of ["RGB", "BGR"] format: one of ["RGB", "BGR"]
Returns: Returns:
image: an image of shape (H, W, C). image: an image of shape (H, W, C).
""" """
if format not in ["RGB", "BGR"]: if format not in ["RGB", "BGR"]:
raise ValueError("format can only be one of [RGB, BGR]; got %s", format) raise ValueError("format can only be one of [RGB, BGR]; got %s", format)
with PathManager.open(file_name, "rb") as f: with path_manager.open(file_name, "rb") as f:
# pyre-fixme[6]: Expected `Union[str, typing.BinaryIO]` for 1st param but # pyre-fixme[6]: Expected `Union[str, typing.BinaryIO]` for 1st param but
# got `Union[typing.IO[bytes], typing.IO[str]]`. # got `Union[typing.IO[bytes], typing.IO[str]]`.
image = Image.open(f) image = Image.open(f)

View File

@ -8,6 +8,7 @@ from pathlib import Path
import torch import torch
from common_testing import TestCaseMixin from common_testing import TestCaseMixin
from iopath.common.file_io import PathManager
from pytorch3d.io import load_obj, load_objs_as_meshes, save_obj from pytorch3d.io import load_obj, load_objs_as_meshes, save_obj
from pytorch3d.io.mtl_io import ( from pytorch3d.io.mtl_io import (
_bilinear_interpolation_grid_sample, _bilinear_interpolation_grid_sample,
@ -460,7 +461,9 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
] ]
) )
mtl_file = StringIO(mtl_file) mtl_file = StringIO(mtl_file)
material_properties, texture_files = _parse_mtl(mtl_file, device="cpu") material_properties, texture_files = _parse_mtl(
mtl_file, path_manager=PathManager(), device="cpu"
)
dtype = torch.float32 dtype = torch.float32
expected_materials = { expected_materials = {

View File

@ -5,12 +5,21 @@ import unittest
from io import BytesIO, StringIO from io import BytesIO, StringIO
from tempfile import TemporaryFile from tempfile import TemporaryFile
import pytorch3d.io.ply_io
import torch import torch
from common_testing import TestCaseMixin from common_testing import TestCaseMixin
from pytorch3d.io.ply_io import _load_ply_raw, load_ply, save_ply from iopath.common.file_io import PathManager
from pytorch3d.io.ply_io import load_ply, save_ply
from pytorch3d.utils import torus from pytorch3d.utils import torus
global_path_manager = PathManager()
def _load_ply_raw(stream):
return pytorch3d.io.ply_io._load_ply_raw(stream, global_path_manager)
class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
def test_raw_load_simple_ascii(self): def test_raw_load_simple_ascii(self):
ply_file = "\n".join( ply_file = "\n".join(