Loading/saving meshes to OFF files.

Summary: Implements the ascii OFF file format. This was discussed in https://github.com/facebookresearch/pytorch3d/issues/216

Reviewed By: theschnitz

Differential Revision: D25788834

fbshipit-source-id: c141d1f4ba3bad24e3c1f280a20aee782bfd74d6
This commit is contained in:
Jeremy Reizenstein 2021-02-12 07:04:30 -08:00 committed by Facebook GitHub Bot
parent 4bfe7158b1
commit 0345f860d4
7 changed files with 818 additions and 3 deletions

488
pytorch3d/io/off_io.py Normal file
View File

@ -0,0 +1,488 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
This module implements utility functions for loading and saving
meshes as .off files.
"""
import warnings
from pathlib import Path
from typing import Optional, Tuple, Union, cast
import numpy as np
import torch
from iopath.common.file_io import PathManager
from pytorch3d.io.utils import _check_faces_indices, _open_file
from pytorch3d.renderer import TexturesAtlas, TexturesVertex
from pytorch3d.structures import Meshes
from .pluggable_formats import MeshFormatInterpreter, endswith
def _is_line_empty(line: Union[str, bytes]) -> bool:
"""
Returns whether line is not relevant in an OFF file.
"""
line = line.strip()
return len(line) == 0 or line[:1] == b"#"
def _count_next_line_periods(file) -> int:
"""
Returns the number of . characters before any # on the next
meaningful line.
"""
old_offset = file.tell()
line = file.readline()
while _is_line_empty(line):
line = file.readline()
if len(line) == 0:
raise ValueError("Premature end of file")
contents = line.split(b"#")[0]
count = contents.count(b".")
file.seek(old_offset)
return count
def _read_faces_lump(
file, n_faces: int, n_colors: Optional[int]
) -> Optional[Tuple[np.ndarray, int, Optional[np.ndarray]]]:
"""
Parse n_faces faces and faces_colors from the file,
if they all have the same number of vertices.
This is used in two ways.
1) To try to read all faces.
2) To read faces one-by-one if that failed.
Args:
file: file-like object being read.
n_faces: The known number of faces yet to read.
n_colors: The number of colors if known already.
Returns:
- 2D numpy array of faces
- number of colors found
- 2D numpy array of face colors if found.
of None if there are faces with different numbers of vertices.
"""
if n_faces == 0:
return np.array([[]]), 0, None
old_offset = file.tell()
try:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message=".* Empty input file.*", category=UserWarning
)
data = np.loadtxt(file, dtype=np.float32, ndmin=2, max_rows=n_faces)
except ValueError as e:
if n_faces > 1 and "Wrong number of columns" in e.args[0]:
file.seek(old_offset)
return None
raise ValueError("Not enough face data.")
if len(data) != n_faces:
raise ValueError("Not enough face data.")
face_size = int(data[0, 0])
if (data[:, 0] != face_size).any():
msg = "A line of face data did not have the specified length."
raise ValueError(msg)
if face_size < 3:
raise ValueError("Faces must have at least 3 vertices.")
n_colors_found = data.shape[1] - 1 - face_size
if n_colors is not None and n_colors_found != n_colors:
raise ValueError("Number of colors differs between faces.")
n_colors = n_colors_found
if n_colors not in [0, 3, 4]:
raise ValueError("Unexpected number of colors.")
face_raw_data = data[:, 1 : 1 + face_size].astype("int64")
if face_size == 3:
face_data = face_raw_data
else:
face_arrays = [
face_raw_data[:, [0, i + 1, i + 2]] for i in range(face_size - 2)
]
face_data = np.vstack(face_arrays)
if n_colors == 0:
return face_data, 0, None
colors = data[:, 1 + face_size :]
if face_size == 3:
return face_data, n_colors, colors
return face_data, n_colors, np.tile(colors, (face_size - 2, 1))
def _read_faces(
file, n_faces: int
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
"""
Returns faces and face colors from the file.
Args:
file: file-like object being read.
n_faces: The known number of faces.
Returns:
2D numpy arrays of faces and face colors, or None for each if
they are not present.
"""
if n_faces == 0:
return None, None
color_is_int = 0 == _count_next_line_periods(file)
color_scale = 1 / 255.0 if color_is_int else 1
faces_ncolors_colors = _read_faces_lump(file, n_faces=n_faces, n_colors=None)
if faces_ncolors_colors is not None:
faces, _, colors = faces_ncolors_colors
if colors is None:
return faces, None
return faces, colors * color_scale
faces_list, colors_list = [], []
n_colors = None
for _ in range(n_faces):
faces_ncolors_colors = _read_faces_lump(file, n_faces=1, n_colors=n_colors)
faces_found, n_colors, colors_found = cast(
Tuple[np.ndarray, int, Optional[np.ndarray]], faces_ncolors_colors
)
faces_list.append(faces_found)
colors_list.append(colors_found)
faces = np.vstack(faces_list)
if n_colors == 0:
colors = None
else:
colors = np.vstack(colors_list) * color_scale
return faces, colors
def _read_verts(file, n_verts: int) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""
Returns verts and vertex colors from the file.
Args:
file: file-like object being read.
n_verts: The known number of faces.
Returns:
2D numpy arrays of verts and (if present)
vertex colors.
"""
color_is_int = 3 == _count_next_line_periods(file)
color_scale = 1 / 255.0 if color_is_int else 1
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message=".* Empty input file.*", category=UserWarning
)
data = np.loadtxt(file, dtype=np.float32, ndmin=2, max_rows=n_verts)
if data.shape[0] != n_verts:
raise ValueError("Not enough vertex data.")
if data.shape[1] not in [3, 6, 7]:
raise ValueError("Bad vertex data.")
if data.shape[1] == 3:
return data, None
return data[:, :3], data[:, 3:] * color_scale # []
def _load_off_stream(file) -> dict:
"""
Load the data from a stream of an .off file.
Example .off file format:
off
8 6 1927 { number of vertices, faces, and (not used) edges }
# comment { comments with # sign }
0 0 0 { start of vertex list }
0 0 1
0 1 1
0 1 0
1 0 0
1 0 1
1 1 1
1 1 0
4 0 1 2 3 { start of face list }
4 7 6 5 4
4 0 4 5 1
4 1 5 6 2
4 2 6 7 3
4 3 7 4 0
Args:
file: A binary file-like object (with methods read, readline,
tell and seek).
Returns dictionary possibly containing:
verts: (always present) FloatTensor of shape (V, 3).
verts_colors: FloatTensor of shape (V, C) where C is 3 or 4.
faces: LongTensor of vertex indices, split into triangles, shape (F, 3).
faces_colors: FloatTensor of shape (F, C), where C is 3 or 4.
"""
header = file.readline()
if header.lower() in (b"off\n", b"off\r\n", "off\n"):
header = file.readline()
while _is_line_empty(header):
header = file.readline()
items = header.split(b" ")
if len(items) and items[0].lower() in ("off", b"off"):
items = items[1:]
if len(items) < 3:
raise ValueError("Invalid counts line: %s" % header)
try:
n_verts = int(items[0])
except ValueError:
raise ValueError("Invalid counts line: %s" % header)
try:
n_faces = int(items[1])
except ValueError:
raise ValueError("Invalid counts line: %s" % header)
if (len(items) > 3 and not items[3].startswith("#")) or n_verts < 0 or n_faces < 0:
raise ValueError("Invalid counts line: %s" % header)
verts, verts_colors = _read_verts(file, n_verts)
faces, faces_colors = _read_faces(file, n_faces)
end = file.read().strip()
if len(end) != 0:
raise ValueError("Extra data at end of file: " + str(end[:20]))
out = {"verts": verts}
if verts_colors is not None:
out["verts_colors"] = verts_colors
if faces is not None:
out["faces"] = faces
if faces_colors is not None:
out["faces_colors"] = faces_colors
return out
def _write_off_data(
file,
verts: torch.Tensor,
verts_colors: Optional[torch.Tensor] = None,
faces: Optional[torch.LongTensor] = None,
faces_colors: Optional[torch.Tensor] = None,
decimal_places: Optional[int] = None,
) -> None:
"""
Internal implementation for saving 3D data to a .off file.
Args:
file: Binary file object to which the 3D data should be written.
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
verts_colors: FloatTensor of shape (V, C) giving vertex colors where C is 3 or 4.
faces: LongTensor of shape (F, 3) giving faces.
faces_colors: FloatTensor of shape (V, C) giving face colors where C is 3 or 4.
decimal_places: Number of decimal places for saving.
"""
nfaces = 0 if faces is None else faces.shape[0]
file.write(f"off\n{verts.shape[0]} {nfaces} 0\n".encode("ascii"))
if verts_colors is not None:
verts = torch.cat((verts, verts_colors), dim=1)
if decimal_places is None:
float_str = "%f"
else:
float_str = "%" + ".%df" % decimal_places
np.savetxt(file, verts.cpu().detach().numpy(), float_str)
if faces is not None:
_check_faces_indices(faces, max_index=verts.shape[0])
if faces_colors is not None:
face_data = torch.cat(
[
cast(torch.Tensor, faces).cpu().to(torch.float64),
faces_colors.detach().cpu().to(torch.float64),
],
dim=1,
)
format = "3 %d %d %d" + " %f" * faces_colors.shape[1]
np.savetxt(file, face_data.numpy(), format)
elif faces is not None:
np.savetxt(file, faces.cpu().detach().numpy(), "3 %d %d %d")
def _save_off(
file,
*,
verts: torch.Tensor,
verts_colors: Optional[torch.Tensor] = None,
faces: Optional[torch.LongTensor] = None,
faces_colors: Optional[torch.Tensor] = None,
decimal_places: Optional[int] = None,
path_manager: PathManager,
) -> None:
"""
Save a mesh to an ascii .off file.
Args:
file: File (or path) to which the mesh should be written.
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
verts_colors: FloatTensor of shape (V, C) giving vertex colors where C is 3 or 4.
faces: LongTensor of shape (F, 3) giving faces.
faces_colors: FloatTensor of shape (V, C) giving face colors where C is 3 or 4.
decimal_places: Number of decimal places for saving.
"""
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
raise ValueError(message)
if verts_colors is not None and 0 == len(verts_colors):
verts_colors = None
if faces_colors is not None and 0 == len(faces_colors):
faces_colors = None
if faces is not None and 0 == len(faces):
faces = None
if verts_colors is not None:
if not (verts_colors.dim() == 2 and verts_colors.size(1) in [3, 4]):
message = "verts_colors should have shape (num_faces, C)."
raise ValueError(message)
if verts_colors.shape[0] != verts.shape[0]:
message = "verts_colors should have the same length as verts."
raise ValueError(message)
if faces is not None and not (faces.dim() == 2 and faces.size(1) == 3):
message = "Argument 'faces' if present should have shape (num_faces, 3)."
raise ValueError(message)
if faces_colors is not None and faces is None:
message = "Cannot have face colors without faces"
raise ValueError(message)
if faces_colors is not None:
if not (faces_colors.dim() == 2 and faces_colors.size(1) in [3, 4]):
message = "faces_colors should have shape (num_faces, C)."
raise ValueError(message)
if faces_colors.shape[0] != cast(torch.LongTensor, faces).shape[0]:
message = "faces_colors should have the same length as faces."
raise ValueError(message)
with _open_file(file, path_manager, "wb") as f:
_write_off_data(f, verts, verts_colors, faces, faces_colors, decimal_places)
class MeshOffFormat(MeshFormatInterpreter):
"""
Loads and saves meshes in the ascii OFF format. This is a simple
format which can only deal with the following texture types:
- TexturesVertex, i.e. one color for each vertex
- TexturesAtlas with R=1, i.e. one color for each face.
There are some possible features of OFF files which we do not support
and which appear to be rare:
- Four dimensional data.
- Binary data.
- Vertex Normals.
- Texture coordinates.
- "COFF" header.
Example .off file format:
off
8 6 1927 { number of vertices, faces, and (not used) edges }
# comment { comments with # sign }
0 0 0 { start of vertex list }
0 0 1
0 1 1
0 1 0
1 0 0
1 0 1
1 1 1
1 1 0
4 0 1 2 3 { start of face list }
4 7 6 5 4
4 0 4 5 1
4 1 5 6 2
4 2 6 7 3
4 3 7 4 0
"""
def __init__(self):
self.known_suffixes = (".off",)
def read(
self,
path: Union[str, Path],
include_textures: bool,
device,
path_manager: PathManager,
**kwargs,
) -> Optional[Meshes]:
if not endswith(path, self.known_suffixes):
return None
with _open_file(path, path_manager, "rb") as f:
data = _load_off_stream(f)
verts = torch.from_numpy(data["verts"]).to(device)
if "faces" in data:
faces = torch.from_numpy(data["faces"]).to(dtype=torch.int64, device=device)
else:
faces = torch.zeros((0, 3), dtype=torch.int64, device=device)
textures = None
if "verts_colors" in data:
if "faces_colors" in data:
msg = "Faces colors ignored because vertex colors provided too."
warnings.warn(msg)
verts_colors = torch.from_numpy(data["verts_colors"]).to(device)
textures = TexturesVertex([verts_colors])
elif "faces_colors" in data:
faces_colors = torch.from_numpy(data["faces_colors"]).to(device)
textures = TexturesAtlas([faces_colors[:, None, None, :]])
mesh = Meshes(
verts=[verts.to(device)], faces=[faces.to(device)], textures=textures
)
return mesh
def save(
self,
data: Meshes,
path: Union[str, Path],
path_manager: PathManager,
binary: Optional[bool],
decimal_places: Optional[int] = None,
**kwargs,
) -> bool:
if not endswith(path, self.known_suffixes):
return False
verts = data.verts_list()[0]
faces = data.faces_list()[0]
if isinstance(data.textures, TexturesVertex):
[verts_colors] = data.textures.verts_features_list()
else:
verts_colors = None
faces_colors = None
if isinstance(data.textures, TexturesAtlas):
[atlas] = data.textures.atlas_list()
F, R, _, D = atlas.shape
if R == 1:
faces_colors = atlas[:, 0, 0, :]
_save_off(
file=path,
verts=verts,
faces=faces,
verts_colors=verts_colors,
faces_colors=faces_colors,
decimal_places=decimal_places,
path_manager=path_manager,
)
return True

