diff --git a/pytorch3d/io/ply_io.py b/pytorch3d/io/ply_io.py index fe215ead..8b13b051 100644 --- a/pytorch3d/io/ply_io.py +++ b/pytorch3d/io/ply_io.py @@ -213,6 +213,17 @@ class _PlyHeader: self.elements.append(_PlyElementType(items[1], count)) +def _make_tensor(data, cols: int, dtype: torch.dtype) -> torch.Tensor: + """ + Return a 2D tensor with the specified cols and dtype filled with data, + even when data is empty. + """ + if not len(data): + return torch.zeros((0, cols), dtype=dtype) + + return torch.tensor(data, dtype=dtype) + + def _read_ply_fixed_size_element_ascii(f, definition: _PlyElementType): """ Given an element which has no lists and one type, read the @@ -227,10 +238,17 @@ def _read_ply_fixed_size_element_ascii(f, definition: _PlyElementType): values. There is one column for each property. """ np_type = _PLY_TYPES[definition.properties[0].data_type].np_type - data = np.loadtxt( - f, dtype=np_type, comments=None, ndmin=2, max_rows=definition.count - ) - if data.shape[1] != len(definition.properties): + old_offset = f.tell() + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message=".* Empty input file.*", category=UserWarning + ) + data = np.loadtxt( + f, dtype=np_type, comments=None, ndmin=2, max_rows=definition.count + ) + if not len(data): # np.loadtxt() seeks even on empty data + f.seek(old_offset) + if definition.count and data.shape[1] != len(definition.properties): raise ValueError("Inconsistent data for %s." % definition.name) if data.shape[0] != definition.count: raise ValueError("Not enough data for %s." % definition.name) @@ -252,7 +270,7 @@ def _try_read_ply_constant_list_ascii(f, definition: _PlyElementType): data. The rows are the different values. Otherwise None. """ np_type = _PLY_TYPES[definition.properties[0].data_type].np_type - start_point = f.tell() + old_offset = f.tell() try: with warnings.catch_warnings(): warnings.filterwarnings( @@ -262,8 +280,10 @@ def _try_read_ply_constant_list_ascii(f, definition: _PlyElementType): f, dtype=np_type, comments=None, ndmin=2, max_rows=definition.count ) except ValueError: - f.seek(start_point) + f.seek(old_offset) return None + if not len(data): # np.loadtxt() seeks even on empty data + f.seek(old_offset) if (data.shape[1] - 1 != data[:, 0]).any(): msg = "A line of %s data did not have the specified length." raise ValueError(msg % definition.name) @@ -442,7 +462,7 @@ def _try_read_ply_constant_list_binary( [length] = length_struct.unpack(bytes_data) return length - start_point = f.tell() + old_offset = f.tell() length = get_length() np_type = _PLY_TYPES[definition.properties[0].data_type].np_type @@ -459,7 +479,7 @@ def _try_read_ply_constant_list_binary( if i + 1 == definition.count: break if length != get_length(): - f.seek(start_point) + f.seek(old_offset) return None if (sys.byteorder == "big") != big_endian: output = output.byteswap() @@ -644,20 +664,25 @@ def load_ply(f): if face is None: raise ValueError("The ply file has no face element.") - if not isinstance(vertex, np.ndarray) or vertex.ndim != 2 or vertex.shape[1] != 3: + if len(vertex) and ( + not isinstance(vertex, np.ndarray) or vertex.ndim != 2 or vertex.shape[1] != 3 + ): raise ValueError("Invalid vertices in file.") - verts = torch.tensor(vertex, dtype=torch.float32) + verts = _make_tensor(vertex, cols=3, dtype=torch.float32) face_head = next(head for head in header.elements if head.name == "face") if len(face_head.properties) != 1 or face_head.properties[0].list_size_type is None: raise ValueError("Unexpected form of faces data.") # face_head.properties[0].name is usually "vertex_index" or "vertex_indices" # but we don't need to enforce this. - if isinstance(face, np.ndarray) and face.ndim == 2: + + if not len(face): + faces = torch.zeros(size=(0, 3), dtype=torch.int64) + elif isinstance(face, np.ndarray) and face.ndim == 2: # Homogeneous elements if face.shape[1] < 3: raise ValueError("Faces must have at least 3 vertices.") face_arrays = [face[:, [0, i + 1, i + 2]] for i in range(face.shape[1] - 2)] - faces = torch.tensor(np.vstack(face_arrays), dtype=torch.int64) + faces = torch.LongTensor(np.vstack(face_arrays)) else: face_list = [] for face_item in face: @@ -667,7 +692,10 @@ def load_ply(f): raise ValueError("Faces must have at least 3 vertices.") for i in range(face_item.shape[0] - 2): face_list.append([face_item[0], face_item[i + 1], face_item[i + 2]]) - faces = torch.tensor(face_list, dtype=torch.int64) + faces = _make_tensor(face_list, cols=3, dtype=torch.int64) + + if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0): + warnings.warn("Faces have invalid indices") return verts, faces @@ -682,6 +710,16 @@ def _save_ply(f, verts, faces, decimal_places: Optional[int]): faces: LongTensor of shape (F, 3) giving faces. decimal_places: Number of decimal places for saving. """ + if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3): + raise ValueError( + "Argument 'verts' should either be empty or of shape (num_verts, 3)." + ) + + if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3): + raise ValueError( + "Argument 'faces' should either be empty or of shape (num_faces, 3)." + ) + print("ply\nformat ascii 1.0", file=f) print(f"element vertex {verts.shape[0]}", file=f) print("property float x", file=f) @@ -691,13 +729,25 @@ def _save_ply(f, verts, faces, decimal_places: Optional[int]): print("property list uchar int vertex_index", file=f) print("end_header", file=f) + if not (len(verts) or len(faces)): + warnings.warn("Empty 'verts' and 'faces' arguments provided") + return + if decimal_places is None: float_str = "%f" else: float_str = "%" + ".%df" % decimal_places - np.savetxt(f, verts.detach().numpy(), float_str) - np.savetxt(f, faces.detach().numpy(), "3 %d %d %d") + verts_array = verts.detach().numpy() + np.savetxt(f, verts_array, float_str) + + faces_array = faces.detach().numpy() + + if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0): + warnings.warn("Faces have invalid indices") + + if len(faces_array): + np.savetxt(f, faces_array, "3 %d %d %d") def save_ply(f, verts, faces, decimal_places: Optional[int] = None): diff --git a/tests/test_ply_io.py b/tests/test_ply_io.py index 34a48bde..9d7e058b 100644 --- a/tests/test_ply_io.py +++ b/tests/test_ply_io.py @@ -139,6 +139,90 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): ] self.assertClose(faces, torch.LongTensor(faces_expected)) + def test_save_ply_invalid_shapes(self): + # Invalid vertices shape + with self.assertRaises(ValueError) as error: + verts = torch.FloatTensor([[0.1, 0.2, 0.3, 0.4]]) # (V, 4) + faces = torch.LongTensor([[0, 1, 2]]) + save_ply(StringIO(), verts, faces) + expected_message = ( + "Argument 'verts' should either be empty or of shape (num_verts, 3)." + ) + self.assertTrue(expected_message, error.exception) + + # Invalid faces shape + with self.assertRaises(ValueError) as error: + verts = torch.FloatTensor([[0.1, 0.2, 0.3]]) + faces = torch.LongTensor([[0, 1, 2, 3]]) # (F, 4) + save_ply(StringIO(), verts, faces) + expected_message = ( + "Argument 'faces' should either be empty or of shape (num_faces, 3)." + ) + self.assertTrue(expected_message, error.exception) + + def test_save_ply_invalid_indices(self): + message_regex = "Faces have invalid indices" + verts = torch.FloatTensor([[0.1, 0.2, 0.3]]) + faces = torch.LongTensor([[0, 1, 2]]) + with self.assertWarnsRegex(UserWarning, message_regex): + save_ply(StringIO(), verts, faces) + + faces = torch.LongTensor([[-1, 0, 1]]) + with self.assertWarnsRegex(UserWarning, message_regex): + save_ply(StringIO(), verts, faces) + + def _test_save_load(self, verts, faces): + f = StringIO() + save_ply(f, verts, faces) + f.seek(0) + # raise Exception(f.getvalue()) + expected_verts, expected_faces = verts, faces + if not len(expected_verts): # Always compare with a (V, 3) tensor + expected_verts = torch.zeros(size=(0, 3), dtype=torch.float32) + if not len(expected_faces): # Always compare with an (F, 3) tensor + expected_faces = torch.zeros(size=(0, 3), dtype=torch.int64) + actual_verts, actual_faces = load_ply(f) + self.assertClose(expected_verts, actual_verts) + self.assertClose(expected_faces, actual_faces) + + def test_empty_save_load(self): + # Vertices + empty faces + verts = torch.tensor([[0.1, 0.2, 0.3]]) + faces = torch.LongTensor([]) + self._test_save_load(verts, faces) + + faces = torch.zeros(size=(0, 3), dtype=torch.int64) + self._test_save_load(verts, faces) + + # Faces + empty vertices + message_regex = "Faces have invalid indices" + verts = torch.FloatTensor([]) + faces = torch.LongTensor([[0, 1, 2]]) + with self.assertWarnsRegex(UserWarning, message_regex): + self._test_save_load(verts, faces) + + verts = torch.zeros(size=(0, 3), dtype=torch.float32) + with self.assertWarnsRegex(UserWarning, message_regex): + self._test_save_load(verts, faces) + + # Empty vertices + empty faces + message_regex = "Empty 'verts' and 'faces' arguments provided" + verts0 = torch.FloatTensor([]) + faces0 = torch.LongTensor([]) + with self.assertWarnsRegex(UserWarning, message_regex): + self._test_save_load(verts0, faces0) + + faces3 = torch.zeros(size=(0, 3), dtype=torch.int64) + with self.assertWarnsRegex(UserWarning, message_regex): + self._test_save_load(verts0, faces3) + + verts3 = torch.zeros(size=(0, 3), dtype=torch.float32) + with self.assertWarnsRegex(UserWarning, message_regex): + self._test_save_load(verts3, faces0) + + with self.assertWarnsRegex(UserWarning, message_regex): + self._test_save_load(verts3, faces3) + def test_simple_save(self): verts = torch.tensor( [[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32 @@ -378,6 +462,11 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): lines2.insert(5, "property float z") lines2.insert(5, "property float y") lines2[-2] = "0 0 0" + lines2[-1] = "" + with self.assertRaisesRegex(ValueError, "Not enough data for face."): + load_ply(StringIO("\n".join(lines2))) + + lines2[-1] = "2 0 0" with self.assertRaisesRegex(ValueError, "Faces must have at least 3 vertices."): load_ply(StringIO("\n".join(lines2)))