diff --git a/pytorch3d/io/off_io.py b/pytorch3d/io/off_io.py new file mode 100644 index 00000000..30244370 --- /dev/null +++ b/pytorch3d/io/off_io.py @@ -0,0 +1,488 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +""" +This module implements utility functions for loading and saving +meshes as .off files. +""" +import warnings +from pathlib import Path +from typing import Optional, Tuple, Union, cast + +import numpy as np +import torch +from iopath.common.file_io import PathManager +from pytorch3d.io.utils import _check_faces_indices, _open_file +from pytorch3d.renderer import TexturesAtlas, TexturesVertex +from pytorch3d.structures import Meshes + +from .pluggable_formats import MeshFormatInterpreter, endswith + + +def _is_line_empty(line: Union[str, bytes]) -> bool: + """ + Returns whether line is not relevant in an OFF file. + """ + line = line.strip() + return len(line) == 0 or line[:1] == b"#" + + +def _count_next_line_periods(file) -> int: + """ + Returns the number of . characters before any # on the next + meaningful line. + """ + old_offset = file.tell() + line = file.readline() + while _is_line_empty(line): + line = file.readline() + if len(line) == 0: + raise ValueError("Premature end of file") + + contents = line.split(b"#")[0] + count = contents.count(b".") + file.seek(old_offset) + return count + + +def _read_faces_lump( + file, n_faces: int, n_colors: Optional[int] +) -> Optional[Tuple[np.ndarray, int, Optional[np.ndarray]]]: + """ + Parse n_faces faces and faces_colors from the file, + if they all have the same number of vertices. + This is used in two ways. + 1) To try to read all faces. + 2) To read faces one-by-one if that failed. + + Args: + file: file-like object being read. + n_faces: The known number of faces yet to read. + n_colors: The number of colors if known already. + + Returns: + - 2D numpy array of faces + - number of colors found + - 2D numpy array of face colors if found. + of None if there are faces with different numbers of vertices. + """ + if n_faces == 0: + return np.array([[]]), 0, None + old_offset = file.tell() + try: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message=".* Empty input file.*", category=UserWarning + ) + data = np.loadtxt(file, dtype=np.float32, ndmin=2, max_rows=n_faces) + except ValueError as e: + if n_faces > 1 and "Wrong number of columns" in e.args[0]: + file.seek(old_offset) + return None + raise ValueError("Not enough face data.") + + if len(data) != n_faces: + raise ValueError("Not enough face data.") + face_size = int(data[0, 0]) + if (data[:, 0] != face_size).any(): + msg = "A line of face data did not have the specified length." + raise ValueError(msg) + if face_size < 3: + raise ValueError("Faces must have at least 3 vertices.") + + n_colors_found = data.shape[1] - 1 - face_size + if n_colors is not None and n_colors_found != n_colors: + raise ValueError("Number of colors differs between faces.") + n_colors = n_colors_found + if n_colors not in [0, 3, 4]: + raise ValueError("Unexpected number of colors.") + + face_raw_data = data[:, 1 : 1 + face_size].astype("int64") + if face_size == 3: + face_data = face_raw_data + else: + face_arrays = [ + face_raw_data[:, [0, i + 1, i + 2]] for i in range(face_size - 2) + ] + face_data = np.vstack(face_arrays) + + if n_colors == 0: + return face_data, 0, None + colors = data[:, 1 + face_size :] + if face_size == 3: + return face_data, n_colors, colors + return face_data, n_colors, np.tile(colors, (face_size - 2, 1)) + + +def _read_faces( + file, n_faces: int +) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]: + """ + Returns faces and face colors from the file. + + Args: + file: file-like object being read. + n_faces: The known number of faces. + + Returns: + 2D numpy arrays of faces and face colors, or None for each if + they are not present. + """ + if n_faces == 0: + return None, None + + color_is_int = 0 == _count_next_line_periods(file) + color_scale = 1 / 255.0 if color_is_int else 1 + + faces_ncolors_colors = _read_faces_lump(file, n_faces=n_faces, n_colors=None) + if faces_ncolors_colors is not None: + faces, _, colors = faces_ncolors_colors + if colors is None: + return faces, None + return faces, colors * color_scale + + faces_list, colors_list = [], [] + n_colors = None + for _ in range(n_faces): + faces_ncolors_colors = _read_faces_lump(file, n_faces=1, n_colors=n_colors) + faces_found, n_colors, colors_found = cast( + Tuple[np.ndarray, int, Optional[np.ndarray]], faces_ncolors_colors + ) + faces_list.append(faces_found) + colors_list.append(colors_found) + faces = np.vstack(faces_list) + if n_colors == 0: + colors = None + else: + colors = np.vstack(colors_list) * color_scale + return faces, colors + + +def _read_verts(file, n_verts: int) -> Tuple[np.ndarray, Optional[np.ndarray]]: + """ + Returns verts and vertex colors from the file. + + Args: + file: file-like object being read. + n_verts: The known number of faces. + + Returns: + 2D numpy arrays of verts and (if present) + vertex colors. + """ + + color_is_int = 3 == _count_next_line_periods(file) + color_scale = 1 / 255.0 if color_is_int else 1 + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message=".* Empty input file.*", category=UserWarning + ) + data = np.loadtxt(file, dtype=np.float32, ndmin=2, max_rows=n_verts) + if data.shape[0] != n_verts: + raise ValueError("Not enough vertex data.") + if data.shape[1] not in [3, 6, 7]: + raise ValueError("Bad vertex data.") + + if data.shape[1] == 3: + return data, None + return data[:, :3], data[:, 3:] * color_scale # [] + + +def _load_off_stream(file) -> dict: + """ + Load the data from a stream of an .off file. + + Example .off file format: + + off + 8 6 1927 { number of vertices, faces, and (not used) edges } + # comment { comments with # sign } + 0 0 0 { start of vertex list } + 0 0 1 + 0 1 1 + 0 1 0 + 1 0 0 + 1 0 1 + 1 1 1 + 1 1 0 + 4 0 1 2 3 { start of face list } + 4 7 6 5 4 + 4 0 4 5 1 + 4 1 5 6 2 + 4 2 6 7 3 + 4 3 7 4 0 + + Args: + file: A binary file-like object (with methods read, readline, + tell and seek). + + Returns dictionary possibly containing: + verts: (always present) FloatTensor of shape (V, 3). + verts_colors: FloatTensor of shape (V, C) where C is 3 or 4. + faces: LongTensor of vertex indices, split into triangles, shape (F, 3). + faces_colors: FloatTensor of shape (F, C), where C is 3 or 4. + """ + header = file.readline() + if header.lower() in (b"off\n", b"off\r\n", "off\n"): + header = file.readline() + + while _is_line_empty(header): + header = file.readline() + + items = header.split(b" ") + if len(items) and items[0].lower() in ("off", b"off"): + items = items[1:] + if len(items) < 3: + raise ValueError("Invalid counts line: %s" % header) + + try: + n_verts = int(items[0]) + except ValueError: + raise ValueError("Invalid counts line: %s" % header) + try: + n_faces = int(items[1]) + except ValueError: + raise ValueError("Invalid counts line: %s" % header) + + if (len(items) > 3 and not items[3].startswith("#")) or n_verts < 0 or n_faces < 0: + raise ValueError("Invalid counts line: %s" % header) + + verts, verts_colors = _read_verts(file, n_verts) + faces, faces_colors = _read_faces(file, n_faces) + + end = file.read().strip() + if len(end) != 0: + raise ValueError("Extra data at end of file: " + str(end[:20])) + + out = {"verts": verts} + if verts_colors is not None: + out["verts_colors"] = verts_colors + if faces is not None: + out["faces"] = faces + if faces_colors is not None: + out["faces_colors"] = faces_colors + return out + + +def _write_off_data( + file, + verts: torch.Tensor, + verts_colors: Optional[torch.Tensor] = None, + faces: Optional[torch.LongTensor] = None, + faces_colors: Optional[torch.Tensor] = None, + decimal_places: Optional[int] = None, +) -> None: + """ + Internal implementation for saving 3D data to a .off file. + + Args: + file: Binary file object to which the 3D data should be written. + verts: FloatTensor of shape (V, 3) giving vertex coordinates. + verts_colors: FloatTensor of shape (V, C) giving vertex colors where C is 3 or 4. + faces: LongTensor of shape (F, 3) giving faces. + faces_colors: FloatTensor of shape (V, C) giving face colors where C is 3 or 4. + decimal_places: Number of decimal places for saving. + """ + nfaces = 0 if faces is None else faces.shape[0] + file.write(f"off\n{verts.shape[0]} {nfaces} 0\n".encode("ascii")) + + if verts_colors is not None: + verts = torch.cat((verts, verts_colors), dim=1) + if decimal_places is None: + float_str = "%f" + else: + float_str = "%" + ".%df" % decimal_places + np.savetxt(file, verts.cpu().detach().numpy(), float_str) + + if faces is not None: + _check_faces_indices(faces, max_index=verts.shape[0]) + + if faces_colors is not None: + face_data = torch.cat( + [ + cast(torch.Tensor, faces).cpu().to(torch.float64), + faces_colors.detach().cpu().to(torch.float64), + ], + dim=1, + ) + format = "3 %d %d %d" + " %f" * faces_colors.shape[1] + np.savetxt(file, face_data.numpy(), format) + elif faces is not None: + np.savetxt(file, faces.cpu().detach().numpy(), "3 %d %d %d") + + +def _save_off( + file, + *, + verts: torch.Tensor, + verts_colors: Optional[torch.Tensor] = None, + faces: Optional[torch.LongTensor] = None, + faces_colors: Optional[torch.Tensor] = None, + decimal_places: Optional[int] = None, + path_manager: PathManager, +) -> None: + """ + Save a mesh to an ascii .off file. + + Args: + file: File (or path) to which the mesh should be written. + verts: FloatTensor of shape (V, 3) giving vertex coordinates. + verts_colors: FloatTensor of shape (V, C) giving vertex colors where C is 3 or 4. + faces: LongTensor of shape (F, 3) giving faces. + faces_colors: FloatTensor of shape (V, C) giving face colors where C is 3 or 4. + decimal_places: Number of decimal places for saving. + """ + if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3): + message = "Argument 'verts' should either be empty or of shape (num_verts, 3)." + raise ValueError(message) + + if verts_colors is not None and 0 == len(verts_colors): + verts_colors = None + if faces_colors is not None and 0 == len(faces_colors): + faces_colors = None + if faces is not None and 0 == len(faces): + faces = None + + if verts_colors is not None: + if not (verts_colors.dim() == 2 and verts_colors.size(1) in [3, 4]): + message = "verts_colors should have shape (num_faces, C)." + raise ValueError(message) + if verts_colors.shape[0] != verts.shape[0]: + message = "verts_colors should have the same length as verts." + raise ValueError(message) + + if faces is not None and not (faces.dim() == 2 and faces.size(1) == 3): + message = "Argument 'faces' if present should have shape (num_faces, 3)." + raise ValueError(message) + if faces_colors is not None and faces is None: + message = "Cannot have face colors without faces" + raise ValueError(message) + + if faces_colors is not None: + if not (faces_colors.dim() == 2 and faces_colors.size(1) in [3, 4]): + message = "faces_colors should have shape (num_faces, C)." + raise ValueError(message) + if faces_colors.shape[0] != cast(torch.LongTensor, faces).shape[0]: + message = "faces_colors should have the same length as faces." + raise ValueError(message) + + with _open_file(file, path_manager, "wb") as f: + _write_off_data(f, verts, verts_colors, faces, faces_colors, decimal_places) + + +class MeshOffFormat(MeshFormatInterpreter): + """ + Loads and saves meshes in the ascii OFF format. This is a simple + format which can only deal with the following texture types: + + - TexturesVertex, i.e. one color for each vertex + - TexturesAtlas with R=1, i.e. one color for each face. + + There are some possible features of OFF files which we do not support + and which appear to be rare: + + - Four dimensional data. + - Binary data. + - Vertex Normals. + - Texture coordinates. + - "COFF" header. + + Example .off file format: + + off + 8 6 1927 { number of vertices, faces, and (not used) edges } + # comment { comments with # sign } + 0 0 0 { start of vertex list } + 0 0 1 + 0 1 1 + 0 1 0 + 1 0 0 + 1 0 1 + 1 1 1 + 1 1 0 + 4 0 1 2 3 { start of face list } + 4 7 6 5 4 + 4 0 4 5 1 + 4 1 5 6 2 + 4 2 6 7 3 + 4 3 7 4 0 + + """ + + def __init__(self): + self.known_suffixes = (".off",) + + def read( + self, + path: Union[str, Path], + include_textures: bool, + device, + path_manager: PathManager, + **kwargs, + ) -> Optional[Meshes]: + if not endswith(path, self.known_suffixes): + return None + + with _open_file(path, path_manager, "rb") as f: + data = _load_off_stream(f) + verts = torch.from_numpy(data["verts"]).to(device) + if "faces" in data: + faces = torch.from_numpy(data["faces"]).to(dtype=torch.int64, device=device) + else: + faces = torch.zeros((0, 3), dtype=torch.int64, device=device) + + textures = None + if "verts_colors" in data: + if "faces_colors" in data: + msg = "Faces colors ignored because vertex colors provided too." + warnings.warn(msg) + verts_colors = torch.from_numpy(data["verts_colors"]).to(device) + textures = TexturesVertex([verts_colors]) + elif "faces_colors" in data: + faces_colors = torch.from_numpy(data["faces_colors"]).to(device) + textures = TexturesAtlas([faces_colors[:, None, None, :]]) + + mesh = Meshes( + verts=[verts.to(device)], faces=[faces.to(device)], textures=textures + ) + return mesh + + def save( + self, + data: Meshes, + path: Union[str, Path], + path_manager: PathManager, + binary: Optional[bool], + decimal_places: Optional[int] = None, + **kwargs, + ) -> bool: + if not endswith(path, self.known_suffixes): + return False + + verts = data.verts_list()[0] + faces = data.faces_list()[0] + if isinstance(data.textures, TexturesVertex): + [verts_colors] = data.textures.verts_features_list() + else: + verts_colors = None + + faces_colors = None + if isinstance(data.textures, TexturesAtlas): + [atlas] = data.textures.atlas_list() + F, R, _, D = atlas.shape + if R == 1: + faces_colors = atlas[:, 0, 0, :] + + _save_off( + file=path, + verts=verts, + faces=faces, + verts_colors=verts_colors, + faces_colors=faces_colors, + decimal_places=decimal_places, + path_manager=path_manager, + ) + return True diff --git a/pytorch3d/io/pluggable.py b/pytorch3d/io/pluggable.py index 910a1a83..6d1b4691 100644 --- a/pytorch3d/io/pluggable.py +++ b/pytorch3d/io/pluggable.py @@ -11,6 +11,7 @@ from iopath.common.file_io import PathManager from pytorch3d.structures import Meshes, Pointclouds from .obj_io import MeshObjFormat +from .off_io import MeshOffFormat from .pluggable_formats import MeshFormatInterpreter, PointcloudFormatInterpreter from .ply_io import MeshPlyFormat, PointcloudPlyFormat @@ -73,6 +74,7 @@ class IO: def register_default_formats(self) -> None: self.register_meshes_format(MeshObjFormat()) + self.register_meshes_format(MeshOffFormat()) self.register_meshes_format(MeshPlyFormat()) self.register_pointcloud_format(PointcloudPlyFormat()) diff --git a/pytorch3d/io/ply_io.py b/pytorch3d/io/ply_io.py index 019041c5..2df8461f 100644 --- a/pytorch3d/io/ply_io.py +++ b/pytorch3d/io/ply_io.py @@ -5,7 +5,7 @@ """ This module implements utility functions for loading and saving -meshes and point clouds from PLY files. +meshes and point clouds as PLY files. """ import itertools import struct diff --git a/tests/bm_mesh_io.py b/tests/bm_mesh_io.py index a8f43be2..6fc914bb 100644 --- a/tests/bm_mesh_io.py +++ b/tests/bm_mesh_io.py @@ -3,8 +3,8 @@ from itertools import product from fvcore.common.benchmark import benchmark -from test_obj_io import TestMeshObjIO -from test_ply_io import TestMeshPlyIO +from test_io_obj import TestMeshObjIO +from test_io_ply import TestMeshPlyIO def bm_save_load() -> None: diff --git a/tests/test_obj_io.py b/tests/test_io_obj.py similarity index 100% rename from tests/test_obj_io.py rename to tests/test_io_obj.py diff --git a/tests/test_io_off.py b/tests/test_io_off.py new file mode 100644 index 00000000..16cf766d --- /dev/null +++ b/tests/test_io_off.py @@ -0,0 +1,325 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import unittest +from tempfile import NamedTemporaryFile + +import torch +from common_testing import TestCaseMixin +from pytorch3d.io import IO +from pytorch3d.renderer import TexturesAtlas, TexturesVertex +from pytorch3d.utils import ico_sphere + + +CUBE_FACES = [ + [0, 1, 2], + [7, 4, 0], + [4, 5, 1], + [5, 6, 2], + [3, 2, 6], + [6, 5, 4], + [0, 2, 3], + [7, 0, 3], + [4, 1, 0], + [5, 2, 1], + [3, 6, 7], + [6, 4, 7], +] + + +class TestMeshOffIO(TestCaseMixin, unittest.TestCase): + def test_load_face_colors(self): + # Example from wikipedia + off_file_lines = [ + "OFF", + "# cube.off", + "# A cube", + " ", + "8 6 12", + " 1.0 0.0 1.4142", + " 0.0 1.0 1.4142", + "-1.0 0.0 1.4142", + " 0.0 -1.0 1.4142", + " 1.0 0.0 0.0", + " 0.0 1.0 0.0", + "-1.0 0.0 0.0", + " 0.0 -1.0 0.0", + "4 0 1 2 3 255 0 0 #red", + "4 7 4 0 3 0 255 0 #green", + "4 4 5 1 0 0 0 255 #blue", + "4 5 6 2 1 0 255 0 ", + "4 3 2 6 7 0 0 255", + "4 6 5 4 7 255 0 0", + ] + off_file = "\n".join(off_file_lines) + io = IO() + with NamedTemporaryFile(mode="w", suffix=".off") as f: + f.write(off_file) + f.flush() + mesh = io.load_mesh(f.name) + + self.assertEqual(mesh.verts_padded().shape, (1, 8, 3)) + verts_str = " ".join(off_file_lines[5:13]) + verts_data = torch.tensor([float(i) for i in verts_str.split()]) + self.assertClose(mesh.verts_padded().flatten(), verts_data) + self.assertClose(mesh.faces_padded(), torch.tensor(CUBE_FACES)[None]) + + faces_colors_full = mesh.textures.atlas_padded() + self.assertEqual(faces_colors_full.shape, (1, 12, 1, 1, 3)) + faces_colors = faces_colors_full[0, :, 0, 0] + max_color = faces_colors.max() + self.assertEqual(max_color, 1) + + # Every face has one color 1, the rest 0. + total_color = faces_colors.sum(dim=1) + self.assertEqual(total_color.max(), max_color) + self.assertEqual(total_color.min(), max_color) + + def test_load_vertex_colors(self): + # Example with no faces and with integer vertex colors + off_file_lines = [ + "8 1 12", + " 1.0 0.0 1.4142 0 1 0", + " 0.0 1.0 1.4142 0 1 0", + "-1.0 0.0 1.4142 0 1 0", + " 0.0 -1.0 1.4142 0 1 0", + " 1.0 0.0 0.0 0 1 0", + " 0.0 1.0 0.0 0 1 0", + "-1.0 0.0 0.0 0 1 0", + " 0.0 -1.0 0.0 0 1 0", + "3 0 1 2", + ] + off_file = "\n".join(off_file_lines) + io = IO() + with NamedTemporaryFile(mode="w", suffix=".off") as f: + f.write(off_file) + f.flush() + mesh = io.load_mesh(f.name) + + self.assertEqual(mesh.verts_padded().shape, (1, 8, 3)) + verts_lines = (line.split()[:3] for line in off_file_lines[1:9]) + verts_data = [[[float(x) for x in line] for line in verts_lines]] + self.assertClose(mesh.verts_padded(), torch.tensor(verts_data)) + self.assertClose(mesh.faces_padded(), torch.tensor([[[0, 1, 2]]])) + + self.assertIsInstance(mesh.textures, TexturesVertex) + colors = mesh.textures.verts_features_padded() + + self.assertEqual(colors.shape, (1, 8, 3)) + self.assertClose(colors[0, :, [0, 2]], torch.zeros(8, 2)) + self.assertClose(colors[0, :, 1], torch.full((8,), 1.0 / 255)) + + def test_load_lumpy(self): + # Example off file whose faces have different numbers of vertices. + off_file_lines = [ + "8 3 12", + " 1.0 0.0 1.4142", + " 0.0 1.0 1.4142", + "-1.0 0.0 1.4142", + " 0.0 -1.0 1.4142", + " 1.0 0.0 0.0", + " 0.0 1.0 0.0", + "-1.0 0.0 0.0", + " 0.0 -1.0 0.0", + "3 0 1 2 255 0 0 #red", + "4 7 4 0 3 0 255 0 #green", + "4 4 5 1 0 0 0 255 #blue", + ] + off_file = "\n".join(off_file_lines) + io = IO() + with NamedTemporaryFile(mode="w", suffix=".off") as f: + f.write(off_file) + f.flush() + mesh = io.load_mesh(f.name) + + self.assertEqual(mesh.verts_padded().shape, (1, 8, 3)) + verts_str = " ".join(off_file_lines[1:9]) + verts_data = torch.tensor([float(i) for i in verts_str.split()]) + self.assertClose(mesh.verts_padded().flatten(), verts_data) + + self.assertEqual(mesh.faces_padded().shape, (1, 5, 3)) + faces_expected = [[0, 1, 2], [7, 4, 0], [7, 0, 3], [4, 5, 1], [4, 1, 0]] + self.assertClose(mesh.faces_padded()[0], torch.tensor(faces_expected)) + + def test_save_load_icosphere(self): + # Test that saving a mesh as an off file and loading it results in the + # same data on the correct device, for all permitted types of textures. + # Standard test is for random colors, but also check totally white, + # because there's a different in OFF semantics between "1.0" color (=full) + # and "1" (= 1/255 color) + sphere = ico_sphere(0) + io = IO() + device = torch.device("cuda:0") + + atlas_padded = torch.rand(1, sphere.faces_list()[0].shape[0], 1, 1, 3) + atlas = TexturesAtlas(atlas_padded) + + atlas_padded_white = torch.ones(1, sphere.faces_list()[0].shape[0], 1, 1, 3) + atlas_white = TexturesAtlas(atlas_padded_white) + + verts_colors_padded = torch.rand(1, sphere.verts_list()[0].shape[0], 3) + vertex_texture = TexturesVertex(verts_colors_padded) + + verts_colors_padded_white = torch.ones(1, sphere.verts_list()[0].shape[0], 3) + vertex_texture_white = TexturesVertex(verts_colors_padded_white) + + # No colors case + with NamedTemporaryFile(mode="w", suffix=".off") as f: + io.save_mesh(sphere, f.name) + f.flush() + mesh1 = io.load_mesh(f.name, device=device) + self.assertEqual(mesh1.device, device) + mesh1 = mesh1.cpu() + self.assertClose(mesh1.verts_padded(), sphere.verts_padded()) + self.assertClose(mesh1.faces_padded(), sphere.faces_padded()) + self.assertIsNone(mesh1.textures) + + # Atlas case + sphere.textures = atlas + with NamedTemporaryFile(mode="w", suffix=".off") as f: + io.save_mesh(sphere, f.name) + f.flush() + mesh2 = io.load_mesh(f.name, device=device) + + self.assertEqual(mesh2.device, device) + mesh2 = mesh2.cpu() + self.assertClose(mesh2.verts_padded(), sphere.verts_padded()) + self.assertClose(mesh2.faces_padded(), sphere.faces_padded()) + self.assertClose(mesh2.textures.atlas_padded(), atlas_padded, atol=1e-4) + + # White atlas case + sphere.textures = atlas_white + with NamedTemporaryFile(mode="w", suffix=".off") as f: + io.save_mesh(sphere, f.name) + f.flush() + mesh3 = io.load_mesh(f.name) + + self.assertClose(mesh3.textures.atlas_padded(), atlas_padded_white, atol=1e-4) + + # TexturesVertex case + sphere.textures = vertex_texture + with NamedTemporaryFile(mode="w", suffix=".off") as f: + io.save_mesh(sphere, f.name) + f.flush() + mesh4 = io.load_mesh(f.name, device=device) + + self.assertEqual(mesh4.device, device) + mesh4 = mesh4.cpu() + self.assertClose(mesh4.verts_padded(), sphere.verts_padded()) + self.assertClose(mesh4.faces_padded(), sphere.faces_padded()) + self.assertClose( + mesh4.textures.verts_features_padded(), verts_colors_padded, atol=1e-4 + ) + + # white TexturesVertex case + sphere.textures = vertex_texture_white + with NamedTemporaryFile(mode="w", suffix=".off") as f: + io.save_mesh(sphere, f.name) + f.flush() + mesh5 = io.load_mesh(f.name) + + self.assertClose( + mesh5.textures.verts_features_padded(), verts_colors_padded_white, atol=1e-4 + ) + + def test_bad(self): + # Test errors from various invalid OFF files. + io = IO() + + def load(lines): + off_file = "\n".join(lines) + with NamedTemporaryFile(mode="w", suffix=".off") as f: + f.write(off_file) + f.flush() + io.load_mesh(f.name) + + # First a good example + lines = [ + "4 2 12", + " 1.0 0.0 1.4142", + " 0.0 1.0 1.4142", + " 1.0 0.0 0.4142", + " 0.0 1.0 0.4142", + "3 0 1 2 ", + "3 1 3 0 ", + ] + + # This example passes. + load(lines) + + # OFF can occur on the first line separately + load(["OFF"] + lines) + + # OFF line can be merged in to the first line + lines2 = lines.copy() + lines2[0] = "OFF " + lines[0] + load(lines2) + + with self.assertRaisesRegex(ValueError, "Not enough face data."): + load(lines[:-1]) + + lines2 = lines.copy() + lines2[0] = "4 1 12" + with self.assertRaisesRegex(ValueError, "Extra data at end of file:"): + load(lines2) + + lines2 = lines.copy() + lines2[-1] = "2 1 3" + with self.assertRaisesRegex(ValueError, "Faces must have at least 3 vertices."): + load(lines2) + + lines2 = lines.copy() + lines2[-1] = "4 1 3 0" + with self.assertRaisesRegex( + ValueError, "A line of face data did not have the specified length." + ): + load(lines2) + + lines2 = lines.copy() + lines2[0] = "6 2 0" + with self.assertRaisesRegex(ValueError, "Wrong number of columns at line 5"): + load(lines2) + + lines2[0] = "5 1 0" + with self.assertRaisesRegex(ValueError, "Wrong number of columns at line 5"): + load(lines2) + + lines2[0] = "16 2 0" + with self.assertRaisesRegex(ValueError, "Wrong number of columns at line 5"): + load(lines2) + + lines2[0] = "3 3 0" + # This is a bit of a special case because the last vertex could be a face + with self.assertRaisesRegex(ValueError, "Faces must have at least 3 vertices."): + load(lines2) + + lines2[4] = "7.3 4.2 8.3" + with self.assertRaisesRegex( + ValueError, "A line of face data did not have the specified length." + ): + load(lines2) + + # Now try bad number of colors + + lines2 = lines.copy() + lines2[2] = "7.3 4.2 8.3 932" + with self.assertRaisesRegex(ValueError, "Wrong number of columns at line 2"): + load(lines2) + + lines2[1] = "7.3 4.2 8.3 932" + lines2[3] = "7.3 4.2 8.3 932" + lines2[4] = "7.3 4.2 8.3 932" + with self.assertRaisesRegex(ValueError, "Bad vertex data."): + load(lines2) + + lines2 = lines.copy() + lines2[5] = "3 0 1 2 0.9" + lines2[6] = "3 0 3 0 0.9" + with self.assertRaisesRegex(ValueError, "Unexpected number of colors."): + load(lines2) + + lines2 = lines.copy() + for i in range(1, 7): + lines2[i] = lines2[i] + " 4 4 4 4" + msg = "Faces colors ignored because vertex colors provided too." + with self.assertWarnsRegex(UserWarning, msg): + load(lines2) diff --git a/tests/test_ply_io.py b/tests/test_io_ply.py similarity index 100% rename from tests/test_ply_io.py rename to tests/test_io_ply.py