View File

@ -11,6 +11,7 @@ from iopath.common.file_io import PathManager
from pytorch3d.structures import Meshes, Pointclouds
from .obj_io import MeshObjFormat
from .off_io import MeshOffFormat
from .pluggable_formats import MeshFormatInterpreter, PointcloudFormatInterpreter
from .ply_io import MeshPlyFormat, PointcloudPlyFormat
@ -73,6 +74,7 @@ class IO:
def register_default_formats(self) -> None:
self.register_meshes_format(MeshObjFormat())
self.register_meshes_format(MeshOffFormat())
self.register_meshes_format(MeshPlyFormat())
self.register_pointcloud_format(PointcloudPlyFormat())

View File

@ -5,7 +5,7 @@
"""
This module implements utility functions for loading and saving
meshes and point clouds from PLY files.
meshes and point clouds as PLY files.
"""
import itertools
import struct

View File

@ -3,8 +3,8 @@
from itertools import product
from fvcore.common.benchmark import benchmark
from test_obj_io import TestMeshObjIO
from test_ply_io import TestMeshPlyIO
from test_io_obj import TestMeshObjIO
from test_io_ply import TestMeshPlyIO
def bm_save_load() -> None:

325
tests/test_io_off.py Normal file
View File

@ -0,0 +1,325 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
from tempfile import NamedTemporaryFile
import torch
from common_testing import TestCaseMixin
from pytorch3d.io import IO
from pytorch3d.renderer import TexturesAtlas, TexturesVertex
from pytorch3d.utils import ico_sphere
CUBE_FACES = [
[0, 1, 2],
[7, 4, 0],
[4, 5, 1],
[5, 6, 2],
[3, 2, 6],
[6, 5, 4],
[0, 2, 3],
[7, 0, 3],
[4, 1, 0],
[5, 2, 1],
[3, 6, 7],
[6, 4, 7],
]
class TestMeshOffIO(TestCaseMixin, unittest.TestCase):
def test_load_face_colors(self):
# Example from wikipedia
off_file_lines = [
"OFF",
"# cube.off",
"# A cube",
" ",
"8 6 12",
" 1.0 0.0 1.4142",
" 0.0 1.0 1.4142",
"-1.0 0.0 1.4142",
" 0.0 -1.0 1.4142",
" 1.0 0.0 0.0",
" 0.0 1.0 0.0",
"-1.0 0.0 0.0",
" 0.0 -1.0 0.0",
"4 0 1 2 3 255 0 0 #red",
"4 7 4 0 3 0 255 0 #green",
"4 4 5 1 0 0 0 255 #blue",
"4 5 6 2 1 0 255 0 ",
"4 3 2 6 7 0 0 255",
"4 6 5 4 7 255 0 0",
]
off_file = "\n".join(off_file_lines)
io = IO()
with NamedTemporaryFile(mode="w", suffix=".off") as f:
f.write(off_file)
f.flush()
mesh = io.load_mesh(f.name)
self.assertEqual(mesh.verts_padded().shape, (1, 8, 3))
verts_str = " ".join(off_file_lines[5:13])
verts_data = torch.tensor([float(i) for i in verts_str.split()])
self.assertClose(mesh.verts_padded().flatten(), verts_data)
self.assertClose(mesh.faces_padded(), torch.tensor(CUBE_FACES)[None])
faces_colors_full = mesh.textures.atlas_padded()
self.assertEqual(faces_colors_full.shape, (1, 12, 1, 1, 3))
faces_colors = faces_colors_full[0, :, 0, 0]
max_color = faces_colors.max()
self.assertEqual(max_color, 1)
# Every face has one color 1, the rest 0.
total_color = faces_colors.sum(dim=1)
self.assertEqual(total_color.max(), max_color)
self.assertEqual(total_color.min(), max_color)
def test_load_vertex_colors(self):
# Example with no faces and with integer vertex colors
off_file_lines = [
"8 1 12",
" 1.0 0.0 1.4142 0 1 0",
" 0.0 1.0 1.4142 0 1 0",
"-1.0 0.0 1.4142 0 1 0",
" 0.0 -1.0 1.4142 0 1 0",
" 1.0 0.0 0.0 0 1 0",
" 0.0 1.0 0.0 0 1 0",
"-1.0 0.0 0.0 0 1 0",
" 0.0 -1.0 0.0 0 1 0",
"3 0 1 2",
]
off_file = "\n".join(off_file_lines)
io = IO()
with NamedTemporaryFile(mode="w", suffix=".off") as f:
f.write(off_file)
f.flush()
mesh = io.load_mesh(f.name)
self.assertEqual(mesh.verts_padded().shape, (1, 8, 3))
verts_lines = (line.split()[:3] for line in off_file_lines[1:9])
verts_data = [[[float(x) for x in line] for line in verts_lines]]
self.assertClose(mesh.verts_padded(), torch.tensor(verts_data))
self.assertClose(mesh.faces_padded(), torch.tensor([[[0, 1, 2]]]))
self.assertIsInstance(mesh.textures, TexturesVertex)
colors = mesh.textures.verts_features_padded()
self.assertEqual(colors.shape, (1, 8, 3))
self.assertClose(colors[0, :, [0, 2]], torch.zeros(8, 2))
self.assertClose(colors[0, :, 1], torch.full((8,), 1.0 / 255))
def test_load_lumpy(self):
# Example off file whose faces have different numbers of vertices.
off_file_lines = [
"8 3 12",
" 1.0 0.0 1.4142",
" 0.0 1.0 1.4142",
"-1.0 0.0 1.4142",
" 0.0 -1.0 1.4142",
" 1.0 0.0 0.0",
" 0.0 1.0 0.0",
"-1.0 0.0 0.0",
" 0.0 -1.0 0.0",
"3 0 1 2 255 0 0 #red",
"4 7 4 0 3 0 255 0 #green",
"4 4 5 1 0 0 0 255 #blue",
]
off_file = "\n".join(off_file_lines)
io = IO()
with NamedTemporaryFile(mode="w", suffix=".off") as f:
f.write(off_file)
f.flush()
mesh = io.load_mesh(f.name)
self.assertEqual(mesh.verts_padded().shape, (1, 8, 3))
verts_str = " ".join(off_file_lines[1:9])
verts_data = torch.tensor([float(i) for i in verts_str.split()])
self.assertClose(mesh.verts_padded().flatten(), verts_data)
self.assertEqual(mesh.faces_padded().shape, (1, 5, 3))
faces_expected = [[0, 1, 2], [7, 4, 0], [7, 0, 3], [4, 5, 1], [4, 1, 0]]
self.assertClose(mesh.faces_padded()[0], torch.tensor(faces_expected))
def test_save_load_icosphere(self):
# Test that saving a mesh as an off file and loading it results in the
# same data on the correct device, for all permitted types of textures.
# Standard test is for random colors, but also check totally white,
# because there's a different in OFF semantics between "1.0" color (=full)
# and "1" (= 1/255 color)
sphere = ico_sphere(0)
io = IO()
device = torch.device("cuda:0")
atlas_padded = torch.rand(1, sphere.faces_list()[0].shape[0], 1, 1, 3)
atlas = TexturesAtlas(atlas_padded)
atlas_padded_white = torch.ones(1, sphere.faces_list()[0].shape[0], 1, 1, 3)
atlas_white = TexturesAtlas(atlas_padded_white)
verts_colors_padded = torch.rand(1, sphere.verts_list()[0].shape[0], 3)
vertex_texture = TexturesVertex(verts_colors_padded)
verts_colors_padded_white = torch.ones(1, sphere.verts_list()[0].shape[0], 3)
vertex_texture_white = TexturesVertex(verts_colors_padded_white)
# No colors case
with NamedTemporaryFile(mode="w", suffix=".off") as f:
io.save_mesh(sphere, f.name)
f.flush()
mesh1 = io.load_mesh(f.name, device=device)
self.assertEqual(mesh1.device, device)
mesh1 = mesh1.cpu()
self.assertClose(mesh1.verts_padded(), sphere.verts_padded())
self.assertClose(mesh1.faces_padded(), sphere.faces_padded())
self.assertIsNone(mesh1.textures)
# Atlas case
sphere.textures = atlas
with NamedTemporaryFile(mode="w", suffix=".off") as f:
io.save_mesh(sphere, f.name)
f.flush()
mesh2 = io.load_mesh(f.name, device=device)
self.assertEqual(mesh2.device, device)
mesh2 = mesh2.cpu()
self.assertClose(mesh2.verts_padded(), sphere.verts_padded())
self.assertClose(mesh2.faces_padded(), sphere.faces_padded())
self.assertClose(mesh2.textures.atlas_padded(), atlas_padded, atol=1e-4)
# White atlas case
sphere.textures = atlas_white
with NamedTemporaryFile(mode="w", suffix=".off") as f:
io.save_mesh(sphere, f.name)
f.flush()
mesh3 = io.load_mesh(f.name)
self.assertClose(mesh3.textures.atlas_padded(), atlas_padded_white, atol=1e-4)
# TexturesVertex case
sphere.textures = vertex_texture
with NamedTemporaryFile(mode="w", suffix=".off") as f:
io.save_mesh(sphere, f.name)
f.flush()
mesh4 = io.load_mesh(f.name, device=device)
self.assertEqual(mesh4.device, device)
mesh4 = mesh4.cpu()
self.assertClose(mesh4.verts_padded(), sphere.verts_padded())
self.assertClose(mesh4.faces_padded(), sphere.faces_padded())
self.assertClose(
mesh4.textures.verts_features_padded(), verts_colors_padded, atol=1e-4
)
# white TexturesVertex case
sphere.textures = vertex_texture_white
with NamedTemporaryFile(mode="w", suffix=".off") as f:
io.save_mesh(sphere, f.name)
f.flush()
mesh5 = io.load_mesh(f.name)
self.assertClose(
mesh5.textures.verts_features_padded(), verts_colors_padded_white, atol=1e-4
)
def test_bad(self):
# Test errors from various invalid OFF files.
io = IO()
def load(lines):
off_file = "\n".join(lines)
with NamedTemporaryFile(mode="w", suffix=".off") as f:
f.write(off_file)
f.flush()
io.load_mesh(f.name)
# First a good example
lines = [
"4 2 12",
" 1.0 0.0 1.4142",
" 0.0 1.0 1.4142",
" 1.0 0.0 0.4142",
" 0.0 1.0 0.4142",
"3 0 1 2 ",
"3 1 3 0 ",
]
# This example passes.
load(lines)
# OFF can occur on the first line separately
load(["OFF"] + lines)
# OFF line can be merged in to the first line
lines2 = lines.copy()
lines2[0] = "OFF " + lines[0]
load(lines2)
with self.assertRaisesRegex(ValueError, "Not enough face data."):
load(lines[:-1])
lines2 = lines.copy()
lines2[0] = "4 1 12"
with self.assertRaisesRegex(ValueError, "Extra data at end of file:"):
load(lines2)
lines2 = lines.copy()
lines2[-1] = "2 1 3"
with self.assertRaisesRegex(ValueError, "Faces must have at least 3 vertices."):
load(lines2)
lines2 = lines.copy()
lines2[-1] = "4 1 3 0"
with self.assertRaisesRegex(
ValueError, "A line of face data did not have the specified length."
):
load(lines2)
lines2 = lines.copy()
lines2[0] = "6 2 0"
with self.assertRaisesRegex(ValueError, "Wrong number of columns at line 5"):
load(lines2)
lines2[0] = "5 1 0"
with self.assertRaisesRegex(ValueError, "Wrong number of columns at line 5"):
load(lines2)
lines2[0] = "16 2 0"
with self.assertRaisesRegex(ValueError, "Wrong number of columns at line 5"):
load(lines2)
lines2[0] = "3 3 0"
# This is a bit of a special case because the last vertex could be a face
with self.assertRaisesRegex(ValueError, "Faces must have at least 3 vertices."):
load(lines2)
lines2[4] = "7.3 4.2 8.3"
with self.assertRaisesRegex(
ValueError, "A line of face data did not have the specified length."
):
load(lines2)
# Now try bad number of colors
lines2 = lines.copy()
lines2[2] = "7.3 4.2 8.3 932"
with self.assertRaisesRegex(ValueError, "Wrong number of columns at line 2"):
load(lines2)
lines2[1] = "7.3 4.2 8.3 932"
lines2[3] = "7.3 4.2 8.3 932"
lines2[4] = "7.3 4.2 8.3 932"
with self.assertRaisesRegex(ValueError, "Bad vertex data."):
load(lines2)
lines2 = lines.copy()
lines2[5] = "3 0 1 2 0.9"
lines2[6] = "3 0 3 0 0.9"
with self.assertRaisesRegex(ValueError, "Unexpected number of colors."):
load(lines2)
lines2 = lines.copy()
for i in range(1, 7):
lines2[i] = lines2[i] + " 4 4 4 4"
msg = "Faces colors ignored because vertex colors provided too."
with self.assertWarnsRegex(UserWarning, msg):
load(lines2)