No side effect with invalid inputs to save_obj / save_ply

Summary: Do not create output files with invalid inputs to `save_{obj,ply}()`.

Reviewed By: bottler

Differential Revision: D20720282

fbshipit-source-id: 3b611a10da6f6eecacab2a1900bf16f89e2000aa
This commit is contained in:
Patrick Labatut 2020-04-01 11:41:31 -07:00 committed by Facebook GitHub Bot
parent 83feed56a0
commit 745aaf3915
2 changed files with 23 additions and 21 deletions

View File

@ -526,6 +526,14 @@ def save_obj(f, verts, faces, decimal_places: Optional[int] = None):
faces: LongTensor of shape (F, 3) giving faces.
decimal_places: Number of decimal places for saving.
"""
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
raise ValueError(message)
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
raise ValueError(message)
new_f = False
if isinstance(f, str):
new_f = True
@ -541,21 +549,14 @@ def save_obj(f, verts, faces, decimal_places: Optional[int] = None):
# TODO (nikhilar) Speed up this function.
def _save(f, verts, faces, decimal_places: Optional[int] = None):
def _save(f, verts, faces, decimal_places: Optional[int] = None) -> None:
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
if not (len(verts) or len(faces)):
warnings.warn("Empty 'verts' and 'faces' arguments provided")
return
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
raise ValueError(
"Argument 'verts' should either be empty or of shape (num_verts, 3)."
)
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
raise ValueError(
"Argument 'faces' should either be empty or of shape (num_faces, 3)."
)
verts, faces = verts.cpu(), faces.cpu()
lines = ""

View File

@ -700,7 +700,7 @@ def load_ply(f):
return verts, faces
def _save_ply(f, verts, faces, decimal_places: Optional[int]):
def _save_ply(f, verts, faces, decimal_places: Optional[int]) -> None:
"""
Internal implementation for saving a mesh to a .ply file.
@ -710,15 +710,8 @@ def _save_ply(f, verts, faces, decimal_places: Optional[int]):
faces: LongTensor of shape (F, 3) giving faces.
decimal_places: Number of decimal places for saving.
"""
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
raise ValueError(
"Argument 'verts' should either be empty or of shape (num_verts, 3)."
)
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
raise ValueError(
"Argument 'faces' should either be empty or of shape (num_faces, 3)."
)
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
print("ply\nformat ascii 1.0", file=f)
print(f"element vertex {verts.shape[0]}", file=f)
@ -760,6 +753,14 @@ def save_ply(f, verts, faces, decimal_places: Optional[int] = None):
faces: LongTensor of shape (F, 3) giving faces.
decimal_places: Number of decimal places for saving.
"""
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
raise ValueError(message)
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
raise ValueError(message)
new_f = False
if isinstance(f, str):
new_f = True