save_ply binary

Summary:
Make save_ply save to binary instead of ascii. An option makes the previous functionality available. save_ply's API accepts a stream, but this is undocumented; that stream must now be a binary stream not a text stream.

Avoiding warnings about making tensors from immutable numpy arrays.

Possible performance improvement when reading binary files.

Fix reading zero-length binary lists.

Reviewed By: nikhilaravi

Differential Revision: D22333118

fbshipit-source-id: b423dfd3da46e047bead200255f47a7707306811
This commit is contained in:
Jeremy Reizenstein 2020-09-21 05:14:59 -07:00 committed by Facebook GitHub Bot
parent ebe2693b11
commit 197f1d6217
2 changed files with 102 additions and 44 deletions

View File

@ -8,6 +8,7 @@ import struct
import sys
import warnings
from collections import namedtuple
from io import BytesIO
from typing import Optional, Tuple
import numpy as np
@ -386,11 +387,18 @@ def _read_ply_fixed_size_element_binary(
np_type = ply_type.np_type
type_size = ply_type.size
needed_length = definition.count * len(definition.properties)
needed_bytes = needed_length * type_size
bytes_data = f.read(needed_bytes)
if len(bytes_data) != needed_bytes:
raise ValueError("Not enough data for %s." % definition.name)
data = np.frombuffer(bytes_data, dtype=np_type)
if isinstance(f, BytesIO):
# np.fromfile is faster but won't work on a BytesIO
needed_bytes = needed_length * type_size
bytes_data = bytearray(needed_bytes)
n_bytes_read = f.readinto(bytes_data)
if n_bytes_read != needed_bytes:
raise ValueError("Not enough data for %s." % definition.name)
data = np.frombuffer(bytes_data, dtype=np_type)
else:
data = np.fromfile(f, dtype=np_type, count=needed_length)
if data.shape[0] != needed_length:
raise ValueError("Not enough data for %s." % definition.name)
if (sys.byteorder == "big") != big_endian:
data = data.byteswap()
@ -447,6 +455,8 @@ def _try_read_ply_constant_list_binary(
If every element has the same size, 2D numpy array corresponding to the
data. The rows are the different values. Otherwise None.
"""
if definition.count == 0:
return []
property = definition.properties[0]
endian_str = ">" if big_endian else "<"
length_format = endian_str + _PLY_TYPES[property.list_size_type].struct_char
@ -689,6 +699,7 @@ def _save_ply(
verts: torch.Tensor,
faces: torch.LongTensor,
verts_normals: torch.Tensor,
ascii: bool,
decimal_places: Optional[int] = None,
) -> None:
"""
@ -699,7 +710,8 @@ def _save_ply(
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
faces: LongTensor of shsape (F, 3) giving faces.
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
decimal_places: Number of decimal places for saving.
ascii: (bool) whether to use the ascii ply format.
decimal_places: Number of decimal places for saving if ascii=True.
"""
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
@ -707,37 +719,58 @@ def _save_ply(
verts_normals.dim() == 2 and verts_normals.size(1) == 3
)
print("ply\nformat ascii 1.0", file=f)
print(f"element vertex {verts.shape[0]}", file=f)
print("property float x", file=f)
print("property float y", file=f)
print("property float z", file=f)
if ascii:
f.write(b"ply\nformat ascii 1.0\n")
elif sys.byteorder == "big":
f.write(b"ply\nformat binary_big_endian 1.0\n")
else:
f.write(b"ply\nformat binary_little_endian 1.0\n")
f.write(f"element vertex {verts.shape[0]}\n".encode("ascii"))
f.write(b"property float x\n")
f.write(b"property float y\n")
f.write(b"property float z\n")
if verts_normals.numel() > 0:
print("property float nx", file=f)
print("property float ny", file=f)
print("property float nz", file=f)
print(f"element face {faces.shape[0]}", file=f)
print("property list uchar int vertex_index", file=f)
print("end_header", file=f)
f.write(b"property float nx\n")
f.write(b"property float ny\n")
f.write(b"property float nz\n")
f.write(f"element face {faces.shape[0]}\n".encode("ascii"))
f.write(b"property list uchar int vertex_index\n")
f.write(b"end_header\n")
if not (len(verts) or len(faces)):
warnings.warn("Empty 'verts' and 'faces' arguments provided")
return
if decimal_places is None:
float_str = "%f"
vert_data = torch.cat((verts, verts_normals), dim=1).detach().numpy()
if ascii:
if decimal_places is None:
float_str = "%f"
else:
float_str = "%" + ".%df" % decimal_places
np.savetxt(f, vert_data, float_str)
else:
float_str = "%" + ".%df" % decimal_places
vert_data = torch.cat((verts, verts_normals), dim=1)
np.savetxt(f, vert_data.detach().numpy(), float_str)
assert vert_data.dtype == np.float32
if isinstance(f, BytesIO):
# tofile only works with real files, but is faster than this.
f.write(vert_data.tobytes())
else:
vert_data.tofile(f)
faces_array = faces.detach().numpy()
_check_faces_indices(faces, max_index=verts.shape[0])
if len(faces_array):
np.savetxt(f, faces_array, "3 %d %d %d")
if ascii:
np.savetxt(f, faces_array, "3 %d %d %d")
else:
# rows are 13 bytes: a one-byte 3 followed by three four-byte face indices.
faces_uints = np.full((len(faces_array), 13), 3, dtype=np.uint8)
faces_uints[:, 1:] = faces_array.astype(np.uint32).view(np.uint8)
if isinstance(f, BytesIO):
f.write(faces_uints.tobytes())
else:
faces_uints.tofile(f)
def save_ply(
@ -745,6 +778,7 @@ def save_ply(
verts: torch.Tensor,
faces: Optional[torch.LongTensor] = None,
verts_normals: Optional[torch.Tensor] = None,
ascii: bool = False,
decimal_places: Optional[int] = None,
) -> None:
"""
@ -755,7 +789,8 @@ def save_ply(
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
faces: LongTensor of shape (F, 3) giving faces.
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
decimal_places: Number of decimal places for saving.
ascii: (bool) whether to use the ascii ply format.
decimal_places: Number of decimal places for saving if ascii=True.
"""
verts_normals = (
@ -781,5 +816,5 @@ def save_ply(
message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)."
raise ValueError(message)
with _open_file(f, "w") as f:
_save_ply(f, verts, faces, verts_normals, decimal_places)
with _open_file(f, "wb") as f:
_save_ply(f, verts, faces, verts_normals, ascii, decimal_places)

View File

@ -3,6 +3,7 @@
import struct
import unittest
from io import BytesIO, StringIO
from tempfile import TemporaryFile
import torch
from common_testing import TestCaseMixin
@ -144,7 +145,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
with self.assertRaises(ValueError) as error:
verts = torch.FloatTensor([[0.1, 0.2, 0.3, 0.4]]) # (V, 4)
faces = torch.LongTensor([[0, 1, 2]])
save_ply(StringIO(), verts, faces)
save_ply(BytesIO(), verts, faces)
expected_message = (
"Argument 'verts' should either be empty or of shape (num_verts, 3)."
)
@ -154,7 +155,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
with self.assertRaises(ValueError) as error:
verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
faces = torch.LongTensor([[0, 1, 2, 3]]) # (F, 4)
save_ply(StringIO(), verts, faces)
save_ply(BytesIO(), verts, faces)
expected_message = (
"Argument 'faces' should either be empty or of shape (num_faces, 3)."
)
@ -165,14 +166,14 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
faces = torch.LongTensor([[0, 1, 2]])
with self.assertWarnsRegex(UserWarning, message_regex):
save_ply(StringIO(), verts, faces)
save_ply(BytesIO(), verts, faces)
faces = torch.LongTensor([[-1, 0, 1]])
with self.assertWarnsRegex(UserWarning, message_regex):
save_ply(StringIO(), verts, faces)
save_ply(BytesIO(), verts, faces)
def _test_save_load(self, verts, faces):
f = StringIO()
f = BytesIO()
save_ply(f, verts, faces)
f.seek(0)
# raise Exception(f.getvalue())
@ -193,7 +194,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
normals = torch.tensor(
[[0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float32
)
file = StringIO()
file = BytesIO()
save_ply(file, verts=verts, faces=faces, verts_normals=normals)
file.close()
@ -237,15 +238,31 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
def test_simple_save(self):
verts = torch.tensor(
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 0]], dtype=torch.float32
)
faces = torch.tensor([[0, 1, 2], [0, 3, 4]])
file = StringIO()
save_ply(file, verts=verts, faces=faces)
file.seek(0)
verts2, faces2 = load_ply(file)
self.assertClose(verts, verts2)
self.assertClose(faces, faces2)
for filetype in BytesIO, TemporaryFile:
lengths = {}
for ascii in [True, False]:
file = filetype()
save_ply(file, verts=verts, faces=faces, ascii=ascii)
lengths[ascii] = file.tell()
file.seek(0)
verts2, faces2 = load_ply(file)
self.assertClose(verts, verts2)
self.assertClose(faces, faces2)
file.seek(0)
if ascii:
file.read().decode("ascii")
else:
with self.assertRaises(UnicodeDecodeError):
file.read().decode("ascii")
if filetype is TemporaryFile:
file.close()
self.assertLess(lengths[False], lengths[True], "ascii should be longer")
def test_load_simple_binary(self):
for big_endian in [True, False]:
@ -488,15 +505,21 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
@staticmethod
def _bm_save_ply(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):
return lambda: save_ply(StringIO(), verts, faces, decimal_places=decimal_places)
return lambda: save_ply(
BytesIO(),
verts=verts,
faces=faces,
ascii=True,
decimal_places=decimal_places,
)
@staticmethod
def _bm_load_ply(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):
f = StringIO()
save_ply(f, verts, faces, decimal_places)
f = BytesIO()
save_ply(f, verts=verts, faces=faces, ascii=True, decimal_places=decimal_places)
s = f.getvalue()
# Recreate stream so it's unaffected by how it was created.
return lambda: load_ply(StringIO(s))
return lambda: load_ply(BytesIO(s))
@staticmethod
def bm_save_simple_ply_with_init(V: int, F: int):