diff --git a/pytorch3d/io/ply_io.py b/pytorch3d/io/ply_io.py index 128e6f55..29c846e7 100644 --- a/pytorch3d/io/ply_io.py +++ b/pytorch3d/io/ply_io.py @@ -1238,10 +1238,14 @@ class MeshPlyFormat(MeshFormatInterpreter): if not endswith(path, self.known_suffixes): return False - # TODO: normals are not saved. We only want to save them if they already exist. verts = data.verts_list()[0] faces = data.faces_list()[0] + if data.has_verts_normals(): + verts_normals = data.verts_normals_list()[0] + else: + verts_normals = None + if isinstance(data.textures, TexturesVertex): mesh_verts_colors = data.textures.verts_features_list()[0] n_colors = mesh_verts_colors.shape[1] @@ -1261,7 +1265,7 @@ class MeshPlyFormat(MeshFormatInterpreter): verts=verts, faces=faces, verts_colors=verts_colors, - verts_normals=None, + verts_normals=verts_normals, ascii=binary is False, decimal_places=decimal_places, ) @@ -1304,13 +1308,14 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter): points = data.points_list()[0] features = data.features_packed() + normals = data.normals_packed() with _open_file(path, path_manager, "wb") as f: _save_ply( f=f, verts=points, verts_colors=features, - verts_normals=None, + verts_normals=normals, faces=None, ascii=binary is False, decimal_places=decimal_places,