diff --git a/pytorch3d/io/obj_io.py b/pytorch3d/io/obj_io.py index feefc746..cdf401b3 100644 --- a/pytorch3d/io/obj_io.py +++ b/pytorch3d/io/obj_io.py @@ -526,6 +526,14 @@ def save_obj(f, verts, faces, decimal_places: Optional[int] = None): faces: LongTensor of shape (F, 3) giving faces. decimal_places: Number of decimal places for saving. """ + if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3): + message = "Argument 'verts' should either be empty or of shape (num_verts, 3)." + raise ValueError(message) + + if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3): + message = "Argument 'faces' should either be empty or of shape (num_faces, 3)." + raise ValueError(message) + new_f = False if isinstance(f, str): new_f = True @@ -541,21 +549,14 @@ def save_obj(f, verts, faces, decimal_places: Optional[int] = None): # TODO (nikhilar) Speed up this function. -def _save(f, verts, faces, decimal_places: Optional[int] = None): +def _save(f, verts, faces, decimal_places: Optional[int] = None) -> None: + assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3) + assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3) + if not (len(verts) or len(faces)): warnings.warn("Empty 'verts' and 'faces' arguments provided") return - if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3): - raise ValueError( - "Argument 'verts' should either be empty or of shape (num_verts, 3)." - ) - - if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3): - raise ValueError( - "Argument 'faces' should either be empty or of shape (num_faces, 3)." - ) - verts, faces = verts.cpu(), faces.cpu() lines = "" diff --git a/pytorch3d/io/ply_io.py b/pytorch3d/io/ply_io.py index 8b13b051..8c4182dd 100644 --- a/pytorch3d/io/ply_io.py +++ b/pytorch3d/io/ply_io.py @@ -700,7 +700,7 @@ def load_ply(f): return verts, faces -def _save_ply(f, verts, faces, decimal_places: Optional[int]): +def _save_ply(f, verts, faces, decimal_places: Optional[int]) -> None: """ Internal implementation for saving a mesh to a .ply file. @@ -710,15 +710,8 @@ def _save_ply(f, verts, faces, decimal_places: Optional[int]): faces: LongTensor of shape (F, 3) giving faces. decimal_places: Number of decimal places for saving. """ - if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3): - raise ValueError( - "Argument 'verts' should either be empty or of shape (num_verts, 3)." - ) - - if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3): - raise ValueError( - "Argument 'faces' should either be empty or of shape (num_faces, 3)." - ) + assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3) + assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3) print("ply\nformat ascii 1.0", file=f) print(f"element vertex {verts.shape[0]}", file=f) @@ -760,6 +753,14 @@ def save_ply(f, verts, faces, decimal_places: Optional[int] = None): faces: LongTensor of shape (F, 3) giving faces. decimal_places: Number of decimal places for saving. """ + if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3): + message = "Argument 'verts' should either be empty or of shape (num_verts, 3)." + raise ValueError(message) + + if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3): + message = "Argument 'faces' should either be empty or of shape (num_faces, 3)." + raise ValueError(message) + new_f = False if isinstance(f, str): new_f = True