Fix saving / loading empty PLY meshes

Summary:
Similar to D20392526, PLY files without vertices or faces should be allowed:
- a PLY with only vertices can represent a point cloud
- a PLY without any vertex or face is just empty
- a PLY with faces referencing inexistent vertices has invalid data

Reviewed By: gkioxari

Differential Revision: D20400330

fbshipit-source-id: 35a5f072603fd221f382c7faad5f37c3e0b49bb1
This commit is contained in:
Patrick Labatut
2020-04-01 04:49:27 -07:00
committed by Facebook GitHub Bot
parent b64fe51360
commit 83feed56a0
2 changed files with 154 additions and 15 deletions

View File

@@ -139,6 +139,90 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
]
self.assertClose(faces, torch.LongTensor(faces_expected))
def test_save_ply_invalid_shapes(self):
# Invalid vertices shape
with self.assertRaises(ValueError) as error:
verts = torch.FloatTensor([[0.1, 0.2, 0.3, 0.4]]) # (V, 4)
faces = torch.LongTensor([[0, 1, 2]])
save_ply(StringIO(), verts, faces)
expected_message = (
"Argument 'verts' should either be empty or of shape (num_verts, 3)."
)
self.assertTrue(expected_message, error.exception)
# Invalid faces shape
with self.assertRaises(ValueError) as error:
verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
faces = torch.LongTensor([[0, 1, 2, 3]]) # (F, 4)
save_ply(StringIO(), verts, faces)
expected_message = (
"Argument 'faces' should either be empty or of shape (num_faces, 3)."
)
self.assertTrue(expected_message, error.exception)
def test_save_ply_invalid_indices(self):
message_regex = "Faces have invalid indices"
verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
faces = torch.LongTensor([[0, 1, 2]])
with self.assertWarnsRegex(UserWarning, message_regex):
save_ply(StringIO(), verts, faces)
faces = torch.LongTensor([[-1, 0, 1]])
with self.assertWarnsRegex(UserWarning, message_regex):
save_ply(StringIO(), verts, faces)
def _test_save_load(self, verts, faces):
f = StringIO()
save_ply(f, verts, faces)
f.seek(0)
# raise Exception(f.getvalue())
expected_verts, expected_faces = verts, faces
if not len(expected_verts): # Always compare with a (V, 3) tensor
expected_verts = torch.zeros(size=(0, 3), dtype=torch.float32)
if not len(expected_faces): # Always compare with an (F, 3) tensor
expected_faces = torch.zeros(size=(0, 3), dtype=torch.int64)
actual_verts, actual_faces = load_ply(f)
self.assertClose(expected_verts, actual_verts)
self.assertClose(expected_faces, actual_faces)
def test_empty_save_load(self):
# Vertices + empty faces
verts = torch.tensor([[0.1, 0.2, 0.3]])
faces = torch.LongTensor([])
self._test_save_load(verts, faces)
faces = torch.zeros(size=(0, 3), dtype=torch.int64)
self._test_save_load(verts, faces)
# Faces + empty vertices
message_regex = "Faces have invalid indices"
verts = torch.FloatTensor([])
faces = torch.LongTensor([[0, 1, 2]])
with self.assertWarnsRegex(UserWarning, message_regex):
self._test_save_load(verts, faces)
verts = torch.zeros(size=(0, 3), dtype=torch.float32)
with self.assertWarnsRegex(UserWarning, message_regex):
self._test_save_load(verts, faces)
# Empty vertices + empty faces
message_regex = "Empty 'verts' and 'faces' arguments provided"
verts0 = torch.FloatTensor([])
faces0 = torch.LongTensor([])
with self.assertWarnsRegex(UserWarning, message_regex):
self._test_save_load(verts0, faces0)
faces3 = torch.zeros(size=(0, 3), dtype=torch.int64)
with self.assertWarnsRegex(UserWarning, message_regex):
self._test_save_load(verts0, faces3)
verts3 = torch.zeros(size=(0, 3), dtype=torch.float32)
with self.assertWarnsRegex(UserWarning, message_regex):
self._test_save_load(verts3, faces0)
with self.assertWarnsRegex(UserWarning, message_regex):
self._test_save_load(verts3, faces3)
def test_simple_save(self):
verts = torch.tensor(
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
@@ -378,6 +462,11 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
lines2.insert(5, "property float z")
lines2.insert(5, "property float y")
lines2[-2] = "0 0 0"
lines2[-1] = ""
with self.assertRaisesRegex(ValueError, "Not enough data for face."):
load_ply(StringIO("\n".join(lines2)))
lines2[-1] = "2 0 0"
with self.assertRaisesRegex(ValueError, "Faces must have at least 3 vertices."):
load_ply(StringIO("\n".join(lines2)))