From 95707fba1c8eff99757691468f440f9ce63c8027 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Thu, 7 Jan 2021 15:38:49 -0800 Subject: [PATCH] PLY pointcloud loading Summary: Allow PLY files to not contain faces. Allow loading pointclouds with color, at least encoded according to the way of some cloudcompare examples. TODO: Allow vertex normals to be read, and allow vertex colors to be written. Make the return type of load_ply something more user friendly, like a dict. Noticed in https://github.com/facebookresearch/pytorch3d/issues/209 Reviewed By: nikhilaravi Differential Revision: D22573314 fbshipit-source-id: 72ba1f7c6417f5dfc83f2ebf359eff017057635c --- pytorch3d/io/pluggable.py | 3 +- pytorch3d/io/ply_io.py | 389 +++++++++++++++++++++++++++++--------- tests/test_ply_io.py | 124 +++++++++++- 3 files changed, 423 insertions(+), 93 deletions(-) diff --git a/pytorch3d/io/pluggable.py b/pytorch3d/io/pluggable.py index b9f0e035..910a1a83 100644 --- a/pytorch3d/io/pluggable.py +++ b/pytorch3d/io/pluggable.py @@ -12,7 +12,7 @@ from pytorch3d.structures import Meshes, Pointclouds from .obj_io import MeshObjFormat from .pluggable_formats import MeshFormatInterpreter, PointcloudFormatInterpreter -from .ply_io import MeshPlyFormat +from .ply_io import MeshPlyFormat, PointcloudPlyFormat """ @@ -74,6 +74,7 @@ class IO: def register_default_formats(self) -> None: self.register_meshes_format(MeshObjFormat()) self.register_meshes_format(MeshPlyFormat()) + self.register_pointcloud_format(PointcloudPlyFormat()) def register_meshes_format(self, interpreter: MeshFormatInterpreter) -> None: """ diff --git a/pytorch3d/io/ply_io.py b/pytorch3d/io/ply_io.py index c2eded04..019041c5 100644 --- a/pytorch3d/io/ply_io.py +++ b/pytorch3d/io/ply_io.py @@ -3,23 +3,30 @@ # LICENSE file in the root directory of this source tree. -"""This module implements utility functions for loading and saving meshes.""" +""" +This module implements utility functions for loading and saving +meshes and point clouds from PLY files. +""" import itertools import struct import sys import warnings from collections import namedtuple -from io import BytesIO +from io import BytesIO, TextIOBase from pathlib import Path -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch from iopath.common.file_io import PathManager from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file -from pytorch3d.structures import Meshes +from pytorch3d.structures import Meshes, Pointclouds -from .pluggable_formats import MeshFormatInterpreter, endswith +from .pluggable_formats import ( + MeshFormatInterpreter, + PointcloudFormatInterpreter, + endswith, +) _PlyTypeData = namedtuple("_PlyTypeData", "size struct_char np_type") @@ -127,7 +134,7 @@ class _PlyHeader: self.elements: (List[_PlyElementType]) element description self.ascii: (bool) Whether in ascii format self.big_endian: (bool) (if not ascii) whether big endian - self.obj_info: (dict) arbitrary extra data + self.obj_info: (List[str]) arbitrary extra data Args: f: file-like object. @@ -136,7 +143,7 @@ class _PlyHeader: raise ValueError("Invalid file header.") seen_format = False self.elements = [] - self.obj_info = {} + self.obj_info = [] while True: line = f.readline() if isinstance(line, bytes): @@ -172,11 +179,8 @@ class _PlyHeader: if line.startswith("element"): self._parse_element(line) continue - if line.startswith("obj_info"): - items = line.split(" ") - if len(items) != 3: - raise ValueError("Invalid line: %s" % line) - self.obj_info[items[1]] = items[2] + if line.startswith("obj_info "): + self.obj_info.append(line[9:]) continue if line.startswith("property"): self._parse_property(line) @@ -736,6 +740,10 @@ def _load_ply_raw_stream(f) -> Tuple[_PlyHeader, dict]: for element in header.elements: elements[element.name] = _read_ply_element_ascii(f, element) else: + if isinstance(f, TextIOBase): + raise ValueError( + "Cannot safely read a binary ply file using a Text stream." + ) big = header.big_endian for element in header.elements: elements[element.name] = _read_ply_element_binary(f, element, big) @@ -769,7 +777,187 @@ def _load_ply_raw(f, path_manager: PathManager) -> Tuple[_PlyHeader, dict]: return header, elements -def load_ply(f, path_manager: Optional[PathManager] = None): +def _get_verts_column_indices( + vertex_head: _PlyElementType, +) -> Tuple[List[int], Optional[List[int]]]: + """ + Get the columns of verts and verts_colors in the vertex + element of a parsed ply file. + + Args: + vertex_head: as returned from load_ply_raw. + + Returns: + point_idxs: List[int] of 3 point columns. + color_idxs: List[int] of 3 color columns if they are present, + otherwise None. + """ + point_idxs: List[Optional[int]] = [None, None, None] + color_idxs: List[Optional[int]] = [None, None, None] + for i, prop in enumerate(vertex_head.properties): + if prop.list_size_type is not None: + raise ValueError("Invalid vertices in file: did not expect list.") + for j, letter in enumerate(["x", "y", "z"]): + if prop.name == letter: + point_idxs[j] = i + for j, name in enumerate(["red", "green", "blue"]): + if prop.name == name: + color_idxs[j] = i + if None in point_idxs: + raise ValueError("Invalid vertices in file.") + if None in color_idxs: + return point_idxs, None + return point_idxs, color_idxs + + +def _get_verts( + header: _PlyHeader, elements: dict +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Get the vertex locations and colors from a parsed ply file. + + Args: + header, elements: as returned from load_ply_raw. + + Returns: + verts: FloatTensor of shape (V, 3). + vertex_colors: None or FloatTensor of shape (V, 3). + """ + + vertex = elements.get("vertex", None) + if vertex is None: + raise ValueError("The ply file has no vertex element.") + if not isinstance(vertex, list): + raise ValueError("Invalid vertices in file.") + vertex_head = next(head for head in header.elements if head.name == "vertex") + point_idxs, color_idxs = _get_verts_column_indices(vertex_head) + + # Case of no vertices + if vertex_head.count == 0: + verts = torch.zeros((0, 3), dtype=torch.float32) + if color_idxs is None: + return verts, None + return verts, torch.zeros((0, 3), dtype=torch.float32) + + # Simple case where the only data is the vertices themselves + if ( + len(vertex) == 1 + and isinstance(vertex[0], np.ndarray) + and vertex[0].ndim == 2 + and vertex[0].shape[1] == 3 + ): + return _make_tensor(vertex[0], cols=3, dtype=torch.float32), None + + vertex_colors = None + + if len(vertex) == 1: + # This is the case where the whole vertex element has one type, + # so it was read as a single array and we can index straight into it. + verts = torch.tensor(vertex[0][:, point_idxs], dtype=torch.float32) + if color_idxs is not None: + vertex_colors = torch.tensor(vertex[0][:, color_idxs], dtype=torch.float32) + else: + # The vertex element is heterogeneous. It was read as several arrays, + # part by part, where a part is a set of properties with the same type. + # For each property (=column in the file), we store in + # prop_to_partnum_col its partnum (i.e. the index of what part it is + # in) and its column number (its index within its part). + prop_to_partnum_col = [ + (partnum, col) + for partnum, array in enumerate(vertex) + for col in range(array.shape[1]) + ] + verts = torch.empty(size=(vertex_head.count, 3), dtype=torch.float32) + for axis in range(3): + partnum, col = prop_to_partnum_col[point_idxs[axis]] + verts.numpy()[:, axis] = vertex[partnum][:, col] + # Note that in the previous line, we made the assignment + # as numpy arrays by casting verts. If we took the (more + # obvious) method of converting the right hand side to + # torch, then we might have an extra data copy because + # torch wants contiguity. The code would be like: + # if not vertex[partnum].flags["C_CONTIGUOUS"]: + # vertex[partnum] = np.ascontiguousarray(vertex[partnum]) + # verts[:, axis] = torch.tensor((vertex[partnum][:, col])) + if color_idxs is not None: + vertex_colors = torch.empty( + size=(vertex_head.count, 3), dtype=torch.float32 + ) + for color in range(3): + partnum, col = prop_to_partnum_col[color_idxs[color]] + vertex_colors.numpy()[:, color] = vertex[partnum][:, col] + + return verts, vertex_colors + + +def _load_ply( + f, *, path_manager: PathManager, return_vertex_colors: bool = False +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Load the data from a .ply file. + + Args: + f: A binary or text file-like object (with methods read, readline, + tell and seek), a pathlib path or a string containing a file name. + If the ply file is in the binary ply format rather than the text + ply format, then a text stream is not supported. + It is easiest to use a binary stream in all cases. + path_manager: PathManager for loading if f is a str. + return_vertex_colors: whether to return vertex colors. + + Returns: + verts: FloatTensor of shape (V, 3). + faces: None or LongTensor of vertex indices, shape (F, 3). + vertex_colors: None or FloatTensor of shape (V, 3), only if requested + """ + header, elements = _load_ply_raw(f, path_manager=path_manager) + + verts, vertex_colors = _get_verts(header, elements) + + face = elements.get("face", None) + if face is not None: + face_head = next(head for head in header.elements if head.name == "face") + if ( + len(face_head.properties) != 1 + or face_head.properties[0].list_size_type is None + ): + raise ValueError("Unexpected form of faces data.") + # face_head.properties[0].name is usually "vertex_index" or "vertex_indices" + # but we don't need to enforce this. + + if face is None: + faces = None + elif not len(face): + # pyre is happier when this condition is not joined to the + # previous one with `or`. + faces = None + elif isinstance(face, np.ndarray) and face.ndim == 2: # Homogeneous elements + if face.shape[1] < 3: + raise ValueError("Faces must have at least 3 vertices.") + face_arrays = [face[:, [0, i + 1, i + 2]] for i in range(face.shape[1] - 2)] + faces = torch.LongTensor(np.vstack(face_arrays)) + else: + face_list = [] + for face_item in face: + if face_item.ndim != 1: + raise ValueError("Bad face data.") + if face_item.shape[0] < 3: + raise ValueError("Faces must have at least 3 vertices.") + for i in range(face_item.shape[0] - 2): + face_list.append([face_item[0], face_item[i + 1], face_item[i + 2]]) + faces = torch.tensor(face_list, dtype=torch.int64) + + if faces is not None: + _check_faces_indices(faces, max_index=verts.shape[0]) + + if return_vertex_colors: + return verts, faces, vertex_colors + return verts, faces, None + + +def load_ply( + f, *, path_manager: Optional[PathManager] = None +) -> Tuple[torch.Tensor, torch.Tensor]: """ Load the data from a .ply file. @@ -809,72 +997,27 @@ def load_ply(f, path_manager: Optional[PathManager] = None): It is easiest to use a binary stream in all cases. path_manager: PathManager for loading if f is a str. - Returns: verts: FloatTensor of shape (V, 3). faces: LongTensor of vertex indices, shape (F, 3). """ + if path_manager is None: path_manager = PathManager() - header, elements = _load_ply_raw(f, path_manager=path_manager) + verts, faces, _ = _load_ply(f, path_manager=path_manager) + if faces is None: + faces = torch.zeros(0, 3, dtype=torch.int64) - vertex = elements.get("vertex", None) - if vertex is None: - raise ValueError("The ply file has no vertex element.") - - face = elements.get("face", None) - if face is None: - raise ValueError("The ply file has no face element.") - - if not isinstance(vertex, list) or len(vertex) > 1: - raise ValueError("Invalid vertices in file.") - - if len(vertex): - vertex0 = vertex[0] - if len(vertex0) and ( - not isinstance(vertex0, np.ndarray) - or vertex0.ndim != 2 - or vertex0.shape[1] != 3 - ): - raise ValueError("Invalid vertices in file.") - else: - vertex0 = [] - verts = _make_tensor(vertex0, cols=3, dtype=torch.float32) - - face_head = next(head for head in header.elements if head.name == "face") - if len(face_head.properties) != 1 or face_head.properties[0].list_size_type is None: - raise ValueError("Unexpected form of faces data.") - # face_head.properties[0].name is usually "vertex_index" or "vertex_indices" - # but we don't need to enforce this. - - if not len(face): - faces = torch.zeros((0, 3), dtype=torch.int64) - elif isinstance(face, np.ndarray) and face.ndim == 2: # Homogeneous elements - if face.shape[1] < 3: - raise ValueError("Faces must have at least 3 vertices.") - face_arrays = [face[:, [0, i + 1, i + 2]] for i in range(face.shape[1] - 2)] - faces = torch.LongTensor(np.vstack(face_arrays)) - else: - face_list = [] - for face_item in face: - if face_item.ndim != 1: - raise ValueError("Bad face data.") - if face_item.shape[0] < 3: - raise ValueError("Faces must have at least 3 vertices.") - for i in range(face_item.shape[0] - 2): - face_list.append([face_item[0], face_item[i + 1], face_item[i + 2]]) - # pyre-fixme[6]: Expected `dtype` for 3rd param but got `Type[torch.int64]`. - faces = _make_tensor(face_list, cols=3, dtype=torch.int64) - - _check_faces_indices(faces, max_index=verts.shape[0]) return verts, faces def _save_ply( f, + *, verts: torch.Tensor, - faces: torch.LongTensor, + faces: Optional[torch.LongTensor], verts_normals: torch.Tensor, + verts_colors: torch.Tensor, ascii: bool, decimal_places: Optional[int] = None, ) -> None: @@ -890,10 +1033,14 @@ def _save_ply( decimal_places: Number of decimal places for saving if ascii=True. """ assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3) - assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3) + if faces is not None: + assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3) assert not len(verts_normals) or ( verts_normals.dim() == 2 and verts_normals.size(1) == 3 ) + assert not len(verts_colors) or ( + verts_colors.dim() == 2 and verts_colors.size(1) == 3 + ) if ascii: f.write(b"ply\nformat ascii 1.0\n") @@ -909,15 +1056,20 @@ def _save_ply( f.write(b"property float nx\n") f.write(b"property float ny\n") f.write(b"property float nz\n") - f.write(f"element face {faces.shape[0]}\n".encode("ascii")) - f.write(b"property list uchar int vertex_index\n") + if verts_colors.numel() > 0: + f.write(b"property float red\n") + f.write(b"property float green\n") + f.write(b"property float blue\n") + if len(verts) and faces is not None: + f.write(f"element face {faces.shape[0]}\n".encode("ascii")) + f.write(b"property list uchar int vertex_index\n") f.write(b"end_header\n") - if not (len(verts) or len(faces)): - warnings.warn("Empty 'verts' and 'faces' arguments provided") + if not (len(verts)): + warnings.warn("Empty 'verts' provided") return - vert_data = torch.cat((verts, verts_normals), dim=1).detach().numpy() + vert_data = torch.cat((verts, verts_normals, verts_colors), dim=1).detach().numpy() if ascii: if decimal_places is None: float_str = "%f" @@ -932,21 +1084,22 @@ def _save_ply( else: vert_data.tofile(f) - faces_array = faces.detach().numpy() + if faces is not None: + faces_array = faces.detach().numpy() - _check_faces_indices(faces, max_index=verts.shape[0]) + _check_faces_indices(faces, max_index=verts.shape[0]) - if len(faces_array): - if ascii: - np.savetxt(f, faces_array, "3 %d %d %d") - else: - # rows are 13 bytes: a one-byte 3 followed by three four-byte face indices. - faces_uints = np.full((len(faces_array), 13), 3, dtype=np.uint8) - faces_uints[:, 1:] = faces_array.astype(np.uint32).view(np.uint8) - if isinstance(f, BytesIO): - f.write(faces_uints.tobytes()) + if len(faces_array): + if ascii: + np.savetxt(f, faces_array, "3 %d %d %d") else: - faces_uints.tofile(f) + # rows are 13 bytes: a one-byte 3 followed by three four-byte face indices. + faces_uints = np.full((len(faces_array), 13), 3, dtype=np.uint8) + faces_uints[:, 1:] = faces_array.astype(np.uint32).view(np.uint8) + if isinstance(f, BytesIO): + f.write(faces_uints.tobytes()) + else: + faces_uints.tofile(f) def save_ply( @@ -977,13 +1130,16 @@ def save_ply( if verts_normals is None else verts_normals ) - faces = torch.LongTensor([]) if faces is None else faces if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3): message = "Argument 'verts' should either be empty or of shape (num_verts, 3)." raise ValueError(message) - if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3): + if ( + faces is not None + and len(faces) + and not (faces.dim() == 2 and faces.size(1) == 3) + ): message = "Argument 'faces' should either be empty or of shape (num_faces, 3)." raise ValueError(message) @@ -995,10 +1151,20 @@ def save_ply( message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)." raise ValueError(message) + verts_colors = torch.FloatTensor([]) + if path_manager is None: path_manager = PathManager() with _open_file(f, path_manager, "wb") as f: - _save_ply(f, verts, faces, verts_normals, ascii, decimal_places) + _save_ply( + f, + verts=verts, + faces=faces, + verts_normals=verts_normals, + verts_colors=verts_colors, + ascii=ascii, + decimal_places=decimal_places, + ) class MeshPlyFormat(MeshFormatInterpreter): @@ -1044,3 +1210,54 @@ class MeshPlyFormat(MeshFormatInterpreter): path_manager=path_manager, ) return True + + +class PointcloudPlyFormat(PointcloudFormatInterpreter): + def __init__(self): + self.known_suffixes = (".ply",) + + def read( + self, + path: Union[str, Path], + device, + path_manager: PathManager, + **kwargs, + ) -> Optional[Pointclouds]: + if not endswith(path, self.known_suffixes): + return None + + verts, faces, features = _load_ply( + f=path, path_manager=path_manager, return_vertex_colors=True + ) + verts = verts.to(device) + if features is None: + pointcloud = Pointclouds(points=[verts]) + else: + pointcloud = Pointclouds(points=[verts], features=[features.to(device)]) + return pointcloud + + def save( + self, + data: Pointclouds, + path: Union[str, Path], + path_manager: PathManager, + binary: Optional[bool], + decimal_places: Optional[int] = None, + **kwargs, + ) -> bool: + if not endswith(path, self.known_suffixes): + return False + + points = data.points_list()[0] + features = data.features_list()[0] + with _open_file(path, path_manager, "wb") as f: + _save_ply( + f=f, + verts=points, + verts_colors=features, + verts_normals=torch.FloatTensor([]), + faces=None, + ascii=binary is False, + decimal_places=decimal_places, + ) + return True diff --git a/tests/test_ply_io.py b/tests/test_ply_io.py index 2c24f6ee..abc3f39d 100644 --- a/tests/test_ply_io.py +++ b/tests/test_ply_io.py @@ -12,6 +12,7 @@ from common_testing import TestCaseMixin from iopath.common.file_io import PathManager from pytorch3d.io import IO from pytorch3d.io.ply_io import load_ply, save_ply +from pytorch3d.structures import Pointclouds from pytorch3d.utils import torus @@ -229,9 +230,13 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): expected_verts = torch.zeros(size=(0, 3), dtype=torch.float32) if not len(expected_faces): # Always compare with an (F, 3) tensor expected_faces = torch.zeros(size=(0, 3), dtype=torch.int64) + actual_verts, actual_faces = load_ply(f) self.assertClose(expected_verts, actual_verts) - self.assertClose(expected_faces, actual_faces) + if len(actual_verts): + self.assertClose(expected_faces, actual_faces) + else: + self.assertEqual(actual_faces.numel(), 0) def test_normals_save(self): verts = torch.tensor( @@ -255,9 +260,10 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): self._test_save_load(verts, faces) # Faces + empty vertices - message_regex = "Faces have invalid indices" + # => We don't save the faces verts = torch.FloatTensor([]) faces = torch.LongTensor([[0, 1, 2]]) + message_regex = "Empty 'verts' provided" with self.assertWarnsRegex(UserWarning, message_regex): self._test_save_load(verts, faces) @@ -266,7 +272,6 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): self._test_save_load(verts, faces) # Empty vertices + empty faces - message_regex = "Empty 'verts' and 'faces' arguments provided" verts0 = torch.FloatTensor([]) faces0 = torch.LongTensor([]) with self.assertWarnsRegex(UserWarning, message_regex): @@ -354,6 +359,115 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): self.assertClose(x, X) self.assertClose(yz, YZ.reshape(8, 2)) + def test_load_cloudcompare_pointcloud(self): + """ + Test loading a pointcloud styled like some cloudcompare output. + cloudcompare is an open source 3D point cloud processing software. + """ + header = "\n".join( + [ + "ply", + "format binary_little_endian 1.0", + "obj_info Not a key-value pair!", + "element vertex 8", + "property double x", + "property double y", + "property double z", + "property uchar red", + "property uchar green", + "property uchar blue", + "property float my_Favorite", + "end_header", + "", + ] + ).encode("ascii") + data = struct.pack("<" + "dddBBBf" * 8, *range(56)) + io = IO() + with NamedTemporaryFile(mode="wb", suffix=".ply") as f: + f.write(header) + f.write(data) + f.flush() + pointcloud = io.load_pointcloud(f.name) + + self.assertClose( + pointcloud.points_padded()[0], + torch.FloatTensor([0, 1, 2]) + 7 * torch.arange(8)[:, None], + ) + self.assertClose( + pointcloud.features_padded()[0], + torch.FloatTensor([3, 4, 5]) + 7 * torch.arange(8)[:, None], + ) + + def test_save_pointcloud(self): + header = "\n".join( + [ + "ply", + "format binary_little_endian 1.0", + "element vertex 8", + "property float x", + "property float y", + "property float z", + "property float red", + "property float green", + "property float blue", + "end_header", + "", + ] + ).encode("ascii") + data = struct.pack("<" + "f" * 48, *range(48)) + points = torch.FloatTensor([0, 1, 2]) + 6 * torch.arange(8)[:, None] + features = torch.FloatTensor([3, 4, 5]) + 6 * torch.arange(8)[:, None] + pointcloud = Pointclouds(points=[points], features=[features]) + + io = IO() + with NamedTemporaryFile(mode="rb", suffix=".ply") as f: + io.save_pointcloud(data=pointcloud, path=f.name) + f.flush() + f.seek(0) + actual_data = f.read() + reloaded_pointcloud = io.load_pointcloud(f.name) + + self.assertEqual(header + data, actual_data) + self.assertClose(reloaded_pointcloud.points_list()[0], points) + self.assertClose(reloaded_pointcloud.features_list()[0], features) + + with NamedTemporaryFile(mode="r", suffix=".ply") as f: + io.save_pointcloud(data=pointcloud, path=f.name, binary=False) + reloaded_pointcloud2 = io.load_pointcloud(f.name) + self.assertEqual(f.readline(), "ply\n") + self.assertEqual(f.readline(), "format ascii 1.0\n") + self.assertClose(reloaded_pointcloud2.points_list()[0], points) + self.assertClose(reloaded_pointcloud2.features_list()[0], features) + + def test_load_pointcloud_bad_order(self): + """ + Ply file with a strange property order + """ + file = "\n".join( + [ + "ply", + "format ascii 1.0", + "element vertex 1", + "property uchar green", + "property float x", + "property float z", + "property uchar red", + "property float y", + "property uchar blue", + "end_header", + "1 2 3 4 5 6", + ] + ) + + io = IO() + pointcloud_gpu = io.load_pointcloud(StringIO(file), device="cuda:0") + self.assertEqual(pointcloud_gpu.device, torch.device("cuda:0")) + pointcloud = pointcloud_gpu.to(torch.device("cpu")) + expected_points = torch.tensor([[[2, 5, 3]]], dtype=torch.float32) + expected_features = torch.tensor([[[4, 1, 6]]], dtype=torch.float32) + self.assertClose(pointcloud.points_padded(), expected_points) + self.assertClose(pointcloud.features_padded(), expected_features) + def test_load_simple_binary(self): for big_endian in [True, False]: verts = ( @@ -569,9 +683,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): with self.assertRaisesRegex(ValueError, "Inconsistent data for vertex."): _load_ply_raw(StringIO("\n".join(lines2))) - # Now make the ply file actually be readable as a Mesh - - with self.assertRaisesRegex(ValueError, "The ply file has no face element."): + with self.assertRaisesRegex(ValueError, "Invalid vertices in file."): load_ply(StringIO("\n".join(lines))) lines2 = lines.copy()