mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Make _open_file() return a context manager
Summary: Make the `_open_file()` function return a context manager and remove the associated file closure Reviewed By: nikhilaravi Differential Revision: D20720506 fbshipit-source-id: 7d96ceb2fd64b6ee3985d0b0faf8d8bf791b1966
This commit is contained in:
parent
e2b47f047e
commit
0505e5f4a9
@ -379,13 +379,13 @@ def _bilinear_interpolation_grid_sample(
|
||||
return out.permute(0, 2, 3, 1)
|
||||
|
||||
|
||||
def load_mtl(f_mtl, material_names: List, data_dir: str, device="cpu"):
|
||||
def load_mtl(f, material_names: List, data_dir: str, device="cpu"):
|
||||
"""
|
||||
Load texture images and material reflectivity values for ambient, diffuse
|
||||
and specular light (Ka, Kd, Ks, Ns).
|
||||
|
||||
Args:
|
||||
f_mtl: a file like object of the material information.
|
||||
f: a file-like object of the material information.
|
||||
material_names: a list of the material names found in the .obj file.
|
||||
data_dir: the directory where the material texture files are located.
|
||||
|
||||
@ -414,39 +414,36 @@ def load_mtl(f_mtl, material_names: List, data_dir: str, device="cpu"):
|
||||
texture_images = {}
|
||||
material_name = ""
|
||||
|
||||
f_mtl, new_f = _open_file(f_mtl)
|
||||
lines = [line.strip() for line in f_mtl]
|
||||
for line in lines:
|
||||
if len(line.split()) != 0:
|
||||
if line.split()[0] == "newmtl":
|
||||
material_name = line.split()[1]
|
||||
material_colors[material_name] = {}
|
||||
if line.split()[0] == "map_Kd":
|
||||
# Texture map.
|
||||
texture_files[material_name] = line.split()[1]
|
||||
if line.split()[0] == "Kd":
|
||||
# RGB diffuse reflectivity
|
||||
kd = np.array(list(line.split()[1:4])).astype(np.float32)
|
||||
kd = torch.from_numpy(kd).to(device)
|
||||
material_colors[material_name]["diffuse_color"] = kd
|
||||
if line.split()[0] == "Ka":
|
||||
# RGB ambient reflectivity
|
||||
ka = np.array(list(line.split()[1:4])).astype(np.float32)
|
||||
ka = torch.from_numpy(ka).to(device)
|
||||
material_colors[material_name]["ambient_color"] = ka
|
||||
if line.split()[0] == "Ks":
|
||||
# RGB specular reflectivity
|
||||
ks = np.array(list(line.split()[1:4])).astype(np.float32)
|
||||
ks = torch.from_numpy(ks).to(device)
|
||||
material_colors[material_name]["specular_color"] = ks
|
||||
if line.split()[0] == "Ns":
|
||||
# Specular exponent
|
||||
ns = np.array(list(line.split()[1:4])).astype(np.float32)
|
||||
ns = torch.from_numpy(ns).to(device)
|
||||
material_colors[material_name]["shininess"] = ns
|
||||
|
||||
if new_f:
|
||||
f_mtl.close()
|
||||
with _open_file(f) as f:
|
||||
lines = [line.strip() for line in f]
|
||||
for line in lines:
|
||||
if len(line.split()) != 0:
|
||||
if line.split()[0] == "newmtl":
|
||||
material_name = line.split()[1]
|
||||
material_colors[material_name] = {}
|
||||
if line.split()[0] == "map_Kd":
|
||||
# Texture map.
|
||||
texture_files[material_name] = line.split()[1]
|
||||
if line.split()[0] == "Kd":
|
||||
# RGB diffuse reflectivity
|
||||
kd = np.array(list(line.split()[1:4])).astype(np.float32)
|
||||
kd = torch.from_numpy(kd).to(device)
|
||||
material_colors[material_name]["diffuse_color"] = kd
|
||||
if line.split()[0] == "Ka":
|
||||
# RGB ambient reflectivity
|
||||
ka = np.array(list(line.split()[1:4])).astype(np.float32)
|
||||
ka = torch.from_numpy(ka).to(device)
|
||||
material_colors[material_name]["ambient_color"] = ka
|
||||
if line.split()[0] == "Ks":
|
||||
# RGB specular reflectivity
|
||||
ks = np.array(list(line.split()[1:4])).astype(np.float32)
|
||||
ks = torch.from_numpy(ks).to(device)
|
||||
material_colors[material_name]["specular_color"] = ks
|
||||
if line.split()[0] == "Ns":
|
||||
# Specular exponent
|
||||
ns = np.array(list(line.split()[1:4])).astype(np.float32)
|
||||
ns = torch.from_numpy(ns).to(device)
|
||||
material_colors[material_name]["shininess"] = ns
|
||||
|
||||
# Only keep the materials referenced in the obj.
|
||||
for name in material_names:
|
||||
|
@ -72,7 +72,7 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
|
||||
|
||||
|
||||
def load_obj(
|
||||
f_obj,
|
||||
f,
|
||||
load_textures=True,
|
||||
create_texture_atlas: bool = False,
|
||||
texture_atlas_size: int = 4,
|
||||
@ -211,14 +211,13 @@ def load_obj(
|
||||
None.
|
||||
"""
|
||||
data_dir = "./"
|
||||
if isinstance(f_obj, (str, bytes, os.PathLike)):
|
||||
if isinstance(f, (str, bytes, os.PathLike)):
|
||||
# pyre-fixme[6]: Expected `_PathLike[Variable[typing.AnyStr <: [str,
|
||||
# bytes]]]` for 1st param but got `Union[_PathLike[typing.Any], bytes, str]`.
|
||||
data_dir = os.path.dirname(f_obj)
|
||||
f_obj, new_f = _open_file(f_obj, "r")
|
||||
try:
|
||||
return _load(
|
||||
f_obj,
|
||||
data_dir = os.path.dirname(f)
|
||||
with _open_file(f, "r") as f:
|
||||
return _load_obj(
|
||||
f,
|
||||
data_dir,
|
||||
load_textures=load_textures,
|
||||
create_texture_atlas=create_texture_atlas,
|
||||
@ -226,9 +225,6 @@ def load_obj(
|
||||
texture_wrap=texture_wrap,
|
||||
device=device,
|
||||
)
|
||||
finally:
|
||||
if new_f:
|
||||
f_obj.close()
|
||||
|
||||
|
||||
def load_objs_as_meshes(files: list, device=None, load_textures: bool = True):
|
||||
@ -338,7 +334,7 @@ def _parse_face(
|
||||
faces_materials_idx.append(material_idx)
|
||||
|
||||
|
||||
def _load(
|
||||
def _load_obj(
|
||||
f_obj,
|
||||
data_dir,
|
||||
load_textures: bool = True,
|
||||
@ -523,12 +519,8 @@ def save_obj(f, verts, faces, decimal_places: Optional[int] = None):
|
||||
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
|
||||
raise ValueError(message)
|
||||
|
||||
f, new_f = _open_file(f, "w")
|
||||
try:
|
||||
with _open_file(f, "w") as f:
|
||||
return _save(f, verts, faces, decimal_places)
|
||||
finally:
|
||||
if new_f:
|
||||
f.close()
|
||||
|
||||
|
||||
# TODO (nikhilar) Speed up this function.
|
||||
|
@ -603,13 +603,8 @@ def _load_ply_raw(f) -> Tuple[_PlyHeader, dict]:
|
||||
uniformly-sized list, then the value will be a 2D numpy array.
|
||||
If not, it is a list of the relevant property values.
|
||||
"""
|
||||
f, new_f = _open_file(f, "rb")
|
||||
try:
|
||||
with _open_file(f, "rb") as f:
|
||||
header, elements = _load_ply_raw_stream(f)
|
||||
finally:
|
||||
if new_f:
|
||||
f.close()
|
||||
|
||||
return header, elements
|
||||
|
||||
|
||||
@ -800,9 +795,5 @@ def save_ply(
|
||||
message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)."
|
||||
raise ValueError(message)
|
||||
|
||||
f, new_f = _open_file(f, "w")
|
||||
try:
|
||||
with _open_file(f, "w") as f:
|
||||
_save_ply(f, verts, faces, verts_normals, decimal_places)
|
||||
finally:
|
||||
if new_f:
|
||||
f.close()
|
||||
|
@ -1,22 +1,23 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import contextlib
|
||||
import pathlib
|
||||
from typing import IO, ContextManager
|
||||
|
||||
import numpy as np
|
||||
from fvcore.common.file_io import PathManager
|
||||
from PIL import Image
|
||||
|
||||
|
||||
# TODO(plabatut): Replace with a context manager
|
||||
def _open_file(f, mode="r"):
|
||||
new_f = False
|
||||
def _open_file(f, mode="r") -> ContextManager[IO]:
|
||||
if isinstance(f, str):
|
||||
new_f = True
|
||||
f = open(f, mode)
|
||||
return contextlib.closing(f)
|
||||
elif isinstance(f, pathlib.Path):
|
||||
new_f = True
|
||||
f = f.open(mode)
|
||||
return f, new_f
|
||||
return contextlib.closing(f)
|
||||
else:
|
||||
return contextlib.nullcontext(f)
|
||||
|
||||
|
||||
def _read_image(file_name: str, format=None):
|
||||
|
Loading…
x
Reference in New Issue
Block a user