mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-23 06:12:48 +08:00
save_ply binary
Summary: Make save_ply save to binary instead of ascii. An option makes the previous functionality available. save_ply's API accepts a stream, but this is undocumented; that stream must now be a binary stream not a text stream. Avoiding warnings about making tensors from immutable numpy arrays. Possible performance improvement when reading binary files. Fix reading zero-length binary lists. Reviewed By: nikhilaravi Differential Revision: D22333118 fbshipit-source-id: b423dfd3da46e047bead200255f47a7707306811
This commit is contained in:
parent
ebe2693b11
commit
197f1d6217
@ -8,6 +8,7 @@ import struct
|
||||
import sys
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from io import BytesIO
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
@ -386,11 +387,18 @@ def _read_ply_fixed_size_element_binary(
|
||||
np_type = ply_type.np_type
|
||||
type_size = ply_type.size
|
||||
needed_length = definition.count * len(definition.properties)
|
||||
needed_bytes = needed_length * type_size
|
||||
bytes_data = f.read(needed_bytes)
|
||||
if len(bytes_data) != needed_bytes:
|
||||
raise ValueError("Not enough data for %s." % definition.name)
|
||||
data = np.frombuffer(bytes_data, dtype=np_type)
|
||||
if isinstance(f, BytesIO):
|
||||
# np.fromfile is faster but won't work on a BytesIO
|
||||
needed_bytes = needed_length * type_size
|
||||
bytes_data = bytearray(needed_bytes)
|
||||
n_bytes_read = f.readinto(bytes_data)
|
||||
if n_bytes_read != needed_bytes:
|
||||
raise ValueError("Not enough data for %s." % definition.name)
|
||||
data = np.frombuffer(bytes_data, dtype=np_type)
|
||||
else:
|
||||
data = np.fromfile(f, dtype=np_type, count=needed_length)
|
||||
if data.shape[0] != needed_length:
|
||||
raise ValueError("Not enough data for %s." % definition.name)
|
||||
|
||||
if (sys.byteorder == "big") != big_endian:
|
||||
data = data.byteswap()
|
||||
@ -447,6 +455,8 @@ def _try_read_ply_constant_list_binary(
|
||||
If every element has the same size, 2D numpy array corresponding to the
|
||||
data. The rows are the different values. Otherwise None.
|
||||
"""
|
||||
if definition.count == 0:
|
||||
return []
|
||||
property = definition.properties[0]
|
||||
endian_str = ">" if big_endian else "<"
|
||||
length_format = endian_str + _PLY_TYPES[property.list_size_type].struct_char
|
||||
@ -689,6 +699,7 @@ def _save_ply(
|
||||
verts: torch.Tensor,
|
||||
faces: torch.LongTensor,
|
||||
verts_normals: torch.Tensor,
|
||||
ascii: bool,
|
||||
decimal_places: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
@ -699,7 +710,8 @@ def _save_ply(
|
||||
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
|
||||
faces: LongTensor of shsape (F, 3) giving faces.
|
||||
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
|
||||
decimal_places: Number of decimal places for saving.
|
||||
ascii: (bool) whether to use the ascii ply format.
|
||||
decimal_places: Number of decimal places for saving if ascii=True.
|
||||
"""
|
||||
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
|
||||
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
|
||||
@ -707,37 +719,58 @@ def _save_ply(
|
||||
verts_normals.dim() == 2 and verts_normals.size(1) == 3
|
||||
)
|
||||
|
||||
print("ply\nformat ascii 1.0", file=f)
|
||||
print(f"element vertex {verts.shape[0]}", file=f)
|
||||
print("property float x", file=f)
|
||||
print("property float y", file=f)
|
||||
print("property float z", file=f)
|
||||
if ascii:
|
||||
f.write(b"ply\nformat ascii 1.0\n")
|
||||
elif sys.byteorder == "big":
|
||||
f.write(b"ply\nformat binary_big_endian 1.0\n")
|
||||
else:
|
||||
f.write(b"ply\nformat binary_little_endian 1.0\n")
|
||||
f.write(f"element vertex {verts.shape[0]}\n".encode("ascii"))
|
||||
f.write(b"property float x\n")
|
||||
f.write(b"property float y\n")
|
||||
f.write(b"property float z\n")
|
||||
if verts_normals.numel() > 0:
|
||||
print("property float nx", file=f)
|
||||
print("property float ny", file=f)
|
||||
print("property float nz", file=f)
|
||||
print(f"element face {faces.shape[0]}", file=f)
|
||||
print("property list uchar int vertex_index", file=f)
|
||||
print("end_header", file=f)
|
||||
f.write(b"property float nx\n")
|
||||
f.write(b"property float ny\n")
|
||||
f.write(b"property float nz\n")
|
||||
f.write(f"element face {faces.shape[0]}\n".encode("ascii"))
|
||||
f.write(b"property list uchar int vertex_index\n")
|
||||
f.write(b"end_header\n")
|
||||
|
||||
if not (len(verts) or len(faces)):
|
||||
warnings.warn("Empty 'verts' and 'faces' arguments provided")
|
||||
return
|
||||
|
||||
if decimal_places is None:
|
||||
float_str = "%f"
|
||||
vert_data = torch.cat((verts, verts_normals), dim=1).detach().numpy()
|
||||
if ascii:
|
||||
if decimal_places is None:
|
||||
float_str = "%f"
|
||||
else:
|
||||
float_str = "%" + ".%df" % decimal_places
|
||||
np.savetxt(f, vert_data, float_str)
|
||||
else:
|
||||
float_str = "%" + ".%df" % decimal_places
|
||||
|
||||
vert_data = torch.cat((verts, verts_normals), dim=1)
|
||||
np.savetxt(f, vert_data.detach().numpy(), float_str)
|
||||
assert vert_data.dtype == np.float32
|
||||
if isinstance(f, BytesIO):
|
||||
# tofile only works with real files, but is faster than this.
|
||||
f.write(vert_data.tobytes())
|
||||
else:
|
||||
vert_data.tofile(f)
|
||||
|
||||
faces_array = faces.detach().numpy()
|
||||
|
||||
_check_faces_indices(faces, max_index=verts.shape[0])
|
||||
|
||||
if len(faces_array):
|
||||
np.savetxt(f, faces_array, "3 %d %d %d")
|
||||
if ascii:
|
||||
np.savetxt(f, faces_array, "3 %d %d %d")
|
||||
else:
|
||||
# rows are 13 bytes: a one-byte 3 followed by three four-byte face indices.
|
||||
faces_uints = np.full((len(faces_array), 13), 3, dtype=np.uint8)
|
||||
faces_uints[:, 1:] = faces_array.astype(np.uint32).view(np.uint8)
|
||||
if isinstance(f, BytesIO):
|
||||
f.write(faces_uints.tobytes())
|
||||
else:
|
||||
faces_uints.tofile(f)
|
||||
|
||||
|
||||
def save_ply(
|
||||
@ -745,6 +778,7 @@ def save_ply(
|
||||
verts: torch.Tensor,
|
||||
faces: Optional[torch.LongTensor] = None,
|
||||
verts_normals: Optional[torch.Tensor] = None,
|
||||
ascii: bool = False,
|
||||
decimal_places: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
@ -755,7 +789,8 @@ def save_ply(
|
||||
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
|
||||
faces: LongTensor of shape (F, 3) giving faces.
|
||||
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
|
||||
decimal_places: Number of decimal places for saving.
|
||||
ascii: (bool) whether to use the ascii ply format.
|
||||
decimal_places: Number of decimal places for saving if ascii=True.
|
||||
"""
|
||||
|
||||
verts_normals = (
|
||||
@ -781,5 +816,5 @@ def save_ply(
|
||||
message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)."
|
||||
raise ValueError(message)
|
||||
|
||||
with _open_file(f, "w") as f:
|
||||
_save_ply(f, verts, faces, verts_normals, decimal_places)
|
||||
with _open_file(f, "wb") as f:
|
||||
_save_ply(f, verts, faces, verts_normals, ascii, decimal_places)
|
||||
|
@ -3,6 +3,7 @@
|
||||
import struct
|
||||
import unittest
|
||||
from io import BytesIO, StringIO
|
||||
from tempfile import TemporaryFile
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
@ -144,7 +145,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
||||
with self.assertRaises(ValueError) as error:
|
||||
verts = torch.FloatTensor([[0.1, 0.2, 0.3, 0.4]]) # (V, 4)
|
||||
faces = torch.LongTensor([[0, 1, 2]])
|
||||
save_ply(StringIO(), verts, faces)
|
||||
save_ply(BytesIO(), verts, faces)
|
||||
expected_message = (
|
||||
"Argument 'verts' should either be empty or of shape (num_verts, 3)."
|
||||
)
|
||||
@ -154,7 +155,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
||||
with self.assertRaises(ValueError) as error:
|
||||
verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
|
||||
faces = torch.LongTensor([[0, 1, 2, 3]]) # (F, 4)
|
||||
save_ply(StringIO(), verts, faces)
|
||||
save_ply(BytesIO(), verts, faces)
|
||||
expected_message = (
|
||||
"Argument 'faces' should either be empty or of shape (num_faces, 3)."
|
||||
)
|
||||
@ -165,14 +166,14 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
||||
verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
|
||||
faces = torch.LongTensor([[0, 1, 2]])
|
||||
with self.assertWarnsRegex(UserWarning, message_regex):
|
||||
save_ply(StringIO(), verts, faces)
|
||||
save_ply(BytesIO(), verts, faces)
|
||||
|
||||
faces = torch.LongTensor([[-1, 0, 1]])
|
||||
with self.assertWarnsRegex(UserWarning, message_regex):
|
||||
save_ply(StringIO(), verts, faces)
|
||||
save_ply(BytesIO(), verts, faces)
|
||||
|
||||
def _test_save_load(self, verts, faces):
|
||||
f = StringIO()
|
||||
f = BytesIO()
|
||||
save_ply(f, verts, faces)
|
||||
f.seek(0)
|
||||
# raise Exception(f.getvalue())
|
||||
@ -193,7 +194,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
||||
normals = torch.tensor(
|
||||
[[0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float32
|
||||
)
|
||||
file = StringIO()
|
||||
file = BytesIO()
|
||||
save_ply(file, verts=verts, faces=faces, verts_normals=normals)
|
||||
file.close()
|
||||
|
||||
@ -237,15 +238,31 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
def test_simple_save(self):
|
||||
verts = torch.tensor(
|
||||
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
|
||||
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 0]], dtype=torch.float32
|
||||
)
|
||||
faces = torch.tensor([[0, 1, 2], [0, 3, 4]])
|
||||
file = StringIO()
|
||||
save_ply(file, verts=verts, faces=faces)
|
||||
file.seek(0)
|
||||
verts2, faces2 = load_ply(file)
|
||||
self.assertClose(verts, verts2)
|
||||
self.assertClose(faces, faces2)
|
||||
for filetype in BytesIO, TemporaryFile:
|
||||
lengths = {}
|
||||
for ascii in [True, False]:
|
||||
file = filetype()
|
||||
save_ply(file, verts=verts, faces=faces, ascii=ascii)
|
||||
lengths[ascii] = file.tell()
|
||||
|
||||
file.seek(0)
|
||||
verts2, faces2 = load_ply(file)
|
||||
self.assertClose(verts, verts2)
|
||||
self.assertClose(faces, faces2)
|
||||
|
||||
file.seek(0)
|
||||
if ascii:
|
||||
file.read().decode("ascii")
|
||||
else:
|
||||
with self.assertRaises(UnicodeDecodeError):
|
||||
file.read().decode("ascii")
|
||||
|
||||
if filetype is TemporaryFile:
|
||||
file.close()
|
||||
self.assertLess(lengths[False], lengths[True], "ascii should be longer")
|
||||
|
||||
def test_load_simple_binary(self):
|
||||
for big_endian in [True, False]:
|
||||
@ -488,15 +505,21 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
@staticmethod
|
||||
def _bm_save_ply(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):
|
||||
return lambda: save_ply(StringIO(), verts, faces, decimal_places=decimal_places)
|
||||
return lambda: save_ply(
|
||||
BytesIO(),
|
||||
verts=verts,
|
||||
faces=faces,
|
||||
ascii=True,
|
||||
decimal_places=decimal_places,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _bm_load_ply(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):
|
||||
f = StringIO()
|
||||
save_ply(f, verts, faces, decimal_places)
|
||||
f = BytesIO()
|
||||
save_ply(f, verts=verts, faces=faces, ascii=True, decimal_places=decimal_places)
|
||||
s = f.getvalue()
|
||||
# Recreate stream so it's unaffected by how it was created.
|
||||
return lambda: load_ply(StringIO(s))
|
||||
return lambda: load_ply(BytesIO(s))
|
||||
|
||||
@staticmethod
|
||||
def bm_save_simple_ply_with_init(V: int, F: int):
|
||||
|
Loading…
x
Reference in New Issue
Block a user