Make _open_file() return a context manager

Summary: Make the `_open_file()` function return a context manager and remove the associated file closure

Reviewed By: nikhilaravi

Differential Revision: D20720506

fbshipit-source-id: 7d96ceb2fd64b6ee3985d0b0faf8d8bf791b1966
This commit is contained in:
Patrick Labatut 2020-07-13 12:03:01 -07:00 committed by Facebook GitHub Bot
parent e2b47f047e
commit 0505e5f4a9
4 changed files with 49 additions and 68 deletions

View File

@ -379,13 +379,13 @@ def _bilinear_interpolation_grid_sample(
return out.permute(0, 2, 3, 1) return out.permute(0, 2, 3, 1)
def load_mtl(f_mtl, material_names: List, data_dir: str, device="cpu"): def load_mtl(f, material_names: List, data_dir: str, device="cpu"):
""" """
Load texture images and material reflectivity values for ambient, diffuse Load texture images and material reflectivity values for ambient, diffuse
and specular light (Ka, Kd, Ks, Ns). and specular light (Ka, Kd, Ks, Ns).
Args: Args:
f_mtl: a file like object of the material information. f: a file-like object of the material information.
material_names: a list of the material names found in the .obj file. material_names: a list of the material names found in the .obj file.
data_dir: the directory where the material texture files are located. data_dir: the directory where the material texture files are located.
@ -414,8 +414,8 @@ def load_mtl(f_mtl, material_names: List, data_dir: str, device="cpu"):
texture_images = {} texture_images = {}
material_name = "" material_name = ""
f_mtl, new_f = _open_file(f_mtl) with _open_file(f) as f:
lines = [line.strip() for line in f_mtl] lines = [line.strip() for line in f]
for line in lines: for line in lines:
if len(line.split()) != 0: if len(line.split()) != 0:
if line.split()[0] == "newmtl": if line.split()[0] == "newmtl":
@ -445,9 +445,6 @@ def load_mtl(f_mtl, material_names: List, data_dir: str, device="cpu"):
ns = torch.from_numpy(ns).to(device) ns = torch.from_numpy(ns).to(device)
material_colors[material_name]["shininess"] = ns material_colors[material_name]["shininess"] = ns
if new_f:
f_mtl.close()
# Only keep the materials referenced in the obj. # Only keep the materials referenced in the obj.
for name in material_names: for name in material_names:
if name in texture_files: if name in texture_files:

View File

@ -72,7 +72,7 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
def load_obj( def load_obj(
f_obj, f,
load_textures=True, load_textures=True,
create_texture_atlas: bool = False, create_texture_atlas: bool = False,
texture_atlas_size: int = 4, texture_atlas_size: int = 4,
@ -211,14 +211,13 @@ def load_obj(
None. None.
""" """
data_dir = "./" data_dir = "./"
if isinstance(f_obj, (str, bytes, os.PathLike)): if isinstance(f, (str, bytes, os.PathLike)):
# pyre-fixme[6]: Expected `_PathLike[Variable[typing.AnyStr <: [str, # pyre-fixme[6]: Expected `_PathLike[Variable[typing.AnyStr <: [str,
# bytes]]]` for 1st param but got `Union[_PathLike[typing.Any], bytes, str]`. # bytes]]]` for 1st param but got `Union[_PathLike[typing.Any], bytes, str]`.
data_dir = os.path.dirname(f_obj) data_dir = os.path.dirname(f)
f_obj, new_f = _open_file(f_obj, "r") with _open_file(f, "r") as f:
try: return _load_obj(
return _load( f,
f_obj,
data_dir, data_dir,
load_textures=load_textures, load_textures=load_textures,
create_texture_atlas=create_texture_atlas, create_texture_atlas=create_texture_atlas,
@ -226,9 +225,6 @@ def load_obj(
texture_wrap=texture_wrap, texture_wrap=texture_wrap,
device=device, device=device,
) )
finally:
if new_f:
f_obj.close()
def load_objs_as_meshes(files: list, device=None, load_textures: bool = True): def load_objs_as_meshes(files: list, device=None, load_textures: bool = True):
@ -338,7 +334,7 @@ def _parse_face(
faces_materials_idx.append(material_idx) faces_materials_idx.append(material_idx)
def _load( def _load_obj(
f_obj, f_obj,
data_dir, data_dir,
load_textures: bool = True, load_textures: bool = True,
@ -523,12 +519,8 @@ def save_obj(f, verts, faces, decimal_places: Optional[int] = None):
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)." message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
raise ValueError(message) raise ValueError(message)
f, new_f = _open_file(f, "w") with _open_file(f, "w") as f:
try:
return _save(f, verts, faces, decimal_places) return _save(f, verts, faces, decimal_places)
finally:
if new_f:
f.close()
# TODO (nikhilar) Speed up this function. # TODO (nikhilar) Speed up this function.

View File

@ -603,13 +603,8 @@ def _load_ply_raw(f) -> Tuple[_PlyHeader, dict]:
uniformly-sized list, then the value will be a 2D numpy array. uniformly-sized list, then the value will be a 2D numpy array.
If not, it is a list of the relevant property values. If not, it is a list of the relevant property values.
""" """
f, new_f = _open_file(f, "rb") with _open_file(f, "rb") as f:
try:
header, elements = _load_ply_raw_stream(f) header, elements = _load_ply_raw_stream(f)
finally:
if new_f:
f.close()
return header, elements return header, elements
@ -800,9 +795,5 @@ def save_ply(
message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)." message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)."
raise ValueError(message) raise ValueError(message)
f, new_f = _open_file(f, "w") with _open_file(f, "w") as f:
try:
_save_ply(f, verts, faces, verts_normals, decimal_places) _save_ply(f, verts, faces, verts_normals, decimal_places)
finally:
if new_f:
f.close()

View File

@ -1,22 +1,23 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import contextlib
import pathlib import pathlib
from typing import IO, ContextManager
import numpy as np import numpy as np
from fvcore.common.file_io import PathManager from fvcore.common.file_io import PathManager
from PIL import Image from PIL import Image
# TODO(plabatut): Replace with a context manager def _open_file(f, mode="r") -> ContextManager[IO]:
def _open_file(f, mode="r"):
new_f = False
if isinstance(f, str): if isinstance(f, str):
new_f = True
f = open(f, mode) f = open(f, mode)
return contextlib.closing(f)
elif isinstance(f, pathlib.Path): elif isinstance(f, pathlib.Path):
new_f = True
f = f.open(mode) f = f.open(mode)
return f, new_f return contextlib.closing(f)
else:
return contextlib.nullcontext(f)
def _read_image(file_name: str, format=None): def _read_image(file_name: str, format=None):