save_ply binary

Summary:
Make save_ply save to binary instead of ascii. An option makes the previous functionality available. save_ply's API accepts a stream, but this is undocumented; that stream must now be a binary stream not a text stream.

Avoiding warnings about making tensors from immutable numpy arrays.

Possible performance improvement when reading binary files.

Fix reading zero-length binary lists.

Reviewed By: nikhilaravi

Differential Revision: D22333118

fbshipit-source-id: b423dfd3da46e047bead200255f47a7707306811
This commit is contained in:
Jeremy Reizenstein
2020-09-21 05:14:59 -07:00
committed by Facebook GitHub Bot
parent ebe2693b11
commit 197f1d6217
2 changed files with 102 additions and 44 deletions

View File

@@ -3,6 +3,7 @@
import struct
import unittest
from io import BytesIO, StringIO
from tempfile import TemporaryFile
import torch
from common_testing import TestCaseMixin
@@ -144,7 +145,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
with self.assertRaises(ValueError) as error:
verts = torch.FloatTensor([[0.1, 0.2, 0.3, 0.4]]) # (V, 4)
faces = torch.LongTensor([[0, 1, 2]])
save_ply(StringIO(), verts, faces)
save_ply(BytesIO(), verts, faces)
expected_message = (
"Argument 'verts' should either be empty or of shape (num_verts, 3)."
)
@@ -154,7 +155,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
with self.assertRaises(ValueError) as error:
verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
faces = torch.LongTensor([[0, 1, 2, 3]]) # (F, 4)
save_ply(StringIO(), verts, faces)
save_ply(BytesIO(), verts, faces)
expected_message = (
"Argument 'faces' should either be empty or of shape (num_faces, 3)."
)
@@ -165,14 +166,14 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
faces = torch.LongTensor([[0, 1, 2]])
with self.assertWarnsRegex(UserWarning, message_regex):
save_ply(StringIO(), verts, faces)
save_ply(BytesIO(), verts, faces)
faces = torch.LongTensor([[-1, 0, 1]])
with self.assertWarnsRegex(UserWarning, message_regex):
save_ply(StringIO(), verts, faces)
save_ply(BytesIO(), verts, faces)
def _test_save_load(self, verts, faces):
f = StringIO()
f = BytesIO()
save_ply(f, verts, faces)
f.seek(0)
# raise Exception(f.getvalue())
@@ -193,7 +194,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
normals = torch.tensor(
[[0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float32
)
file = StringIO()
file = BytesIO()
save_ply(file, verts=verts, faces=faces, verts_normals=normals)
file.close()
@@ -237,15 +238,31 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
def test_simple_save(self):
verts = torch.tensor(
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 0]], dtype=torch.float32
)
faces = torch.tensor([[0, 1, 2], [0, 3, 4]])
file = StringIO()
save_ply(file, verts=verts, faces=faces)
file.seek(0)
verts2, faces2 = load_ply(file)
self.assertClose(verts, verts2)
self.assertClose(faces, faces2)
for filetype in BytesIO, TemporaryFile:
lengths = {}
for ascii in [True, False]:
file = filetype()
save_ply(file, verts=verts, faces=faces, ascii=ascii)
lengths[ascii] = file.tell()
file.seek(0)
verts2, faces2 = load_ply(file)
self.assertClose(verts, verts2)
self.assertClose(faces, faces2)
file.seek(0)
if ascii:
file.read().decode("ascii")
else:
with self.assertRaises(UnicodeDecodeError):
file.read().decode("ascii")
if filetype is TemporaryFile:
file.close()
self.assertLess(lengths[False], lengths[True], "ascii should be longer")
def test_load_simple_binary(self):
for big_endian in [True, False]:
@@ -488,15 +505,21 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
@staticmethod
def _bm_save_ply(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):
return lambda: save_ply(StringIO(), verts, faces, decimal_places=decimal_places)
return lambda: save_ply(
BytesIO(),
verts=verts,
faces=faces,
ascii=True,
decimal_places=decimal_places,
)
@staticmethod
def _bm_load_ply(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):
f = StringIO()
save_ply(f, verts, faces, decimal_places)
f = BytesIO()
save_ply(f, verts=verts, faces=faces, ascii=True, decimal_places=decimal_places)
s = f.getvalue()
# Recreate stream so it's unaffected by how it was created.
return lambda: load_ply(StringIO(s))
return lambda: load_ply(BytesIO(s))
@staticmethod
def bm_save_simple_ply_with_init(V: int, F: int):