mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
No side effect with invalid inputs to save_obj / save_ply
Summary: Do not create output files with invalid inputs to `save_{obj,ply}()`. Reviewed By: bottler Differential Revision: D20720282 fbshipit-source-id: 3b611a10da6f6eecacab2a1900bf16f89e2000aa
This commit is contained in:
parent
83feed56a0
commit
745aaf3915
@ -526,6 +526,14 @@ def save_obj(f, verts, faces, decimal_places: Optional[int] = None):
|
|||||||
faces: LongTensor of shape (F, 3) giving faces.
|
faces: LongTensor of shape (F, 3) giving faces.
|
||||||
decimal_places: Number of decimal places for saving.
|
decimal_places: Number of decimal places for saving.
|
||||||
"""
|
"""
|
||||||
|
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
|
||||||
|
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
|
||||||
|
raise ValueError(message)
|
||||||
|
|
||||||
|
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
|
||||||
|
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
|
||||||
|
raise ValueError(message)
|
||||||
|
|
||||||
new_f = False
|
new_f = False
|
||||||
if isinstance(f, str):
|
if isinstance(f, str):
|
||||||
new_f = True
|
new_f = True
|
||||||
@ -541,21 +549,14 @@ def save_obj(f, verts, faces, decimal_places: Optional[int] = None):
|
|||||||
|
|
||||||
|
|
||||||
# TODO (nikhilar) Speed up this function.
|
# TODO (nikhilar) Speed up this function.
|
||||||
def _save(f, verts, faces, decimal_places: Optional[int] = None):
|
def _save(f, verts, faces, decimal_places: Optional[int] = None) -> None:
|
||||||
|
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
|
||||||
|
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
|
||||||
|
|
||||||
if not (len(verts) or len(faces)):
|
if not (len(verts) or len(faces)):
|
||||||
warnings.warn("Empty 'verts' and 'faces' arguments provided")
|
warnings.warn("Empty 'verts' and 'faces' arguments provided")
|
||||||
return
|
return
|
||||||
|
|
||||||
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
|
|
||||||
raise ValueError(
|
|
||||||
"Argument 'verts' should either be empty or of shape (num_verts, 3)."
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
|
|
||||||
raise ValueError(
|
|
||||||
"Argument 'faces' should either be empty or of shape (num_faces, 3)."
|
|
||||||
)
|
|
||||||
|
|
||||||
verts, faces = verts.cpu(), faces.cpu()
|
verts, faces = verts.cpu(), faces.cpu()
|
||||||
|
|
||||||
lines = ""
|
lines = ""
|
||||||
|
@ -700,7 +700,7 @@ def load_ply(f):
|
|||||||
return verts, faces
|
return verts, faces
|
||||||
|
|
||||||
|
|
||||||
def _save_ply(f, verts, faces, decimal_places: Optional[int]):
|
def _save_ply(f, verts, faces, decimal_places: Optional[int]) -> None:
|
||||||
"""
|
"""
|
||||||
Internal implementation for saving a mesh to a .ply file.
|
Internal implementation for saving a mesh to a .ply file.
|
||||||
|
|
||||||
@ -710,15 +710,8 @@ def _save_ply(f, verts, faces, decimal_places: Optional[int]):
|
|||||||
faces: LongTensor of shape (F, 3) giving faces.
|
faces: LongTensor of shape (F, 3) giving faces.
|
||||||
decimal_places: Number of decimal places for saving.
|
decimal_places: Number of decimal places for saving.
|
||||||
"""
|
"""
|
||||||
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
|
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
|
||||||
raise ValueError(
|
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
|
||||||
"Argument 'verts' should either be empty or of shape (num_verts, 3)."
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
|
|
||||||
raise ValueError(
|
|
||||||
"Argument 'faces' should either be empty or of shape (num_faces, 3)."
|
|
||||||
)
|
|
||||||
|
|
||||||
print("ply\nformat ascii 1.0", file=f)
|
print("ply\nformat ascii 1.0", file=f)
|
||||||
print(f"element vertex {verts.shape[0]}", file=f)
|
print(f"element vertex {verts.shape[0]}", file=f)
|
||||||
@ -760,6 +753,14 @@ def save_ply(f, verts, faces, decimal_places: Optional[int] = None):
|
|||||||
faces: LongTensor of shape (F, 3) giving faces.
|
faces: LongTensor of shape (F, 3) giving faces.
|
||||||
decimal_places: Number of decimal places for saving.
|
decimal_places: Number of decimal places for saving.
|
||||||
"""
|
"""
|
||||||
|
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
|
||||||
|
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
|
||||||
|
raise ValueError(message)
|
||||||
|
|
||||||
|
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
|
||||||
|
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
|
||||||
|
raise ValueError(message)
|
||||||
|
|
||||||
new_f = False
|
new_f = False
|
||||||
if isinstance(f, str):
|
if isinstance(f, str):
|
||||||
new_f = True
|
new_f = True
|
||||||
|
Loading…
x
Reference in New Issue
Block a user