mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Fix saving / loading empty PLY meshes
Summary: Similar to D20392526, PLY files without vertices or faces should be allowed: - a PLY with only vertices can represent a point cloud - a PLY without any vertex or face is just empty - a PLY with faces referencing inexistent vertices has invalid data Reviewed By: gkioxari Differential Revision: D20400330 fbshipit-source-id: 35a5f072603fd221f382c7faad5f37c3e0b49bb1
This commit is contained in:
parent
b64fe51360
commit
83feed56a0
@ -213,6 +213,17 @@ class _PlyHeader:
|
||||
self.elements.append(_PlyElementType(items[1], count))
|
||||
|
||||
|
||||
def _make_tensor(data, cols: int, dtype: torch.dtype) -> torch.Tensor:
|
||||
"""
|
||||
Return a 2D tensor with the specified cols and dtype filled with data,
|
||||
even when data is empty.
|
||||
"""
|
||||
if not len(data):
|
||||
return torch.zeros((0, cols), dtype=dtype)
|
||||
|
||||
return torch.tensor(data, dtype=dtype)
|
||||
|
||||
|
||||
def _read_ply_fixed_size_element_ascii(f, definition: _PlyElementType):
|
||||
"""
|
||||
Given an element which has no lists and one type, read the
|
||||
@ -227,10 +238,17 @@ def _read_ply_fixed_size_element_ascii(f, definition: _PlyElementType):
|
||||
values. There is one column for each property.
|
||||
"""
|
||||
np_type = _PLY_TYPES[definition.properties[0].data_type].np_type
|
||||
data = np.loadtxt(
|
||||
f, dtype=np_type, comments=None, ndmin=2, max_rows=definition.count
|
||||
)
|
||||
if data.shape[1] != len(definition.properties):
|
||||
old_offset = f.tell()
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore", message=".* Empty input file.*", category=UserWarning
|
||||
)
|
||||
data = np.loadtxt(
|
||||
f, dtype=np_type, comments=None, ndmin=2, max_rows=definition.count
|
||||
)
|
||||
if not len(data): # np.loadtxt() seeks even on empty data
|
||||
f.seek(old_offset)
|
||||
if definition.count and data.shape[1] != len(definition.properties):
|
||||
raise ValueError("Inconsistent data for %s." % definition.name)
|
||||
if data.shape[0] != definition.count:
|
||||
raise ValueError("Not enough data for %s." % definition.name)
|
||||
@ -252,7 +270,7 @@ def _try_read_ply_constant_list_ascii(f, definition: _PlyElementType):
|
||||
data. The rows are the different values. Otherwise None.
|
||||
"""
|
||||
np_type = _PLY_TYPES[definition.properties[0].data_type].np_type
|
||||
start_point = f.tell()
|
||||
old_offset = f.tell()
|
||||
try:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
@ -262,8 +280,10 @@ def _try_read_ply_constant_list_ascii(f, definition: _PlyElementType):
|
||||
f, dtype=np_type, comments=None, ndmin=2, max_rows=definition.count
|
||||
)
|
||||
except ValueError:
|
||||
f.seek(start_point)
|
||||
f.seek(old_offset)
|
||||
return None
|
||||
if not len(data): # np.loadtxt() seeks even on empty data
|
||||
f.seek(old_offset)
|
||||
if (data.shape[1] - 1 != data[:, 0]).any():
|
||||
msg = "A line of %s data did not have the specified length."
|
||||
raise ValueError(msg % definition.name)
|
||||
@ -442,7 +462,7 @@ def _try_read_ply_constant_list_binary(
|
||||
[length] = length_struct.unpack(bytes_data)
|
||||
return length
|
||||
|
||||
start_point = f.tell()
|
||||
old_offset = f.tell()
|
||||
|
||||
length = get_length()
|
||||
np_type = _PLY_TYPES[definition.properties[0].data_type].np_type
|
||||
@ -459,7 +479,7 @@ def _try_read_ply_constant_list_binary(
|
||||
if i + 1 == definition.count:
|
||||
break
|
||||
if length != get_length():
|
||||
f.seek(start_point)
|
||||
f.seek(old_offset)
|
||||
return None
|
||||
if (sys.byteorder == "big") != big_endian:
|
||||
output = output.byteswap()
|
||||
@ -644,20 +664,25 @@ def load_ply(f):
|
||||
if face is None:
|
||||
raise ValueError("The ply file has no face element.")
|
||||
|
||||
if not isinstance(vertex, np.ndarray) or vertex.ndim != 2 or vertex.shape[1] != 3:
|
||||
if len(vertex) and (
|
||||
not isinstance(vertex, np.ndarray) or vertex.ndim != 2 or vertex.shape[1] != 3
|
||||
):
|
||||
raise ValueError("Invalid vertices in file.")
|
||||
verts = torch.tensor(vertex, dtype=torch.float32)
|
||||
verts = _make_tensor(vertex, cols=3, dtype=torch.float32)
|
||||
|
||||
face_head = next(head for head in header.elements if head.name == "face")
|
||||
if len(face_head.properties) != 1 or face_head.properties[0].list_size_type is None:
|
||||
raise ValueError("Unexpected form of faces data.")
|
||||
# face_head.properties[0].name is usually "vertex_index" or "vertex_indices"
|
||||
# but we don't need to enforce this.
|
||||
if isinstance(face, np.ndarray) and face.ndim == 2:
|
||||
|
||||
if not len(face):
|
||||
faces = torch.zeros(size=(0, 3), dtype=torch.int64)
|
||||
elif isinstance(face, np.ndarray) and face.ndim == 2: # Homogeneous elements
|
||||
if face.shape[1] < 3:
|
||||
raise ValueError("Faces must have at least 3 vertices.")
|
||||
face_arrays = [face[:, [0, i + 1, i + 2]] for i in range(face.shape[1] - 2)]
|
||||
faces = torch.tensor(np.vstack(face_arrays), dtype=torch.int64)
|
||||
faces = torch.LongTensor(np.vstack(face_arrays))
|
||||
else:
|
||||
face_list = []
|
||||
for face_item in face:
|
||||
@ -667,7 +692,10 @@ def load_ply(f):
|
||||
raise ValueError("Faces must have at least 3 vertices.")
|
||||
for i in range(face_item.shape[0] - 2):
|
||||
face_list.append([face_item[0], face_item[i + 1], face_item[i + 2]])
|
||||
faces = torch.tensor(face_list, dtype=torch.int64)
|
||||
faces = _make_tensor(face_list, cols=3, dtype=torch.int64)
|
||||
|
||||
if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0):
|
||||
warnings.warn("Faces have invalid indices")
|
||||
|
||||
return verts, faces
|
||||
|
||||
@ -682,6 +710,16 @@ def _save_ply(f, verts, faces, decimal_places: Optional[int]):
|
||||
faces: LongTensor of shape (F, 3) giving faces.
|
||||
decimal_places: Number of decimal places for saving.
|
||||
"""
|
||||
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
|
||||
raise ValueError(
|
||||
"Argument 'verts' should either be empty or of shape (num_verts, 3)."
|
||||
)
|
||||
|
||||
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
|
||||
raise ValueError(
|
||||
"Argument 'faces' should either be empty or of shape (num_faces, 3)."
|
||||
)
|
||||
|
||||
print("ply\nformat ascii 1.0", file=f)
|
||||
print(f"element vertex {verts.shape[0]}", file=f)
|
||||
print("property float x", file=f)
|
||||
@ -691,13 +729,25 @@ def _save_ply(f, verts, faces, decimal_places: Optional[int]):
|
||||
print("property list uchar int vertex_index", file=f)
|
||||
print("end_header", file=f)
|
||||
|
||||
if not (len(verts) or len(faces)):
|
||||
warnings.warn("Empty 'verts' and 'faces' arguments provided")
|
||||
return
|
||||
|
||||
if decimal_places is None:
|
||||
float_str = "%f"
|
||||
else:
|
||||
float_str = "%" + ".%df" % decimal_places
|
||||
|
||||
np.savetxt(f, verts.detach().numpy(), float_str)
|
||||
np.savetxt(f, faces.detach().numpy(), "3 %d %d %d")
|
||||
verts_array = verts.detach().numpy()
|
||||
np.savetxt(f, verts_array, float_str)
|
||||
|
||||
faces_array = faces.detach().numpy()
|
||||
|
||||
if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0):
|
||||
warnings.warn("Faces have invalid indices")
|
||||
|
||||
if len(faces_array):
|
||||
np.savetxt(f, faces_array, "3 %d %d %d")
|
||||
|
||||
|
||||
def save_ply(f, verts, faces, decimal_places: Optional[int] = None):
|
||||
|
@ -139,6 +139,90 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
||||
]
|
||||
self.assertClose(faces, torch.LongTensor(faces_expected))
|
||||
|
||||
def test_save_ply_invalid_shapes(self):
|
||||
# Invalid vertices shape
|
||||
with self.assertRaises(ValueError) as error:
|
||||
verts = torch.FloatTensor([[0.1, 0.2, 0.3, 0.4]]) # (V, 4)
|
||||
faces = torch.LongTensor([[0, 1, 2]])
|
||||
save_ply(StringIO(), verts, faces)
|
||||
expected_message = (
|
||||
"Argument 'verts' should either be empty or of shape (num_verts, 3)."
|
||||
)
|
||||
self.assertTrue(expected_message, error.exception)
|
||||
|
||||
# Invalid faces shape
|
||||
with self.assertRaises(ValueError) as error:
|
||||
verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
|
||||
faces = torch.LongTensor([[0, 1, 2, 3]]) # (F, 4)
|
||||
save_ply(StringIO(), verts, faces)
|
||||
expected_message = (
|
||||
"Argument 'faces' should either be empty or of shape (num_faces, 3)."
|
||||
)
|
||||
self.assertTrue(expected_message, error.exception)
|
||||
|
||||
def test_save_ply_invalid_indices(self):
|
||||
message_regex = "Faces have invalid indices"
|
||||
verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
|
||||
faces = torch.LongTensor([[0, 1, 2]])
|
||||
with self.assertWarnsRegex(UserWarning, message_regex):
|
||||
save_ply(StringIO(), verts, faces)
|
||||
|
||||
faces = torch.LongTensor([[-1, 0, 1]])
|
||||
with self.assertWarnsRegex(UserWarning, message_regex):
|
||||
save_ply(StringIO(), verts, faces)
|
||||
|
||||
def _test_save_load(self, verts, faces):
|
||||
f = StringIO()
|
||||
save_ply(f, verts, faces)
|
||||
f.seek(0)
|
||||
# raise Exception(f.getvalue())
|
||||
expected_verts, expected_faces = verts, faces
|
||||
if not len(expected_verts): # Always compare with a (V, 3) tensor
|
||||
expected_verts = torch.zeros(size=(0, 3), dtype=torch.float32)
|
||||
if not len(expected_faces): # Always compare with an (F, 3) tensor
|
||||
expected_faces = torch.zeros(size=(0, 3), dtype=torch.int64)
|
||||
actual_verts, actual_faces = load_ply(f)
|
||||
self.assertClose(expected_verts, actual_verts)
|
||||
self.assertClose(expected_faces, actual_faces)
|
||||
|
||||
def test_empty_save_load(self):
|
||||
# Vertices + empty faces
|
||||
verts = torch.tensor([[0.1, 0.2, 0.3]])
|
||||
faces = torch.LongTensor([])
|
||||
self._test_save_load(verts, faces)
|
||||
|
||||
faces = torch.zeros(size=(0, 3), dtype=torch.int64)
|
||||
self._test_save_load(verts, faces)
|
||||
|
||||
# Faces + empty vertices
|
||||
message_regex = "Faces have invalid indices"
|
||||
verts = torch.FloatTensor([])
|
||||
faces = torch.LongTensor([[0, 1, 2]])
|
||||
with self.assertWarnsRegex(UserWarning, message_regex):
|
||||
self._test_save_load(verts, faces)
|
||||
|
||||
verts = torch.zeros(size=(0, 3), dtype=torch.float32)
|
||||
with self.assertWarnsRegex(UserWarning, message_regex):
|
||||
self._test_save_load(verts, faces)
|
||||
|
||||
# Empty vertices + empty faces
|
||||
message_regex = "Empty 'verts' and 'faces' arguments provided"
|
||||
verts0 = torch.FloatTensor([])
|
||||
faces0 = torch.LongTensor([])
|
||||
with self.assertWarnsRegex(UserWarning, message_regex):
|
||||
self._test_save_load(verts0, faces0)
|
||||
|
||||
faces3 = torch.zeros(size=(0, 3), dtype=torch.int64)
|
||||
with self.assertWarnsRegex(UserWarning, message_regex):
|
||||
self._test_save_load(verts0, faces3)
|
||||
|
||||
verts3 = torch.zeros(size=(0, 3), dtype=torch.float32)
|
||||
with self.assertWarnsRegex(UserWarning, message_regex):
|
||||
self._test_save_load(verts3, faces0)
|
||||
|
||||
with self.assertWarnsRegex(UserWarning, message_regex):
|
||||
self._test_save_load(verts3, faces3)
|
||||
|
||||
def test_simple_save(self):
|
||||
verts = torch.tensor(
|
||||
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
|
||||
@ -378,6 +462,11 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
||||
lines2.insert(5, "property float z")
|
||||
lines2.insert(5, "property float y")
|
||||
lines2[-2] = "0 0 0"
|
||||
lines2[-1] = ""
|
||||
with self.assertRaisesRegex(ValueError, "Not enough data for face."):
|
||||
load_ply(StringIO("\n".join(lines2)))
|
||||
|
||||
lines2[-1] = "2 0 0"
|
||||
with self.assertRaisesRegex(ValueError, "Faces must have at least 3 vertices."):
|
||||
load_ply(StringIO("\n".join(lines2)))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user