mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 22:00:35 +08:00
save_ply binary
Summary: Make save_ply save to binary instead of ascii. An option makes the previous functionality available. save_ply's API accepts a stream, but this is undocumented; that stream must now be a binary stream not a text stream. Avoiding warnings about making tensors from immutable numpy arrays. Possible performance improvement when reading binary files. Fix reading zero-length binary lists. Reviewed By: nikhilaravi Differential Revision: D22333118 fbshipit-source-id: b423dfd3da46e047bead200255f47a7707306811
This commit is contained in:
committed by
Facebook GitHub Bot
parent
ebe2693b11
commit
197f1d6217
@@ -3,6 +3,7 @@
|
||||
import struct
|
||||
import unittest
|
||||
from io import BytesIO, StringIO
|
||||
from tempfile import TemporaryFile
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
@@ -144,7 +145,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
||||
with self.assertRaises(ValueError) as error:
|
||||
verts = torch.FloatTensor([[0.1, 0.2, 0.3, 0.4]]) # (V, 4)
|
||||
faces = torch.LongTensor([[0, 1, 2]])
|
||||
save_ply(StringIO(), verts, faces)
|
||||
save_ply(BytesIO(), verts, faces)
|
||||
expected_message = (
|
||||
"Argument 'verts' should either be empty or of shape (num_verts, 3)."
|
||||
)
|
||||
@@ -154,7 +155,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
||||
with self.assertRaises(ValueError) as error:
|
||||
verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
|
||||
faces = torch.LongTensor([[0, 1, 2, 3]]) # (F, 4)
|
||||
save_ply(StringIO(), verts, faces)
|
||||
save_ply(BytesIO(), verts, faces)
|
||||
expected_message = (
|
||||
"Argument 'faces' should either be empty or of shape (num_faces, 3)."
|
||||
)
|
||||
@@ -165,14 +166,14 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
||||
verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
|
||||
faces = torch.LongTensor([[0, 1, 2]])
|
||||
with self.assertWarnsRegex(UserWarning, message_regex):
|
||||
save_ply(StringIO(), verts, faces)
|
||||
save_ply(BytesIO(), verts, faces)
|
||||
|
||||
faces = torch.LongTensor([[-1, 0, 1]])
|
||||
with self.assertWarnsRegex(UserWarning, message_regex):
|
||||
save_ply(StringIO(), verts, faces)
|
||||
save_ply(BytesIO(), verts, faces)
|
||||
|
||||
def _test_save_load(self, verts, faces):
|
||||
f = StringIO()
|
||||
f = BytesIO()
|
||||
save_ply(f, verts, faces)
|
||||
f.seek(0)
|
||||
# raise Exception(f.getvalue())
|
||||
@@ -193,7 +194,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
||||
normals = torch.tensor(
|
||||
[[0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float32
|
||||
)
|
||||
file = StringIO()
|
||||
file = BytesIO()
|
||||
save_ply(file, verts=verts, faces=faces, verts_normals=normals)
|
||||
file.close()
|
||||
|
||||
@@ -237,15 +238,31 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
def test_simple_save(self):
|
||||
verts = torch.tensor(
|
||||
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
|
||||
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 0]], dtype=torch.float32
|
||||
)
|
||||
faces = torch.tensor([[0, 1, 2], [0, 3, 4]])
|
||||
file = StringIO()
|
||||
save_ply(file, verts=verts, faces=faces)
|
||||
file.seek(0)
|
||||
verts2, faces2 = load_ply(file)
|
||||
self.assertClose(verts, verts2)
|
||||
self.assertClose(faces, faces2)
|
||||
for filetype in BytesIO, TemporaryFile:
|
||||
lengths = {}
|
||||
for ascii in [True, False]:
|
||||
file = filetype()
|
||||
save_ply(file, verts=verts, faces=faces, ascii=ascii)
|
||||
lengths[ascii] = file.tell()
|
||||
|
||||
file.seek(0)
|
||||
verts2, faces2 = load_ply(file)
|
||||
self.assertClose(verts, verts2)
|
||||
self.assertClose(faces, faces2)
|
||||
|
||||
file.seek(0)
|
||||
if ascii:
|
||||
file.read().decode("ascii")
|
||||
else:
|
||||
with self.assertRaises(UnicodeDecodeError):
|
||||
file.read().decode("ascii")
|
||||
|
||||
if filetype is TemporaryFile:
|
||||
file.close()
|
||||
self.assertLess(lengths[False], lengths[True], "ascii should be longer")
|
||||
|
||||
def test_load_simple_binary(self):
|
||||
for big_endian in [True, False]:
|
||||
@@ -488,15 +505,21 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
@staticmethod
|
||||
def _bm_save_ply(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):
|
||||
return lambda: save_ply(StringIO(), verts, faces, decimal_places=decimal_places)
|
||||
return lambda: save_ply(
|
||||
BytesIO(),
|
||||
verts=verts,
|
||||
faces=faces,
|
||||
ascii=True,
|
||||
decimal_places=decimal_places,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _bm_load_ply(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):
|
||||
f = StringIO()
|
||||
save_ply(f, verts, faces, decimal_places)
|
||||
f = BytesIO()
|
||||
save_ply(f, verts=verts, faces=faces, ascii=True, decimal_places=decimal_places)
|
||||
s = f.getvalue()
|
||||
# Recreate stream so it's unaffected by how it was created.
|
||||
return lambda: load_ply(StringIO(s))
|
||||
return lambda: load_ply(BytesIO(s))
|
||||
|
||||
@staticmethod
|
||||
def bm_save_simple_ply_with_init(V: int, F: int):
|
||||
|
||||
Reference in New Issue
Block a user