diff --git a/pytorch3d/io/ply_io.py b/pytorch3d/io/ply_io.py index 9fe557ba..5208d88b 100644 --- a/pytorch3d/io/ply_io.py +++ b/pytorch3d/io/ply_io.py @@ -20,6 +20,7 @@ import numpy as np import torch from iopath.common.file_io import PathManager from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file +from pytorch3d.renderer import TexturesVertex from pytorch3d.structures import Meshes, Pointclouds from .pluggable_formats import ( @@ -66,7 +67,7 @@ class _PlyElementType: def __init__(self, name: str, count: int): self.name = name self.count = count - self.properties = [] + self.properties: List[_Property] = [] def add_property( self, name: str, data_type: str, list_size_type: Optional[str] = None @@ -142,7 +143,7 @@ class _PlyHeader: if f.readline() not in [b"ply\n", b"ply\r\n", "ply\n"]: raise ValueError("Invalid file header.") seen_format = False - self.elements = [] + self.elements: List[_PlyElementType] = [] self.obj_info = [] while True: line = f.readline() @@ -891,8 +892,8 @@ def _get_verts( def _load_ply( - f, *, path_manager: PathManager, return_vertex_colors: bool = False -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + f, *, path_manager: PathManager +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ Load the data from a .ply file. @@ -903,12 +904,11 @@ def _load_ply( ply format, then a text stream is not supported. It is easiest to use a binary stream in all cases. path_manager: PathManager for loading if f is a str. - return_vertex_colors: whether to return vertex colors. Returns: verts: FloatTensor of shape (V, 3). faces: None or LongTensor of vertex indices, shape (F, 3). - vertex_colors: None or FloatTensor of shape (V, 3), only if requested + vertex_colors: None or FloatTensor of shape (V, 3). """ header, elements = _load_ply_raw(f, path_manager=path_manager) @@ -950,16 +950,17 @@ def _load_ply( if faces is not None: _check_faces_indices(faces, max_index=verts.shape[0]) - if return_vertex_colors: - return verts, faces, vertex_colors - return verts, faces, None + return verts, faces, vertex_colors def load_ply( f, *, path_manager: Optional[PathManager] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Load the data from a .ply file. + Load the verts and faces from a .ply file. + Note that the preferred way to load data from such a file + is to use the IO.load_mesh and IO.load_pointcloud functions, + which can read more of the data. Example .ply file format: @@ -1016,8 +1017,8 @@ def _save_ply( *, verts: torch.Tensor, faces: Optional[torch.LongTensor], - verts_normals: torch.Tensor, - verts_colors: torch.Tensor, + verts_normals: Optional[torch.Tensor], + verts_colors: Optional[torch.Tensor], ascii: bool, decimal_places: Optional[int] = None, ) -> None: @@ -1029,16 +1030,16 @@ def _save_ply( verts: FloatTensor of shape (V, 3) giving vertex coordinates. faces: LongTensor of shape (F, 3) giving faces. verts_normals: FloatTensor of shape (V, 3) giving vertex normals. + verts_colors: FloatTensor of shape (V, 3) giving vertex colors. ascii: (bool) whether to use the ascii ply format. decimal_places: Number of decimal places for saving if ascii=True. """ assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3) - if faces is not None: - assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3) - assert not len(verts_normals) or ( + assert faces is None or not len(faces) or (faces.dim() == 2 and faces.size(1) == 3) + assert verts_normals is None or ( verts_normals.dim() == 2 and verts_normals.size(1) == 3 ) - assert not len(verts_colors) or ( + assert verts_colors is None or ( verts_colors.dim() == 2 and verts_colors.size(1) == 3 ) @@ -1052,11 +1053,11 @@ def _save_ply( f.write(b"property float x\n") f.write(b"property float y\n") f.write(b"property float z\n") - if verts_normals.numel() > 0: + if verts_normals is not None: f.write(b"property float nx\n") f.write(b"property float ny\n") f.write(b"property float nz\n") - if verts_colors.numel() > 0: + if verts_colors is not None: f.write(b"property float red\n") f.write(b"property float green\n") f.write(b"property float blue\n") @@ -1069,7 +1070,13 @@ def _save_ply( warnings.warn("Empty 'verts' provided") return - vert_data = torch.cat((verts, verts_normals, verts_colors), dim=1).detach().numpy() + verts_tensors = [verts] + if verts_normals is not None: + verts_tensors.append(verts_normals) + if verts_colors is not None: + verts_tensors.append(verts_colors) + + vert_data = torch.cat(verts_tensors, dim=1).detach().cpu().numpy() if ascii: if decimal_places is None: float_str = "%f" @@ -1085,7 +1092,7 @@ def _save_ply( vert_data.tofile(f) if faces is not None: - faces_array = faces.detach().numpy() + faces_array = faces.detach().cpu().numpy() _check_faces_indices(faces, max_index=verts.shape[0]) @@ -1125,12 +1132,6 @@ def save_ply( """ - verts_normals = ( - torch.tensor([], dtype=torch.float32, device=verts.device) - if verts_normals is None - else verts_normals - ) - if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3): message = "Argument 'verts' should either be empty or of shape (num_verts, 3)." raise ValueError(message) @@ -1143,16 +1144,18 @@ def save_ply( message = "Argument 'faces' should either be empty or of shape (num_faces, 3)." raise ValueError(message) - if len(verts_normals) and not ( - verts_normals.dim() == 2 - and verts_normals.size(1) == 3 - and verts_normals.size(0) == verts.size(0) + if ( + verts_normals is not None + and len(verts_normals) + and not ( + verts_normals.dim() == 2 + and verts_normals.size(1) == 3 + and verts_normals.size(0) == verts.size(0) + ) ): message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)." raise ValueError(message) - verts_colors = torch.FloatTensor([]) - if path_manager is None: path_manager = PathManager() with _open_file(f, path_manager, "wb") as f: @@ -1161,7 +1164,7 @@ def save_ply( verts=verts, faces=faces, verts_normals=verts_normals, - verts_colors=verts_colors, + verts_colors=None, ascii=ascii, decimal_places=decimal_places, ) @@ -1182,8 +1185,19 @@ class MeshPlyFormat(MeshFormatInterpreter): if not endswith(path, self.known_suffixes): return None - verts, faces = load_ply(f=path, path_manager=path_manager) - mesh = Meshes(verts=[verts.to(device)], faces=[faces.to(device)]) + verts, faces, verts_colors = _load_ply(f=path, path_manager=path_manager) + if faces is None: + faces = torch.zeros(0, 3, dtype=torch.int64) + + textures = None + if include_textures and verts_colors is not None: + textures = TexturesVertex([verts_colors.to(device)]) + + mesh = Meshes( + verts=[verts.to(device)], + faces=[faces.to(device)], + textures=textures, + ) return mesh def save( @@ -1201,14 +1215,30 @@ class MeshPlyFormat(MeshFormatInterpreter): # TODO: normals are not saved. We only want to save them if they already exist. verts = data.verts_list()[0] faces = data.faces_list()[0] - save_ply( - f=path, - verts=verts, - faces=faces, - ascii=binary is False, - decimal_places=decimal_places, - path_manager=path_manager, - ) + + if isinstance(data.textures, TexturesVertex): + mesh_verts_colors = data.textures.verts_features_list()[0] + n_colors = mesh_verts_colors.shape[1] + if n_colors == 3: + verts_colors = mesh_verts_colors + else: + warnings.warn( + f"Texture will not be saved as it has {n_colors} colors, not 3." + ) + verts_colors = None + else: + verts_colors = None + + with _open_file(path, path_manager, "wb") as f: + _save_ply( + f=f, + verts=verts, + faces=faces, + verts_colors=verts_colors, + verts_normals=None, + ascii=binary is False, + decimal_places=decimal_places, + ) return True @@ -1226,14 +1256,12 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter): if not endswith(path, self.known_suffixes): return None - verts, faces, features = _load_ply( - f=path, path_manager=path_manager, return_vertex_colors=True - ) + verts, faces, features = _load_ply(f=path, path_manager=path_manager) verts = verts.to(device) - if features is None: - pointcloud = Pointclouds(points=[verts]) - else: - pointcloud = Pointclouds(points=[verts], features=[features.to(device)]) + if features is not None: + features = [features.to(device)] + + pointcloud = Pointclouds(points=[verts], features=features) return pointcloud def save( @@ -1249,13 +1277,14 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter): return False points = data.points_list()[0] - features = data.features_list()[0] + features = data.features_packed() + with _open_file(path, path_manager, "wb") as f: _save_ply( f=f, verts=points, verts_colors=features, - verts_normals=torch.FloatTensor([]), + verts_normals=None, faces=None, ascii=binary is False, decimal_places=decimal_places, diff --git a/tests/test_io_ply.py b/tests/test_io_ply.py index 90efb1ff..bfbb0e44 100644 --- a/tests/test_io_ply.py +++ b/tests/test_io_ply.py @@ -1,5 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import itertools import struct import unittest from io import BytesIO, StringIO @@ -12,7 +13,8 @@ from common_testing import TestCaseMixin from iopath.common.file_io import PathManager from pytorch3d.io import IO from pytorch3d.io.ply_io import load_ply, save_ply -from pytorch3d.structures import Pointclouds +from pytorch3d.renderer.mesh import TexturesVertex +from pytorch3d.structures import Meshes, Pointclouds from pytorch3d.utils import torus @@ -189,6 +191,57 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): ): io.load_mesh(f3.name) + def test_save_too_many_colors(self): + verts = torch.tensor( + [[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32 + ) + faces = torch.tensor([[0, 1, 2], [0, 2, 3]]) + vert_colors = torch.rand((4, 7)) + texture_with_seven_colors = TexturesVertex(verts_features=[vert_colors]) + + mesh = Meshes( + verts=[verts], + faces=[faces], + textures=texture_with_seven_colors, + ) + + io = IO() + msg = "Texture will not be saved as it has 7 colors, not 3." + with NamedTemporaryFile(mode="w", suffix=".ply") as f: + with self.assertWarnsRegex(UserWarning, msg): + io.save_mesh(mesh.cuda(), f.name) + + def test_save_load_meshes(self): + verts = torch.tensor( + [[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32 + ) + faces = torch.tensor([[0, 1, 2], [0, 2, 3]]) + vert_colors = torch.rand_like(verts) + texture = TexturesVertex(verts_features=[vert_colors]) + + for do_textures in itertools.product([True, False]): + mesh = Meshes( + verts=[verts], + faces=[faces], + textures=texture if do_textures else None, + ) + device = torch.device("cuda:0") + + io = IO() + with NamedTemporaryFile(mode="w", suffix=".ply") as f: + io.save_mesh(mesh.cuda(), f.name) + f.flush() + mesh2 = io.load_mesh(f.name, device=device) + self.assertEqual(mesh2.device, device) + mesh2 = mesh2.cpu() + self.assertClose(mesh2.verts_padded(), mesh.verts_padded()) + self.assertClose(mesh2.faces_padded(), mesh.faces_padded()) + if do_textures: + self.assertIsInstance(mesh2.textures, TexturesVertex) + self.assertClose(mesh2.textures.verts_features_list()[0], vert_colors) + else: + self.assertIsNone(mesh2.textures) + def test_save_ply_invalid_shapes(self): # Invalid vertices shape with self.assertRaises(ValueError) as error: