pytorch3d/tests/test_ply_io.py
facebook-github-bot dbf06b504b Initial commit
fbshipit-source-id: ad58e416e3ceeca85fae0583308968d04e78fe0d
2020-01-23 11:53:46 -08:00

435 lines
15 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import struct
import unittest
from io import BytesIO, StringIO
import torch
from pytorch3d.io.ply_io import _load_ply_raw, load_ply, save_ply
from common_testing import TestCaseMixin
class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
def test_raw_load_simple_ascii(self):
ply_file = "\n".join(
[
"ply",
"format ascii 1.0",
"comment made by Greg Turk",
"comment this file is a cube",
"element vertex 8",
"property float x",
"property float y",
"property float z",
"element face 6",
"property list uchar int vertex_index",
"element irregular_list 3",
"property list uchar int vertex_index",
"end_header",
"0 0 0",
"0 0 1",
"0 1 1",
"0 1 0",
"1 0 0",
"1 0 1",
"1 1 1",
"1 1 0",
"4 0 1 2 3",
"4 7 6 5 4",
"4 0 4 5 1",
"4 1 5 6 2",
"4 2 6 7 3",
"4 3 7 4 0", # end of faces
"4 0 1 2 3",
"4 7 6 5 4",
"3 4 5 1",
]
)
for line_ending in [None, "\n", "\r\n"]:
if line_ending is None:
stream = StringIO(ply_file)
else:
byte_file = ply_file.encode("ascii")
if line_ending == "\r\n":
byte_file = byte_file.replace(b"\n", b"\r\n")
stream = BytesIO(byte_file)
header, data = _load_ply_raw(stream)
self.assertTrue(header.ascii)
self.assertEqual(len(data), 3)
self.assertTupleEqual(data["face"].shape, (6, 4))
self.assertClose([0, 1, 2, 3], data["face"][0])
self.assertClose([3, 7, 4, 0], data["face"][5])
self.assertTupleEqual(data["vertex"].shape, (8, 3))
irregular = data["irregular_list"]
self.assertEqual(len(irregular), 3)
self.assertEqual(type(irregular), list)
[x] = irregular[0]
self.assertClose(x, [0, 1, 2, 3])
[x] = irregular[1]
self.assertClose(x, [7, 6, 5, 4])
[x] = irregular[2]
self.assertClose(x, [4, 5, 1])
def test_load_simple_ascii(self):
ply_file = "\n".join(
[
"ply",
"format ascii 1.0",
"comment made by Greg Turk",
"comment this file is a cube",
"element vertex 8",
"property float x",
"property float y",
"property float z",
"element face 6",
"property list uchar int vertex_index",
"end_header",
"0 0 0",
"0 0 1",
"0 1 1",
"0 1 0",
"1 0 0",
"1 0 1",
"1 1 1",
"1 1 0",
"4 0 1 2 3",
"4 7 6 5 4",
"4 0 4 5 1",
"4 1 5 6 2",
"4 2 6 7 3",
"4 3 7 4 0",
]
)
for line_ending in [None, "\n", "\r\n"]:
if line_ending is None:
stream = StringIO(ply_file)
else:
byte_file = ply_file.encode("ascii")
if line_ending == "\r\n":
byte_file = byte_file.replace(b"\n", b"\r\n")
stream = BytesIO(byte_file)
verts, faces = load_ply(stream)
self.assertEqual(verts.shape, (8, 3))
self.assertEqual(faces.shape, (12, 3))
verts_expected = [
[0, 0, 0],
[0, 0, 1],
[0, 1, 1],
[0, 1, 0],
[1, 0, 0],
[1, 0, 1],
[1, 1, 1],
[1, 1, 0],
]
self.assertClose(verts, torch.FloatTensor(verts_expected))
faces_expected = [
[0, 1, 2],
[7, 6, 5],
[0, 4, 5],
[1, 5, 6],
[2, 6, 7],
[3, 7, 4],
[0, 2, 3],
[7, 5, 4],
[0, 5, 1],
[1, 6, 2],
[2, 7, 3],
[3, 4, 0],
]
self.assertClose(faces, torch.LongTensor(faces_expected))
def test_simple_save(self):
verts = torch.tensor(
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
)
faces = torch.tensor([[0, 1, 2], [0, 3, 4]])
file = StringIO()
save_ply(file, verts=verts, faces=faces)
file.seek(0)
verts2, faces2 = load_ply(file)
self.assertClose(verts, verts2)
self.assertClose(faces, faces2)
def test_load_simple_binary(self):
for big_endian in [True, False]:
verts = (
"0 0 0 "
"0 0 1 "
"0 1 1 "
"0 1 0 "
"1 0 0 "
"1 0 1 "
"1 1 1 "
"1 1 0"
).split()
faces = (
"4 0 1 2 3 "
"4 7 6 5 4 "
"4 0 4 5 1 "
"4 1 5 6 2 "
"4 2 6 7 3 "
"4 3 7 4 0 " # end of first 6
"4 0 1 2 3 "
"4 7 6 5 4 "
"3 4 5 1"
).split()
short_one = b"\00\01" if big_endian else b"\01\00"
mixed_data = b"\00\00" b"\03\03" + (
short_one + b"\00\01\01\01" b"\00\02"
)
minus_one_data = b"\xff" * 14
endian_char = ">" if big_endian else "<"
format = (
"format binary_big_endian 1.0"
if big_endian
else "format binary_little_endian 1.0"
)
vertex_pattern = endian_char + "24f"
vertex_data = struct.pack(vertex_pattern, *map(float, verts))
vertex1_pattern = endian_char + "fdffdffdffdffdffdffdffdf"
vertex1_data = struct.pack(vertex1_pattern, *map(float, verts))
face_char_pattern = endian_char + "44b"
face_char_data = struct.pack(face_char_pattern, *map(int, faces))
header = "\n".join(
[
"ply",
format,
"element vertex 8",
"property float x",
"property float y",
"property float z",
"element vertex1 8",
"property float x",
"property double y",
"property float z",
"element face 6",
"property list uchar uchar vertex_index",
"element irregular_list 3",
"property list uchar uchar vertex_index",
"element mixed 2",
"property list short uint foo",
"property short bar",
"element minus_ones 1",
"property char 1",
"property uchar 2",
"property short 3",
"property ushort 4",
"property int 5",
"property uint 6",
"end_header\n",
]
)
ply_file = b"".join(
[
header.encode("ascii"),
vertex_data,
vertex1_data,
face_char_data,
mixed_data,
minus_one_data,
]
)
metadata, data = _load_ply_raw(BytesIO(ply_file))
self.assertFalse(metadata.ascii)
self.assertEqual(len(data), 6)
self.assertTupleEqual(data["face"].shape, (6, 4))
self.assertClose([0, 1, 2, 3], data["face"][0])
self.assertClose([3, 7, 4, 0], data["face"][5])
self.assertTupleEqual(data["vertex"].shape, (8, 3))
self.assertEqual(len(data["vertex1"]), 8)
self.assertClose(data["vertex"], data["vertex1"])
self.assertClose(data["vertex"].flatten(), list(map(float, verts)))
irregular = data["irregular_list"]
self.assertEqual(len(irregular), 3)
self.assertEqual(type(irregular), list)
[x] = irregular[0]
self.assertClose(x, [0, 1, 2, 3])
[x] = irregular[1]
self.assertClose(x, [7, 6, 5, 4])
[x] = irregular[2]
self.assertClose(x, [4, 5, 1])
mixed = data["mixed"]
self.assertEqual(len(mixed), 2)
self.assertEqual(len(mixed[0]), 2)
self.assertEqual(len(mixed[1]), 2)
self.assertEqual(mixed[0][1], 3 * 256 + 3)
self.assertEqual(len(mixed[0][0]), 0)
self.assertEqual(mixed[1][1], (2 if big_endian else 2 * 256))
base = 1 + 256 + 256 * 256
self.assertEqual(len(mixed[1][0]), 1)
self.assertEqual(mixed[1][0][0], base if big_endian else 256 * base)
self.assertListEqual(
data["minus_ones"], [(-1, 255, -1, 65535, -1, 4294967295)]
)
def test_bad_ply_syntax(self):
"""Some syntactically bad ply files."""
lines = [
"ply",
"format ascii 1.0",
"comment dashfadskfj;k",
"element vertex 1",
"property float x",
"element listy 1",
"property list uint int x",
"end_header",
"0",
"0",
]
lines2 = lines.copy()
# this is ok
_load_ply_raw(StringIO("\n".join(lines2)))
lines2 = lines.copy()
lines2[0] = "PLY"
with self.assertRaisesRegex(ValueError, "Invalid file header."):
_load_ply_raw(StringIO("\n".join(lines2)))
lines2 = lines.copy()
lines2[2] = "#this is a comment"
with self.assertRaisesRegex(ValueError, "Invalid line.*"):
_load_ply_raw(StringIO("\n".join(lines2)))
lines2 = lines.copy()
lines2[3] = lines[4]
lines2[4] = lines[3]
with self.assertRaisesRegex(
ValueError, "Encountered property before any element."
):
_load_ply_raw(StringIO("\n".join(lines2)))
lines2 = lines.copy()
lines2[8] = "1 2"
with self.assertRaisesRegex(
ValueError, "Inconsistent data for vertex."
):
_load_ply_raw(StringIO("\n".join(lines2)))
lines2 = lines[:-1]
with self.assertRaisesRegex(ValueError, "Not enough data for listy."):
_load_ply_raw(StringIO("\n".join(lines2)))
lines2 = lines.copy()
lines2[5] = "element listy 2"
with self.assertRaisesRegex(ValueError, "Not enough data for listy."):
_load_ply_raw(StringIO("\n".join(lines2)))
lines2 = lines.copy()
lines2.insert(4, "property short x")
with self.assertRaisesRegex(
ValueError, "Cannot have two properties called x in vertex."
):
_load_ply_raw(StringIO("\n".join(lines2)))
lines2 = lines.copy()
lines2.insert(4, "property zz short")
with self.assertRaisesRegex(ValueError, "Invalid datatype: zz"):
_load_ply_raw(StringIO("\n".join(lines2)))
lines2 = lines.copy()
lines2.append("3")
with self.assertRaisesRegex(ValueError, "Extra data at end of file."):
_load_ply_raw(StringIO("\n".join(lines2)))
lines2 = lines.copy()
lines2.append("comment foo")
with self.assertRaisesRegex(ValueError, "Extra data at end of file."):
_load_ply_raw(StringIO("\n".join(lines2)))
lines2 = lines.copy()
lines2.insert(4, "element bad 1")
with self.assertRaisesRegex(
ValueError, "Found an element with no properties."
):
_load_ply_raw(StringIO("\n".join(lines2)))
lines2 = lines.copy()
lines2[-1] = "3 2 3 3"
_load_ply_raw(StringIO("\n".join(lines2)))
lines2 = lines.copy()
lines2[-1] = "3 1 2 3 4"
msg = "A line of listy data did not have the specified length."
with self.assertRaisesRegex(ValueError, msg):
_load_ply_raw(StringIO("\n".join(lines2)))
lines2 = lines.copy()
lines2[3] = "element vertex one"
msg = "Number of items for vertex was not a number."
with self.assertRaisesRegex(ValueError, msg):
_load_ply_raw(StringIO("\n".join(lines2)))
# Heterogenous cases
lines2 = lines.copy()
lines2.insert(4, "property double y")
with self.assertRaisesRegex(
ValueError, "Too little data for an element."
):
_load_ply_raw(StringIO("\n".join(lines2)))
lines2[-2] = "3.3 4.2"
_load_ply_raw(StringIO("\n".join(lines2)))
lines2[-2] = "3.3 4.3 2"
with self.assertRaisesRegex(
ValueError, "Too much data for an element."
):
_load_ply_raw(StringIO("\n".join(lines2)))
# Now make the ply file actually be readable as a Mesh
with self.assertRaisesRegex(
ValueError, "The ply file has no face element."
):
load_ply(StringIO("\n".join(lines)))
lines2 = lines.copy()
lines2[5] = "element face 1"
with self.assertRaisesRegex(ValueError, "Invalid vertices in file."):
load_ply(StringIO("\n".join(lines2)))
lines2.insert(5, "property float z")
lines2.insert(5, "property float y")
lines2[-2] = "0 0 0"
with self.assertRaisesRegex(
ValueError, "Faces must have at least 3 vertices."
):
load_ply(StringIO("\n".join(lines2)))
# Good one
lines2[-1] = "3 0 0 0"
load_ply(StringIO("\n".join(lines2)))
@staticmethod
def save_ply_bm(V: int, F: int):
verts_list = torch.tensor(V * [[0.11, 0.22, 0.33]]).view(-1, 3)
faces_list = torch.tensor(F * [[0, 1, 2]]).view(-1, 3)
def save_mesh():
file = StringIO()
save_ply(file, verts_list, faces_list, 2)
return save_mesh
@staticmethod
def load_ply_bm(V: int, F: int):
verts = torch.tensor([[0.1, 0.2, 0.3]]).expand(V, 3)
faces = torch.tensor([[0, 1, 2]], dtype=torch.int64).expand(F, 3)
ply_file = StringIO()
save_ply(ply_file, verts=verts, faces=faces)
ply = ply_file.getvalue()
# Recreate stream so it's unaffected by how it was created.
def load_mesh():
ply_file = StringIO(ply)
verts, faces = load_ply(ply_file)
return load_mesh