mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
No side effect with invalid inputs to save_obj / save_ply
Summary: Do not create output files with invalid inputs to `save_{obj,ply}()`. Reviewed By: bottler Differential Revision: D20720282 fbshipit-source-id: 3b611a10da6f6eecacab2a1900bf16f89e2000aa
This commit is contained in:
parent
83feed56a0
commit
745aaf3915
@ -526,6 +526,14 @@ def save_obj(f, verts, faces, decimal_places: Optional[int] = None):
|
||||
faces: LongTensor of shape (F, 3) giving faces.
|
||||
decimal_places: Number of decimal places for saving.
|
||||
"""
|
||||
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
|
||||
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
|
||||
raise ValueError(message)
|
||||
|
||||
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
|
||||
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
|
||||
raise ValueError(message)
|
||||
|
||||
new_f = False
|
||||
if isinstance(f, str):
|
||||
new_f = True
|
||||
@ -541,21 +549,14 @@ def save_obj(f, verts, faces, decimal_places: Optional[int] = None):
|
||||
|
||||
|
||||
# TODO (nikhilar) Speed up this function.
|
||||
def _save(f, verts, faces, decimal_places: Optional[int] = None):
|
||||
def _save(f, verts, faces, decimal_places: Optional[int] = None) -> None:
|
||||
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
|
||||
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
|
||||
|
||||
if not (len(verts) or len(faces)):
|
||||
warnings.warn("Empty 'verts' and 'faces' arguments provided")
|
||||
return
|
||||
|
||||
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
|
||||
raise ValueError(
|
||||
"Argument 'verts' should either be empty or of shape (num_verts, 3)."
|
||||
)
|
||||
|
||||
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
|
||||
raise ValueError(
|
||||
"Argument 'faces' should either be empty or of shape (num_faces, 3)."
|
||||
)
|
||||
|
||||
verts, faces = verts.cpu(), faces.cpu()
|
||||
|
||||
lines = ""
|
||||
|
@ -700,7 +700,7 @@ def load_ply(f):
|
||||
return verts, faces
|
||||
|
||||
|
||||
def _save_ply(f, verts, faces, decimal_places: Optional[int]):
|
||||
def _save_ply(f, verts, faces, decimal_places: Optional[int]) -> None:
|
||||
"""
|
||||
Internal implementation for saving a mesh to a .ply file.
|
||||
|
||||
@ -710,15 +710,8 @@ def _save_ply(f, verts, faces, decimal_places: Optional[int]):
|
||||
faces: LongTensor of shape (F, 3) giving faces.
|
||||
decimal_places: Number of decimal places for saving.
|
||||
"""
|
||||
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
|
||||
raise ValueError(
|
||||
"Argument 'verts' should either be empty or of shape (num_verts, 3)."
|
||||
)
|
||||
|
||||
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
|
||||
raise ValueError(
|
||||
"Argument 'faces' should either be empty or of shape (num_faces, 3)."
|
||||
)
|
||||
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
|
||||
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
|
||||
|
||||
print("ply\nformat ascii 1.0", file=f)
|
||||
print(f"element vertex {verts.shape[0]}", file=f)
|
||||
@ -760,6 +753,14 @@ def save_ply(f, verts, faces, decimal_places: Optional[int] = None):
|
||||
faces: LongTensor of shape (F, 3) giving faces.
|
||||
decimal_places: Number of decimal places for saving.
|
||||
"""
|
||||
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
|
||||
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
|
||||
raise ValueError(message)
|
||||
|
||||
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
|
||||
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
|
||||
raise ValueError(message)
|
||||
|
||||
new_f = False
|
||||
if isinstance(f, str):
|
||||
new_f = True
|
||||
|
Loading…
x
Reference in New Issue
Block a user