save_ply binary

Summary:
Make save_ply save to binary instead of ascii. An option makes the previous functionality available. save_ply's API accepts a stream, but this is undocumented; that stream must now be a binary stream not a text stream.

Avoiding warnings about making tensors from immutable numpy arrays.

Possible performance improvement when reading binary files.

Fix reading zero-length binary lists.

Reviewed By: nikhilaravi

Differential Revision: D22333118

fbshipit-source-id: b423dfd3da46e047bead200255f47a7707306811
This commit is contained in:
Jeremy Reizenstein 2020-09-21 05:14:59 -07:00 committed by Facebook GitHub Bot
parent ebe2693b11
commit 197f1d6217
2 changed files with 102 additions and 44 deletions

View File

@ -8,6 +8,7 @@ import struct
import sys import sys
import warnings import warnings
from collections import namedtuple from collections import namedtuple
from io import BytesIO
from typing import Optional, Tuple from typing import Optional, Tuple
import numpy as np import numpy as np
@ -386,11 +387,18 @@ def _read_ply_fixed_size_element_binary(
np_type = ply_type.np_type np_type = ply_type.np_type
type_size = ply_type.size type_size = ply_type.size
needed_length = definition.count * len(definition.properties) needed_length = definition.count * len(definition.properties)
needed_bytes = needed_length * type_size if isinstance(f, BytesIO):
bytes_data = f.read(needed_bytes) # np.fromfile is faster but won't work on a BytesIO
if len(bytes_data) != needed_bytes: needed_bytes = needed_length * type_size
raise ValueError("Not enough data for %s." % definition.name) bytes_data = bytearray(needed_bytes)
data = np.frombuffer(bytes_data, dtype=np_type) n_bytes_read = f.readinto(bytes_data)
if n_bytes_read != needed_bytes:
raise ValueError("Not enough data for %s." % definition.name)
data = np.frombuffer(bytes_data, dtype=np_type)
else:
data = np.fromfile(f, dtype=np_type, count=needed_length)
if data.shape[0] != needed_length:
raise ValueError("Not enough data for %s." % definition.name)
if (sys.byteorder == "big") != big_endian: if (sys.byteorder == "big") != big_endian:
data = data.byteswap() data = data.byteswap()
@ -447,6 +455,8 @@ def _try_read_ply_constant_list_binary(
If every element has the same size, 2D numpy array corresponding to the If every element has the same size, 2D numpy array corresponding to the
data. The rows are the different values. Otherwise None. data. The rows are the different values. Otherwise None.
""" """
if definition.count == 0:
return []
property = definition.properties[0] property = definition.properties[0]
endian_str = ">" if big_endian else "<" endian_str = ">" if big_endian else "<"
length_format = endian_str + _PLY_TYPES[property.list_size_type].struct_char length_format = endian_str + _PLY_TYPES[property.list_size_type].struct_char
@ -689,6 +699,7 @@ def _save_ply(
verts: torch.Tensor, verts: torch.Tensor,
faces: torch.LongTensor, faces: torch.LongTensor,
verts_normals: torch.Tensor, verts_normals: torch.Tensor,
ascii: bool,
decimal_places: Optional[int] = None, decimal_places: Optional[int] = None,
) -> None: ) -> None:
""" """
@ -699,7 +710,8 @@ def _save_ply(
verts: FloatTensor of shape (V, 3) giving vertex coordinates. verts: FloatTensor of shape (V, 3) giving vertex coordinates.
faces: LongTensor of shsape (F, 3) giving faces. faces: LongTensor of shsape (F, 3) giving faces.
verts_normals: FloatTensor of shape (V, 3) giving vertex normals. verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
decimal_places: Number of decimal places for saving. ascii: (bool) whether to use the ascii ply format.
decimal_places: Number of decimal places for saving if ascii=True.
""" """
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3) assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3) assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
@ -707,37 +719,58 @@ def _save_ply(
verts_normals.dim() == 2 and verts_normals.size(1) == 3 verts_normals.dim() == 2 and verts_normals.size(1) == 3
) )
print("ply\nformat ascii 1.0", file=f) if ascii:
print(f"element vertex {verts.shape[0]}", file=f) f.write(b"ply\nformat ascii 1.0\n")
print("property float x", file=f) elif sys.byteorder == "big":
print("property float y", file=f) f.write(b"ply\nformat binary_big_endian 1.0\n")
print("property float z", file=f) else:
f.write(b"ply\nformat binary_little_endian 1.0\n")
f.write(f"element vertex {verts.shape[0]}\n".encode("ascii"))
f.write(b"property float x\n")
f.write(b"property float y\n")
f.write(b"property float z\n")
if verts_normals.numel() > 0: if verts_normals.numel() > 0:
print("property float nx", file=f) f.write(b"property float nx\n")
print("property float ny", file=f) f.write(b"property float ny\n")
print("property float nz", file=f) f.write(b"property float nz\n")
print(f"element face {faces.shape[0]}", file=f) f.write(f"element face {faces.shape[0]}\n".encode("ascii"))
print("property list uchar int vertex_index", file=f) f.write(b"property list uchar int vertex_index\n")
print("end_header", file=f) f.write(b"end_header\n")
if not (len(verts) or len(faces)): if not (len(verts) or len(faces)):
warnings.warn("Empty 'verts' and 'faces' arguments provided") warnings.warn("Empty 'verts' and 'faces' arguments provided")
return return
if decimal_places is None: vert_data = torch.cat((verts, verts_normals), dim=1).detach().numpy()
float_str = "%f" if ascii:
if decimal_places is None:
float_str = "%f"
else:
float_str = "%" + ".%df" % decimal_places
np.savetxt(f, vert_data, float_str)
else: else:
float_str = "%" + ".%df" % decimal_places assert vert_data.dtype == np.float32
if isinstance(f, BytesIO):
vert_data = torch.cat((verts, verts_normals), dim=1) # tofile only works with real files, but is faster than this.
np.savetxt(f, vert_data.detach().numpy(), float_str) f.write(vert_data.tobytes())
else:
vert_data.tofile(f)
faces_array = faces.detach().numpy() faces_array = faces.detach().numpy()
_check_faces_indices(faces, max_index=verts.shape[0]) _check_faces_indices(faces, max_index=verts.shape[0])
if len(faces_array): if len(faces_array):
np.savetxt(f, faces_array, "3 %d %d %d") if ascii:
np.savetxt(f, faces_array, "3 %d %d %d")
else:
# rows are 13 bytes: a one-byte 3 followed by three four-byte face indices.
faces_uints = np.full((len(faces_array), 13), 3, dtype=np.uint8)
faces_uints[:, 1:] = faces_array.astype(np.uint32).view(np.uint8)
if isinstance(f, BytesIO):
f.write(faces_uints.tobytes())
else:
faces_uints.tofile(f)
def save_ply( def save_ply(
@ -745,6 +778,7 @@ def save_ply(
verts: torch.Tensor, verts: torch.Tensor,
faces: Optional[torch.LongTensor] = None, faces: Optional[torch.LongTensor] = None,
verts_normals: Optional[torch.Tensor] = None, verts_normals: Optional[torch.Tensor] = None,
ascii: bool = False,
decimal_places: Optional[int] = None, decimal_places: Optional[int] = None,
) -> None: ) -> None:
""" """
@ -755,7 +789,8 @@ def save_ply(
verts: FloatTensor of shape (V, 3) giving vertex coordinates. verts: FloatTensor of shape (V, 3) giving vertex coordinates.
faces: LongTensor of shape (F, 3) giving faces. faces: LongTensor of shape (F, 3) giving faces.
verts_normals: FloatTensor of shape (V, 3) giving vertex normals. verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
decimal_places: Number of decimal places for saving. ascii: (bool) whether to use the ascii ply format.
decimal_places: Number of decimal places for saving if ascii=True.
""" """
verts_normals = ( verts_normals = (
@ -781,5 +816,5 @@ def save_ply(
message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)." message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)."
raise ValueError(message) raise ValueError(message)
with _open_file(f, "w") as f: with _open_file(f, "wb") as f:
_save_ply(f, verts, faces, verts_normals, decimal_places) _save_ply(f, verts, faces, verts_normals, ascii, decimal_places)

View File

@ -3,6 +3,7 @@
import struct import struct
import unittest import unittest
from io import BytesIO, StringIO from io import BytesIO, StringIO
from tempfile import TemporaryFile
import torch import torch
from common_testing import TestCaseMixin from common_testing import TestCaseMixin
@ -144,7 +145,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
with self.assertRaises(ValueError) as error: with self.assertRaises(ValueError) as error:
verts = torch.FloatTensor([[0.1, 0.2, 0.3, 0.4]]) # (V, 4) verts = torch.FloatTensor([[0.1, 0.2, 0.3, 0.4]]) # (V, 4)
faces = torch.LongTensor([[0, 1, 2]]) faces = torch.LongTensor([[0, 1, 2]])
save_ply(StringIO(), verts, faces) save_ply(BytesIO(), verts, faces)
expected_message = ( expected_message = (
"Argument 'verts' should either be empty or of shape (num_verts, 3)." "Argument 'verts' should either be empty or of shape (num_verts, 3)."
) )
@ -154,7 +155,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
with self.assertRaises(ValueError) as error: with self.assertRaises(ValueError) as error:
verts = torch.FloatTensor([[0.1, 0.2, 0.3]]) verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
faces = torch.LongTensor([[0, 1, 2, 3]]) # (F, 4) faces = torch.LongTensor([[0, 1, 2, 3]]) # (F, 4)
save_ply(StringIO(), verts, faces) save_ply(BytesIO(), verts, faces)
expected_message = ( expected_message = (
"Argument 'faces' should either be empty or of shape (num_faces, 3)." "Argument 'faces' should either be empty or of shape (num_faces, 3)."
) )
@ -165,14 +166,14 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
verts = torch.FloatTensor([[0.1, 0.2, 0.3]]) verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
faces = torch.LongTensor([[0, 1, 2]]) faces = torch.LongTensor([[0, 1, 2]])
with self.assertWarnsRegex(UserWarning, message_regex): with self.assertWarnsRegex(UserWarning, message_regex):
save_ply(StringIO(), verts, faces) save_ply(BytesIO(), verts, faces)
faces = torch.LongTensor([[-1, 0, 1]]) faces = torch.LongTensor([[-1, 0, 1]])
with self.assertWarnsRegex(UserWarning, message_regex): with self.assertWarnsRegex(UserWarning, message_regex):
save_ply(StringIO(), verts, faces) save_ply(BytesIO(), verts, faces)
def _test_save_load(self, verts, faces): def _test_save_load(self, verts, faces):
f = StringIO() f = BytesIO()
save_ply(f, verts, faces) save_ply(f, verts, faces)
f.seek(0) f.seek(0)
# raise Exception(f.getvalue()) # raise Exception(f.getvalue())
@ -193,7 +194,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
normals = torch.tensor( normals = torch.tensor(
[[0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float32 [[0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float32
) )
file = StringIO() file = BytesIO()
save_ply(file, verts=verts, faces=faces, verts_normals=normals) save_ply(file, verts=verts, faces=faces, verts_normals=normals)
file.close() file.close()
@ -237,15 +238,31 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
def test_simple_save(self): def test_simple_save(self):
verts = torch.tensor( verts = torch.tensor(
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32 [[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 0]], dtype=torch.float32
) )
faces = torch.tensor([[0, 1, 2], [0, 3, 4]]) faces = torch.tensor([[0, 1, 2], [0, 3, 4]])
file = StringIO() for filetype in BytesIO, TemporaryFile:
save_ply(file, verts=verts, faces=faces) lengths = {}
file.seek(0) for ascii in [True, False]:
verts2, faces2 = load_ply(file) file = filetype()
self.assertClose(verts, verts2) save_ply(file, verts=verts, faces=faces, ascii=ascii)
self.assertClose(faces, faces2) lengths[ascii] = file.tell()
file.seek(0)
verts2, faces2 = load_ply(file)
self.assertClose(verts, verts2)
self.assertClose(faces, faces2)
file.seek(0)
if ascii:
file.read().decode("ascii")
else:
with self.assertRaises(UnicodeDecodeError):
file.read().decode("ascii")
if filetype is TemporaryFile:
file.close()
self.assertLess(lengths[False], lengths[True], "ascii should be longer")
def test_load_simple_binary(self): def test_load_simple_binary(self):
for big_endian in [True, False]: for big_endian in [True, False]:
@ -488,15 +505,21 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
@staticmethod @staticmethod
def _bm_save_ply(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int): def _bm_save_ply(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):
return lambda: save_ply(StringIO(), verts, faces, decimal_places=decimal_places) return lambda: save_ply(
BytesIO(),
verts=verts,
faces=faces,
ascii=True,
decimal_places=decimal_places,
)
@staticmethod @staticmethod
def _bm_load_ply(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int): def _bm_load_ply(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):
f = StringIO() f = BytesIO()
save_ply(f, verts, faces, decimal_places) save_ply(f, verts=verts, faces=faces, ascii=True, decimal_places=decimal_places)
s = f.getvalue() s = f.getvalue()
# Recreate stream so it's unaffected by how it was created. # Recreate stream so it's unaffected by how it was created.
return lambda: load_ply(StringIO(s)) return lambda: load_ply(BytesIO(s))
@staticmethod @staticmethod
def bm_save_simple_ply_with_init(V: int, F: int): def bm_save_simple_ply_with_init(V: int, F: int):