From 0505e5f4a9c52f4ce818884ed18d488b68524343 Mon Sep 17 00:00:00 2001 From: Patrick Labatut Date: Mon, 13 Jul 2020 12:03:01 -0700 Subject: [PATCH] Make _open_file() return a context manager Summary: Make the `_open_file()` function return a context manager and remove the associated file closure Reviewed By: nikhilaravi Differential Revision: D20720506 fbshipit-source-id: 7d96ceb2fd64b6ee3985d0b0faf8d8bf791b1966 --- pytorch3d/io/mtl_io.py | 67 ++++++++++++++++++++---------------------- pytorch3d/io/obj_io.py | 24 +++++---------- pytorch3d/io/ply_io.py | 13 ++------ pytorch3d/io/utils.py | 13 ++++---- 4 files changed, 49 insertions(+), 68 deletions(-) diff --git a/pytorch3d/io/mtl_io.py b/pytorch3d/io/mtl_io.py index 97f5ba74..885a636f 100644 --- a/pytorch3d/io/mtl_io.py +++ b/pytorch3d/io/mtl_io.py @@ -379,13 +379,13 @@ def _bilinear_interpolation_grid_sample( return out.permute(0, 2, 3, 1) -def load_mtl(f_mtl, material_names: List, data_dir: str, device="cpu"): +def load_mtl(f, material_names: List, data_dir: str, device="cpu"): """ Load texture images and material reflectivity values for ambient, diffuse and specular light (Ka, Kd, Ks, Ns). Args: - f_mtl: a file like object of the material information. + f: a file-like object of the material information. material_names: a list of the material names found in the .obj file. data_dir: the directory where the material texture files are located. @@ -414,39 +414,36 @@ def load_mtl(f_mtl, material_names: List, data_dir: str, device="cpu"): texture_images = {} material_name = "" - f_mtl, new_f = _open_file(f_mtl) - lines = [line.strip() for line in f_mtl] - for line in lines: - if len(line.split()) != 0: - if line.split()[0] == "newmtl": - material_name = line.split()[1] - material_colors[material_name] = {} - if line.split()[0] == "map_Kd": - # Texture map. - texture_files[material_name] = line.split()[1] - if line.split()[0] == "Kd": - # RGB diffuse reflectivity - kd = np.array(list(line.split()[1:4])).astype(np.float32) - kd = torch.from_numpy(kd).to(device) - material_colors[material_name]["diffuse_color"] = kd - if line.split()[0] == "Ka": - # RGB ambient reflectivity - ka = np.array(list(line.split()[1:4])).astype(np.float32) - ka = torch.from_numpy(ka).to(device) - material_colors[material_name]["ambient_color"] = ka - if line.split()[0] == "Ks": - # RGB specular reflectivity - ks = np.array(list(line.split()[1:4])).astype(np.float32) - ks = torch.from_numpy(ks).to(device) - material_colors[material_name]["specular_color"] = ks - if line.split()[0] == "Ns": - # Specular exponent - ns = np.array(list(line.split()[1:4])).astype(np.float32) - ns = torch.from_numpy(ns).to(device) - material_colors[material_name]["shininess"] = ns - - if new_f: - f_mtl.close() + with _open_file(f) as f: + lines = [line.strip() for line in f] + for line in lines: + if len(line.split()) != 0: + if line.split()[0] == "newmtl": + material_name = line.split()[1] + material_colors[material_name] = {} + if line.split()[0] == "map_Kd": + # Texture map. + texture_files[material_name] = line.split()[1] + if line.split()[0] == "Kd": + # RGB diffuse reflectivity + kd = np.array(list(line.split()[1:4])).astype(np.float32) + kd = torch.from_numpy(kd).to(device) + material_colors[material_name]["diffuse_color"] = kd + if line.split()[0] == "Ka": + # RGB ambient reflectivity + ka = np.array(list(line.split()[1:4])).astype(np.float32) + ka = torch.from_numpy(ka).to(device) + material_colors[material_name]["ambient_color"] = ka + if line.split()[0] == "Ks": + # RGB specular reflectivity + ks = np.array(list(line.split()[1:4])).astype(np.float32) + ks = torch.from_numpy(ks).to(device) + material_colors[material_name]["specular_color"] = ks + if line.split()[0] == "Ns": + # Specular exponent + ns = np.array(list(line.split()[1:4])).astype(np.float32) + ns = torch.from_numpy(ns).to(device) + material_colors[material_name]["shininess"] = ns # Only keep the materials referenced in the obj. for name in material_names: diff --git a/pytorch3d/io/obj_io.py b/pytorch3d/io/obj_io.py index 2affff9b..acfebeb5 100644 --- a/pytorch3d/io/obj_io.py +++ b/pytorch3d/io/obj_io.py @@ -72,7 +72,7 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None): def load_obj( - f_obj, + f, load_textures=True, create_texture_atlas: bool = False, texture_atlas_size: int = 4, @@ -211,14 +211,13 @@ def load_obj( None. """ data_dir = "./" - if isinstance(f_obj, (str, bytes, os.PathLike)): + if isinstance(f, (str, bytes, os.PathLike)): # pyre-fixme[6]: Expected `_PathLike[Variable[typing.AnyStr <: [str, # bytes]]]` for 1st param but got `Union[_PathLike[typing.Any], bytes, str]`. - data_dir = os.path.dirname(f_obj) - f_obj, new_f = _open_file(f_obj, "r") - try: - return _load( - f_obj, + data_dir = os.path.dirname(f) + with _open_file(f, "r") as f: + return _load_obj( + f, data_dir, load_textures=load_textures, create_texture_atlas=create_texture_atlas, @@ -226,9 +225,6 @@ def load_obj( texture_wrap=texture_wrap, device=device, ) - finally: - if new_f: - f_obj.close() def load_objs_as_meshes(files: list, device=None, load_textures: bool = True): @@ -338,7 +334,7 @@ def _parse_face( faces_materials_idx.append(material_idx) -def _load( +def _load_obj( f_obj, data_dir, load_textures: bool = True, @@ -523,12 +519,8 @@ def save_obj(f, verts, faces, decimal_places: Optional[int] = None): message = "Argument 'faces' should either be empty or of shape (num_faces, 3)." raise ValueError(message) - f, new_f = _open_file(f, "w") - try: + with _open_file(f, "w") as f: return _save(f, verts, faces, decimal_places) - finally: - if new_f: - f.close() # TODO (nikhilar) Speed up this function. diff --git a/pytorch3d/io/ply_io.py b/pytorch3d/io/ply_io.py index 4afc46ee..76709436 100644 --- a/pytorch3d/io/ply_io.py +++ b/pytorch3d/io/ply_io.py @@ -603,13 +603,8 @@ def _load_ply_raw(f) -> Tuple[_PlyHeader, dict]: uniformly-sized list, then the value will be a 2D numpy array. If not, it is a list of the relevant property values. """ - f, new_f = _open_file(f, "rb") - try: + with _open_file(f, "rb") as f: header, elements = _load_ply_raw_stream(f) - finally: - if new_f: - f.close() - return header, elements @@ -800,9 +795,5 @@ def save_ply( message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)." raise ValueError(message) - f, new_f = _open_file(f, "w") - try: + with _open_file(f, "w") as f: _save_ply(f, verts, faces, verts_normals, decimal_places) - finally: - if new_f: - f.close() diff --git a/pytorch3d/io/utils.py b/pytorch3d/io/utils.py index 88bbcc65..278ced15 100644 --- a/pytorch3d/io/utils.py +++ b/pytorch3d/io/utils.py @@ -1,22 +1,23 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import contextlib import pathlib +from typing import IO, ContextManager import numpy as np from fvcore.common.file_io import PathManager from PIL import Image -# TODO(plabatut): Replace with a context manager -def _open_file(f, mode="r"): - new_f = False +def _open_file(f, mode="r") -> ContextManager[IO]: if isinstance(f, str): - new_f = True f = open(f, mode) + return contextlib.closing(f) elif isinstance(f, pathlib.Path): - new_f = True f = f.open(mode) - return f, new_f + return contextlib.closing(f) + else: + return contextlib.nullcontext(f) def _read_image(file_name: str, format=None):