diff --git a/pytorch3d/io/ply_io.py b/pytorch3d/io/ply_io.py index 54abb7a5..cbee7afb 100644 --- a/pytorch3d/io/ply_io.py +++ b/pytorch3d/io/ply_io.py @@ -8,6 +8,7 @@ import struct import sys import warnings from collections import namedtuple +from io import BytesIO from typing import Optional, Tuple import numpy as np @@ -386,11 +387,18 @@ def _read_ply_fixed_size_element_binary( np_type = ply_type.np_type type_size = ply_type.size needed_length = definition.count * len(definition.properties) - needed_bytes = needed_length * type_size - bytes_data = f.read(needed_bytes) - if len(bytes_data) != needed_bytes: - raise ValueError("Not enough data for %s." % definition.name) - data = np.frombuffer(bytes_data, dtype=np_type) + if isinstance(f, BytesIO): + # np.fromfile is faster but won't work on a BytesIO + needed_bytes = needed_length * type_size + bytes_data = bytearray(needed_bytes) + n_bytes_read = f.readinto(bytes_data) + if n_bytes_read != needed_bytes: + raise ValueError("Not enough data for %s." % definition.name) + data = np.frombuffer(bytes_data, dtype=np_type) + else: + data = np.fromfile(f, dtype=np_type, count=needed_length) + if data.shape[0] != needed_length: + raise ValueError("Not enough data for %s." % definition.name) if (sys.byteorder == "big") != big_endian: data = data.byteswap() @@ -447,6 +455,8 @@ def _try_read_ply_constant_list_binary( If every element has the same size, 2D numpy array corresponding to the data. The rows are the different values. Otherwise None. """ + if definition.count == 0: + return [] property = definition.properties[0] endian_str = ">" if big_endian else "<" length_format = endian_str + _PLY_TYPES[property.list_size_type].struct_char @@ -689,6 +699,7 @@ def _save_ply( verts: torch.Tensor, faces: torch.LongTensor, verts_normals: torch.Tensor, + ascii: bool, decimal_places: Optional[int] = None, ) -> None: """ @@ -699,7 +710,8 @@ def _save_ply( verts: FloatTensor of shape (V, 3) giving vertex coordinates. faces: LongTensor of shsape (F, 3) giving faces. verts_normals: FloatTensor of shape (V, 3) giving vertex normals. - decimal_places: Number of decimal places for saving. + ascii: (bool) whether to use the ascii ply format. + decimal_places: Number of decimal places for saving if ascii=True. """ assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3) assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3) @@ -707,37 +719,58 @@ def _save_ply( verts_normals.dim() == 2 and verts_normals.size(1) == 3 ) - print("ply\nformat ascii 1.0", file=f) - print(f"element vertex {verts.shape[0]}", file=f) - print("property float x", file=f) - print("property float y", file=f) - print("property float z", file=f) + if ascii: + f.write(b"ply\nformat ascii 1.0\n") + elif sys.byteorder == "big": + f.write(b"ply\nformat binary_big_endian 1.0\n") + else: + f.write(b"ply\nformat binary_little_endian 1.0\n") + f.write(f"element vertex {verts.shape[0]}\n".encode("ascii")) + f.write(b"property float x\n") + f.write(b"property float y\n") + f.write(b"property float z\n") if verts_normals.numel() > 0: - print("property float nx", file=f) - print("property float ny", file=f) - print("property float nz", file=f) - print(f"element face {faces.shape[0]}", file=f) - print("property list uchar int vertex_index", file=f) - print("end_header", file=f) + f.write(b"property float nx\n") + f.write(b"property float ny\n") + f.write(b"property float nz\n") + f.write(f"element face {faces.shape[0]}\n".encode("ascii")) + f.write(b"property list uchar int vertex_index\n") + f.write(b"end_header\n") if not (len(verts) or len(faces)): warnings.warn("Empty 'verts' and 'faces' arguments provided") return - if decimal_places is None: - float_str = "%f" + vert_data = torch.cat((verts, verts_normals), dim=1).detach().numpy() + if ascii: + if decimal_places is None: + float_str = "%f" + else: + float_str = "%" + ".%df" % decimal_places + np.savetxt(f, vert_data, float_str) else: - float_str = "%" + ".%df" % decimal_places - - vert_data = torch.cat((verts, verts_normals), dim=1) - np.savetxt(f, vert_data.detach().numpy(), float_str) + assert vert_data.dtype == np.float32 + if isinstance(f, BytesIO): + # tofile only works with real files, but is faster than this. + f.write(vert_data.tobytes()) + else: + vert_data.tofile(f) faces_array = faces.detach().numpy() _check_faces_indices(faces, max_index=verts.shape[0]) if len(faces_array): - np.savetxt(f, faces_array, "3 %d %d %d") + if ascii: + np.savetxt(f, faces_array, "3 %d %d %d") + else: + # rows are 13 bytes: a one-byte 3 followed by three four-byte face indices. + faces_uints = np.full((len(faces_array), 13), 3, dtype=np.uint8) + faces_uints[:, 1:] = faces_array.astype(np.uint32).view(np.uint8) + if isinstance(f, BytesIO): + f.write(faces_uints.tobytes()) + else: + faces_uints.tofile(f) def save_ply( @@ -745,6 +778,7 @@ def save_ply( verts: torch.Tensor, faces: Optional[torch.LongTensor] = None, verts_normals: Optional[torch.Tensor] = None, + ascii: bool = False, decimal_places: Optional[int] = None, ) -> None: """ @@ -755,7 +789,8 @@ def save_ply( verts: FloatTensor of shape (V, 3) giving vertex coordinates. faces: LongTensor of shape (F, 3) giving faces. verts_normals: FloatTensor of shape (V, 3) giving vertex normals. - decimal_places: Number of decimal places for saving. + ascii: (bool) whether to use the ascii ply format. + decimal_places: Number of decimal places for saving if ascii=True. """ verts_normals = ( @@ -781,5 +816,5 @@ def save_ply( message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)." raise ValueError(message) - with _open_file(f, "w") as f: - _save_ply(f, verts, faces, verts_normals, decimal_places) + with _open_file(f, "wb") as f: + _save_ply(f, verts, faces, verts_normals, ascii, decimal_places) diff --git a/tests/test_ply_io.py b/tests/test_ply_io.py index b270803d..408311bf 100644 --- a/tests/test_ply_io.py +++ b/tests/test_ply_io.py @@ -3,6 +3,7 @@ import struct import unittest from io import BytesIO, StringIO +from tempfile import TemporaryFile import torch from common_testing import TestCaseMixin @@ -144,7 +145,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): with self.assertRaises(ValueError) as error: verts = torch.FloatTensor([[0.1, 0.2, 0.3, 0.4]]) # (V, 4) faces = torch.LongTensor([[0, 1, 2]]) - save_ply(StringIO(), verts, faces) + save_ply(BytesIO(), verts, faces) expected_message = ( "Argument 'verts' should either be empty or of shape (num_verts, 3)." ) @@ -154,7 +155,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): with self.assertRaises(ValueError) as error: verts = torch.FloatTensor([[0.1, 0.2, 0.3]]) faces = torch.LongTensor([[0, 1, 2, 3]]) # (F, 4) - save_ply(StringIO(), verts, faces) + save_ply(BytesIO(), verts, faces) expected_message = ( "Argument 'faces' should either be empty or of shape (num_faces, 3)." ) @@ -165,14 +166,14 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): verts = torch.FloatTensor([[0.1, 0.2, 0.3]]) faces = torch.LongTensor([[0, 1, 2]]) with self.assertWarnsRegex(UserWarning, message_regex): - save_ply(StringIO(), verts, faces) + save_ply(BytesIO(), verts, faces) faces = torch.LongTensor([[-1, 0, 1]]) with self.assertWarnsRegex(UserWarning, message_regex): - save_ply(StringIO(), verts, faces) + save_ply(BytesIO(), verts, faces) def _test_save_load(self, verts, faces): - f = StringIO() + f = BytesIO() save_ply(f, verts, faces) f.seek(0) # raise Exception(f.getvalue()) @@ -193,7 +194,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): normals = torch.tensor( [[0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float32 ) - file = StringIO() + file = BytesIO() save_ply(file, verts=verts, faces=faces, verts_normals=normals) file.close() @@ -237,15 +238,31 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): def test_simple_save(self): verts = torch.tensor( - [[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32 + [[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 0]], dtype=torch.float32 ) faces = torch.tensor([[0, 1, 2], [0, 3, 4]]) - file = StringIO() - save_ply(file, verts=verts, faces=faces) - file.seek(0) - verts2, faces2 = load_ply(file) - self.assertClose(verts, verts2) - self.assertClose(faces, faces2) + for filetype in BytesIO, TemporaryFile: + lengths = {} + for ascii in [True, False]: + file = filetype() + save_ply(file, verts=verts, faces=faces, ascii=ascii) + lengths[ascii] = file.tell() + + file.seek(0) + verts2, faces2 = load_ply(file) + self.assertClose(verts, verts2) + self.assertClose(faces, faces2) + + file.seek(0) + if ascii: + file.read().decode("ascii") + else: + with self.assertRaises(UnicodeDecodeError): + file.read().decode("ascii") + + if filetype is TemporaryFile: + file.close() + self.assertLess(lengths[False], lengths[True], "ascii should be longer") def test_load_simple_binary(self): for big_endian in [True, False]: @@ -488,15 +505,21 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): @staticmethod def _bm_save_ply(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int): - return lambda: save_ply(StringIO(), verts, faces, decimal_places=decimal_places) + return lambda: save_ply( + BytesIO(), + verts=verts, + faces=faces, + ascii=True, + decimal_places=decimal_places, + ) @staticmethod def _bm_load_ply(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int): - f = StringIO() - save_ply(f, verts, faces, decimal_places) + f = BytesIO() + save_ply(f, verts=verts, faces=faces, ascii=True, decimal_places=decimal_places) s = f.getvalue() # Recreate stream so it's unaffected by how it was created. - return lambda: load_ply(StringIO(s)) + return lambda: load_ply(BytesIO(s)) @staticmethod def bm_save_simple_ply_with_init(V: int, F: int):