mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Loading/saving meshes to OFF files.
Summary: Implements the ascii OFF file format. This was discussed in https://github.com/facebookresearch/pytorch3d/issues/216 Reviewed By: theschnitz Differential Revision: D25788834 fbshipit-source-id: c141d1f4ba3bad24e3c1f280a20aee782bfd74d6
This commit is contained in:
parent
4bfe7158b1
commit
0345f860d4
488
pytorch3d/io/off_io.py
Normal file
488
pytorch3d/io/off_io.py
Normal file
@ -0,0 +1,488 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
"""
|
||||
This module implements utility functions for loading and saving
|
||||
meshes as .off files.
|
||||
"""
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.io.utils import _check_faces_indices, _open_file
|
||||
from pytorch3d.renderer import TexturesAtlas, TexturesVertex
|
||||
from pytorch3d.structures import Meshes
|
||||
|
||||
from .pluggable_formats import MeshFormatInterpreter, endswith
|
||||
|
||||
|
||||
def _is_line_empty(line: Union[str, bytes]) -> bool:
|
||||
"""
|
||||
Returns whether line is not relevant in an OFF file.
|
||||
"""
|
||||
line = line.strip()
|
||||
return len(line) == 0 or line[:1] == b"#"
|
||||
|
||||
|
||||
def _count_next_line_periods(file) -> int:
|
||||
"""
|
||||
Returns the number of . characters before any # on the next
|
||||
meaningful line.
|
||||
"""
|
||||
old_offset = file.tell()
|
||||
line = file.readline()
|
||||
while _is_line_empty(line):
|
||||
line = file.readline()
|
||||
if len(line) == 0:
|
||||
raise ValueError("Premature end of file")
|
||||
|
||||
contents = line.split(b"#")[0]
|
||||
count = contents.count(b".")
|
||||
file.seek(old_offset)
|
||||
return count
|
||||
|
||||
|
||||
def _read_faces_lump(
|
||||
file, n_faces: int, n_colors: Optional[int]
|
||||
) -> Optional[Tuple[np.ndarray, int, Optional[np.ndarray]]]:
|
||||
"""
|
||||
Parse n_faces faces and faces_colors from the file,
|
||||
if they all have the same number of vertices.
|
||||
This is used in two ways.
|
||||
1) To try to read all faces.
|
||||
2) To read faces one-by-one if that failed.
|
||||
|
||||
Args:
|
||||
file: file-like object being read.
|
||||
n_faces: The known number of faces yet to read.
|
||||
n_colors: The number of colors if known already.
|
||||
|
||||
Returns:
|
||||
- 2D numpy array of faces
|
||||
- number of colors found
|
||||
- 2D numpy array of face colors if found.
|
||||
of None if there are faces with different numbers of vertices.
|
||||
"""
|
||||
if n_faces == 0:
|
||||
return np.array([[]]), 0, None
|
||||
old_offset = file.tell()
|
||||
try:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore", message=".* Empty input file.*", category=UserWarning
|
||||
)
|
||||
data = np.loadtxt(file, dtype=np.float32, ndmin=2, max_rows=n_faces)
|
||||
except ValueError as e:
|
||||
if n_faces > 1 and "Wrong number of columns" in e.args[0]:
|
||||
file.seek(old_offset)
|
||||
return None
|
||||
raise ValueError("Not enough face data.")
|
||||
|
||||
if len(data) != n_faces:
|
||||
raise ValueError("Not enough face data.")
|
||||
face_size = int(data[0, 0])
|
||||
if (data[:, 0] != face_size).any():
|
||||
msg = "A line of face data did not have the specified length."
|
||||
raise ValueError(msg)
|
||||
if face_size < 3:
|
||||
raise ValueError("Faces must have at least 3 vertices.")
|
||||
|
||||
n_colors_found = data.shape[1] - 1 - face_size
|
||||
if n_colors is not None and n_colors_found != n_colors:
|
||||
raise ValueError("Number of colors differs between faces.")
|
||||
n_colors = n_colors_found
|
||||
if n_colors not in [0, 3, 4]:
|
||||
raise ValueError("Unexpected number of colors.")
|
||||
|
||||
face_raw_data = data[:, 1 : 1 + face_size].astype("int64")
|
||||
if face_size == 3:
|
||||
face_data = face_raw_data
|
||||
else:
|
||||
face_arrays = [
|
||||
face_raw_data[:, [0, i + 1, i + 2]] for i in range(face_size - 2)
|
||||
]
|
||||
face_data = np.vstack(face_arrays)
|
||||
|
||||
if n_colors == 0:
|
||||
return face_data, 0, None
|
||||
colors = data[:, 1 + face_size :]
|
||||
if face_size == 3:
|
||||
return face_data, n_colors, colors
|
||||
return face_data, n_colors, np.tile(colors, (face_size - 2, 1))
|
||||
|
||||
|
||||
def _read_faces(
|
||||
file, n_faces: int
|
||||
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
"""
|
||||
Returns faces and face colors from the file.
|
||||
|
||||
Args:
|
||||
file: file-like object being read.
|
||||
n_faces: The known number of faces.
|
||||
|
||||
Returns:
|
||||
2D numpy arrays of faces and face colors, or None for each if
|
||||
they are not present.
|
||||
"""
|
||||
if n_faces == 0:
|
||||
return None, None
|
||||
|
||||
color_is_int = 0 == _count_next_line_periods(file)
|
||||
color_scale = 1 / 255.0 if color_is_int else 1
|
||||
|
||||
faces_ncolors_colors = _read_faces_lump(file, n_faces=n_faces, n_colors=None)
|
||||
if faces_ncolors_colors is not None:
|
||||
faces, _, colors = faces_ncolors_colors
|
||||
if colors is None:
|
||||
return faces, None
|
||||
return faces, colors * color_scale
|
||||
|
||||
faces_list, colors_list = [], []
|
||||
n_colors = None
|
||||
for _ in range(n_faces):
|
||||
faces_ncolors_colors = _read_faces_lump(file, n_faces=1, n_colors=n_colors)
|
||||
faces_found, n_colors, colors_found = cast(
|
||||
Tuple[np.ndarray, int, Optional[np.ndarray]], faces_ncolors_colors
|
||||
)
|
||||
faces_list.append(faces_found)
|
||||
colors_list.append(colors_found)
|
||||
faces = np.vstack(faces_list)
|
||||
if n_colors == 0:
|
||||
colors = None
|
||||
else:
|
||||
colors = np.vstack(colors_list) * color_scale
|
||||
return faces, colors
|
||||
|
||||
|
||||
def _read_verts(file, n_verts: int) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
||||
"""
|
||||
Returns verts and vertex colors from the file.
|
||||
|
||||
Args:
|
||||
file: file-like object being read.
|
||||
n_verts: The known number of faces.
|
||||
|
||||
Returns:
|
||||
2D numpy arrays of verts and (if present)
|
||||
vertex colors.
|
||||
"""
|
||||
|
||||
color_is_int = 3 == _count_next_line_periods(file)
|
||||
color_scale = 1 / 255.0 if color_is_int else 1
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore", message=".* Empty input file.*", category=UserWarning
|
||||
)
|
||||
data = np.loadtxt(file, dtype=np.float32, ndmin=2, max_rows=n_verts)
|
||||
if data.shape[0] != n_verts:
|
||||
raise ValueError("Not enough vertex data.")
|
||||
if data.shape[1] not in [3, 6, 7]:
|
||||
raise ValueError("Bad vertex data.")
|
||||
|
||||
if data.shape[1] == 3:
|
||||
return data, None
|
||||
return data[:, :3], data[:, 3:] * color_scale # []
|
||||
|
||||
|
||||
def _load_off_stream(file) -> dict:
|
||||
"""
|
||||
Load the data from a stream of an .off file.
|
||||
|
||||
Example .off file format:
|
||||
|
||||
off
|
||||
8 6 1927 { number of vertices, faces, and (not used) edges }
|
||||
# comment { comments with # sign }
|
||||
0 0 0 { start of vertex list }
|
||||
0 0 1
|
||||
0 1 1
|
||||
0 1 0
|
||||
1 0 0
|
||||
1 0 1
|
||||
1 1 1
|
||||
1 1 0
|
||||
4 0 1 2 3 { start of face list }
|
||||
4 7 6 5 4
|
||||
4 0 4 5 1
|
||||
4 1 5 6 2
|
||||
4 2 6 7 3
|
||||
4 3 7 4 0
|
||||
|
||||
Args:
|
||||
file: A binary file-like object (with methods read, readline,
|
||||
tell and seek).
|
||||
|
||||
Returns dictionary possibly containing:
|
||||
verts: (always present) FloatTensor of shape (V, 3).
|
||||
verts_colors: FloatTensor of shape (V, C) where C is 3 or 4.
|
||||
faces: LongTensor of vertex indices, split into triangles, shape (F, 3).
|
||||
faces_colors: FloatTensor of shape (F, C), where C is 3 or 4.
|
||||
"""
|
||||
header = file.readline()
|
||||
if header.lower() in (b"off\n", b"off\r\n", "off\n"):
|
||||
header = file.readline()
|
||||
|
||||
while _is_line_empty(header):
|
||||
header = file.readline()
|
||||
|
||||
items = header.split(b" ")
|
||||
if len(items) and items[0].lower() in ("off", b"off"):
|
||||
items = items[1:]
|
||||
if len(items) < 3:
|
||||
raise ValueError("Invalid counts line: %s" % header)
|
||||
|
||||
try:
|
||||
n_verts = int(items[0])
|
||||
except ValueError:
|
||||
raise ValueError("Invalid counts line: %s" % header)
|
||||
try:
|
||||
n_faces = int(items[1])
|
||||
except ValueError:
|
||||
raise ValueError("Invalid counts line: %s" % header)
|
||||
|
||||
if (len(items) > 3 and not items[3].startswith("#")) or n_verts < 0 or n_faces < 0:
|
||||
raise ValueError("Invalid counts line: %s" % header)
|
||||
|
||||
verts, verts_colors = _read_verts(file, n_verts)
|
||||
faces, faces_colors = _read_faces(file, n_faces)
|
||||
|
||||
end = file.read().strip()
|
||||
if len(end) != 0:
|
||||
raise ValueError("Extra data at end of file: " + str(end[:20]))
|
||||
|
||||
out = {"verts": verts}
|
||||
if verts_colors is not None:
|
||||
out["verts_colors"] = verts_colors
|
||||
if faces is not None:
|
||||
out["faces"] = faces
|
||||
if faces_colors is not None:
|
||||
out["faces_colors"] = faces_colors
|
||||
return out
|
||||
|
||||
|
||||
def _write_off_data(
|
||||
file,
|
||||
verts: torch.Tensor,
|
||||
verts_colors: Optional[torch.Tensor] = None,
|
||||
faces: Optional[torch.LongTensor] = None,
|
||||
faces_colors: Optional[torch.Tensor] = None,
|
||||
decimal_places: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Internal implementation for saving 3D data to a .off file.
|
||||
|
||||
Args:
|
||||
file: Binary file object to which the 3D data should be written.
|
||||
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
|
||||
verts_colors: FloatTensor of shape (V, C) giving vertex colors where C is 3 or 4.
|
||||
faces: LongTensor of shape (F, 3) giving faces.
|
||||
faces_colors: FloatTensor of shape (V, C) giving face colors where C is 3 or 4.
|
||||
decimal_places: Number of decimal places for saving.
|
||||
"""
|
||||
nfaces = 0 if faces is None else faces.shape[0]
|
||||
file.write(f"off\n{verts.shape[0]} {nfaces} 0\n".encode("ascii"))
|
||||
|
||||
if verts_colors is not None:
|
||||
verts = torch.cat((verts, verts_colors), dim=1)
|
||||
if decimal_places is None:
|
||||
float_str = "%f"
|
||||
else:
|
||||
float_str = "%" + ".%df" % decimal_places
|
||||
np.savetxt(file, verts.cpu().detach().numpy(), float_str)
|
||||
|
||||
if faces is not None:
|
||||
_check_faces_indices(faces, max_index=verts.shape[0])
|
||||
|
||||
if faces_colors is not None:
|
||||
face_data = torch.cat(
|
||||
[
|
||||
cast(torch.Tensor, faces).cpu().to(torch.float64),
|
||||
faces_colors.detach().cpu().to(torch.float64),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
format = "3 %d %d %d" + " %f" * faces_colors.shape[1]
|
||||
np.savetxt(file, face_data.numpy(), format)
|
||||
elif faces is not None:
|
||||
np.savetxt(file, faces.cpu().detach().numpy(), "3 %d %d %d")
|
||||
|
||||
|
||||
def _save_off(
|
||||
file,
|
||||
*,
|
||||
verts: torch.Tensor,
|
||||
verts_colors: Optional[torch.Tensor] = None,
|
||||
faces: Optional[torch.LongTensor] = None,
|
||||
faces_colors: Optional[torch.Tensor] = None,
|
||||
decimal_places: Optional[int] = None,
|
||||
path_manager: PathManager,
|
||||
) -> None:
|
||||
"""
|
||||
Save a mesh to an ascii .off file.
|
||||
|
||||
Args:
|
||||
file: File (or path) to which the mesh should be written.
|
||||
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
|
||||
verts_colors: FloatTensor of shape (V, C) giving vertex colors where C is 3 or 4.
|
||||
faces: LongTensor of shape (F, 3) giving faces.
|
||||
faces_colors: FloatTensor of shape (V, C) giving face colors where C is 3 or 4.
|
||||
decimal_places: Number of decimal places for saving.
|
||||
"""
|
||||
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
|
||||
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
|
||||
raise ValueError(message)
|
||||
|
||||
if verts_colors is not None and 0 == len(verts_colors):
|
||||
verts_colors = None
|
||||
if faces_colors is not None and 0 == len(faces_colors):
|
||||
faces_colors = None
|
||||
if faces is not None and 0 == len(faces):
|
||||
faces = None
|
||||
|
||||
if verts_colors is not None:
|
||||
if not (verts_colors.dim() == 2 and verts_colors.size(1) in [3, 4]):
|
||||
message = "verts_colors should have shape (num_faces, C)."
|
||||
raise ValueError(message)
|
||||
if verts_colors.shape[0] != verts.shape[0]:
|
||||
message = "verts_colors should have the same length as verts."
|
||||
raise ValueError(message)
|
||||
|
||||
if faces is not None and not (faces.dim() == 2 and faces.size(1) == 3):
|
||||
message = "Argument 'faces' if present should have shape (num_faces, 3)."
|
||||
raise ValueError(message)
|
||||
if faces_colors is not None and faces is None:
|
||||
message = "Cannot have face colors without faces"
|
||||
raise ValueError(message)
|
||||
|
||||
if faces_colors is not None:
|
||||
if not (faces_colors.dim() == 2 and faces_colors.size(1) in [3, 4]):
|
||||
message = "faces_colors should have shape (num_faces, C)."
|
||||
raise ValueError(message)
|
||||
if faces_colors.shape[0] != cast(torch.LongTensor, faces).shape[0]:
|
||||
message = "faces_colors should have the same length as faces."
|
||||
raise ValueError(message)
|
||||
|
||||
with _open_file(file, path_manager, "wb") as f:
|
||||
_write_off_data(f, verts, verts_colors, faces, faces_colors, decimal_places)
|
||||
|
||||
|
||||
class MeshOffFormat(MeshFormatInterpreter):
|
||||
"""
|
||||
Loads and saves meshes in the ascii OFF format. This is a simple
|
||||
format which can only deal with the following texture types:
|
||||
|
||||
- TexturesVertex, i.e. one color for each vertex
|
||||
- TexturesAtlas with R=1, i.e. one color for each face.
|
||||
|
||||
There are some possible features of OFF files which we do not support
|
||||
and which appear to be rare:
|
||||
|
||||
- Four dimensional data.
|
||||
- Binary data.
|
||||
- Vertex Normals.
|
||||
- Texture coordinates.
|
||||
- "COFF" header.
|
||||
|
||||
Example .off file format:
|
||||
|
||||
off
|
||||
8 6 1927 { number of vertices, faces, and (not used) edges }
|
||||
# comment { comments with # sign }
|
||||
0 0 0 { start of vertex list }
|
||||
0 0 1
|
||||
0 1 1
|
||||
0 1 0
|
||||
1 0 0
|
||||
1 0 1
|
||||
1 1 1
|
||||
1 1 0
|
||||
4 0 1 2 3 { start of face list }
|
||||
4 7 6 5 4
|
||||
4 0 4 5 1
|
||||
4 1 5 6 2
|
||||
4 2 6 7 3
|
||||
4 3 7 4 0
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.known_suffixes = (".off",)
|
||||
|
||||
def read(
|
||||
self,
|
||||
path: Union[str, Path],
|
||||
include_textures: bool,
|
||||
device,
|
||||
path_manager: PathManager,
|
||||
**kwargs,
|
||||
) -> Optional[Meshes]:
|
||||
if not endswith(path, self.known_suffixes):
|
||||
return None
|
||||
|
||||
with _open_file(path, path_manager, "rb") as f:
|
||||
data = _load_off_stream(f)
|
||||
verts = torch.from_numpy(data["verts"]).to(device)
|
||||
if "faces" in data:
|
||||
faces = torch.from_numpy(data["faces"]).to(dtype=torch.int64, device=device)
|
||||
else:
|
||||
faces = torch.zeros((0, 3), dtype=torch.int64, device=device)
|
||||
|
||||
textures = None
|
||||
if "verts_colors" in data:
|
||||
if "faces_colors" in data:
|
||||
msg = "Faces colors ignored because vertex colors provided too."
|
||||
warnings.warn(msg)
|
||||
verts_colors = torch.from_numpy(data["verts_colors"]).to(device)
|
||||
textures = TexturesVertex([verts_colors])
|
||||
elif "faces_colors" in data:
|
||||
faces_colors = torch.from_numpy(data["faces_colors"]).to(device)
|
||||
textures = TexturesAtlas([faces_colors[:, None, None, :]])
|
||||
|
||||
mesh = Meshes(
|
||||
verts=[verts.to(device)], faces=[faces.to(device)], textures=textures
|
||||
)
|
||||
return mesh
|
||||
|
||||
def save(
|
||||
self,
|
||||
data: Meshes,
|
||||
path: Union[str, Path],
|
||||
path_manager: PathManager,
|
||||
binary: Optional[bool],
|
||||
decimal_places: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
if not endswith(path, self.known_suffixes):
|
||||
return False
|
||||
|
||||
verts = data.verts_list()[0]
|
||||
faces = data.faces_list()[0]
|
||||
if isinstance(data.textures, TexturesVertex):
|
||||
[verts_colors] = data.textures.verts_features_list()
|
||||
else:
|
||||
verts_colors = None
|
||||
|
||||
faces_colors = None
|
||||
if isinstance(data.textures, TexturesAtlas):
|
||||
[atlas] = data.textures.atlas_list()
|
||||
F, R, _, D = atlas.shape
|
||||
if R == 1:
|
||||
faces_colors = atlas[:, 0, 0, :]
|
||||
|
||||
_save_off(
|
||||
file=path,
|
||||
verts=verts,
|
||||
faces=faces,
|
||||
verts_colors=verts_colors,
|
||||
faces_colors=faces_colors,
|
||||
decimal_places=decimal_places,
|
||||
path_manager=path_manager,
|
||||
)
|
||||
return True
|
@ -11,6 +11,7 @@ from iopath.common.file_io import PathManager
|
||||
from pytorch3d.structures import Meshes, Pointclouds
|
||||
|
||||
from .obj_io import MeshObjFormat
|
||||
from .off_io import MeshOffFormat
|
||||
from .pluggable_formats import MeshFormatInterpreter, PointcloudFormatInterpreter
|
||||
from .ply_io import MeshPlyFormat, PointcloudPlyFormat
|
||||
|
||||
@ -73,6 +74,7 @@ class IO:
|
||||
|
||||
def register_default_formats(self) -> None:
|
||||
self.register_meshes_format(MeshObjFormat())
|
||||
self.register_meshes_format(MeshOffFormat())
|
||||
self.register_meshes_format(MeshPlyFormat())
|
||||
self.register_pointcloud_format(PointcloudPlyFormat())
|
||||
|
||||
|
@ -5,7 +5,7 @@
|
||||
|
||||
"""
|
||||
This module implements utility functions for loading and saving
|
||||
meshes and point clouds from PLY files.
|
||||
meshes and point clouds as PLY files.
|
||||
"""
|
||||
import itertools
|
||||
import struct
|
||||
|
@ -3,8 +3,8 @@
|
||||
from itertools import product
|
||||
|
||||
from fvcore.common.benchmark import benchmark
|
||||
from test_obj_io import TestMeshObjIO
|
||||
from test_ply_io import TestMeshPlyIO
|
||||
from test_io_obj import TestMeshObjIO
|
||||
from test_io_ply import TestMeshPlyIO
|
||||
|
||||
|
||||
def bm_save_load() -> None:
|
||||
|
325
tests/test_io_off.py
Normal file
325
tests/test_io_off.py
Normal file
@ -0,0 +1,325 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import unittest
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.io import IO
|
||||
from pytorch3d.renderer import TexturesAtlas, TexturesVertex
|
||||
from pytorch3d.utils import ico_sphere
|
||||
|
||||
|
||||
CUBE_FACES = [
|
||||
[0, 1, 2],
|
||||
[7, 4, 0],
|
||||
[4, 5, 1],
|
||||
[5, 6, 2],
|
||||
[3, 2, 6],
|
||||
[6, 5, 4],
|
||||
[0, 2, 3],
|
||||
[7, 0, 3],
|
||||
[4, 1, 0],
|
||||
[5, 2, 1],
|
||||
[3, 6, 7],
|
||||
[6, 4, 7],
|
||||
]
|
||||
|
||||
|
||||
class TestMeshOffIO(TestCaseMixin, unittest.TestCase):
|
||||
def test_load_face_colors(self):
|
||||
# Example from wikipedia
|
||||
off_file_lines = [
|
||||
"OFF",
|
||||
"# cube.off",
|
||||
"# A cube",
|
||||
" ",
|
||||
"8 6 12",
|
||||
" 1.0 0.0 1.4142",
|
||||
" 0.0 1.0 1.4142",
|
||||
"-1.0 0.0 1.4142",
|
||||
" 0.0 -1.0 1.4142",
|
||||
" 1.0 0.0 0.0",
|
||||
" 0.0 1.0 0.0",
|
||||
"-1.0 0.0 0.0",
|
||||
" 0.0 -1.0 0.0",
|
||||
"4 0 1 2 3 255 0 0 #red",
|
||||
"4 7 4 0 3 0 255 0 #green",
|
||||
"4 4 5 1 0 0 0 255 #blue",
|
||||
"4 5 6 2 1 0 255 0 ",
|
||||
"4 3 2 6 7 0 0 255",
|
||||
"4 6 5 4 7 255 0 0",
|
||||
]
|
||||
off_file = "\n".join(off_file_lines)
|
||||
io = IO()
|
||||
with NamedTemporaryFile(mode="w", suffix=".off") as f:
|
||||
f.write(off_file)
|
||||
f.flush()
|
||||
mesh = io.load_mesh(f.name)
|
||||
|
||||
self.assertEqual(mesh.verts_padded().shape, (1, 8, 3))
|
||||
verts_str = " ".join(off_file_lines[5:13])
|
||||
verts_data = torch.tensor([float(i) for i in verts_str.split()])
|
||||
self.assertClose(mesh.verts_padded().flatten(), verts_data)
|
||||
self.assertClose(mesh.faces_padded(), torch.tensor(CUBE_FACES)[None])
|
||||
|
||||
faces_colors_full = mesh.textures.atlas_padded()
|
||||
self.assertEqual(faces_colors_full.shape, (1, 12, 1, 1, 3))
|
||||
faces_colors = faces_colors_full[0, :, 0, 0]
|
||||
max_color = faces_colors.max()
|
||||
self.assertEqual(max_color, 1)
|
||||
|
||||
# Every face has one color 1, the rest 0.
|
||||
total_color = faces_colors.sum(dim=1)
|
||||
self.assertEqual(total_color.max(), max_color)
|
||||
self.assertEqual(total_color.min(), max_color)
|
||||
|
||||
def test_load_vertex_colors(self):
|
||||
# Example with no faces and with integer vertex colors
|
||||
off_file_lines = [
|
||||
"8 1 12",
|
||||
" 1.0 0.0 1.4142 0 1 0",
|
||||
" 0.0 1.0 1.4142 0 1 0",
|
||||
"-1.0 0.0 1.4142 0 1 0",
|
||||
" 0.0 -1.0 1.4142 0 1 0",
|
||||
" 1.0 0.0 0.0 0 1 0",
|
||||
" 0.0 1.0 0.0 0 1 0",
|
||||
"-1.0 0.0 0.0 0 1 0",
|
||||
" 0.0 -1.0 0.0 0 1 0",
|
||||
"3 0 1 2",
|
||||
]
|
||||
off_file = "\n".join(off_file_lines)
|
||||
io = IO()
|
||||
with NamedTemporaryFile(mode="w", suffix=".off") as f:
|
||||
f.write(off_file)
|
||||
f.flush()
|
||||
mesh = io.load_mesh(f.name)
|
||||
|
||||
self.assertEqual(mesh.verts_padded().shape, (1, 8, 3))
|
||||
verts_lines = (line.split()[:3] for line in off_file_lines[1:9])
|
||||
verts_data = [[[float(x) for x in line] for line in verts_lines]]
|
||||
self.assertClose(mesh.verts_padded(), torch.tensor(verts_data))
|
||||
self.assertClose(mesh.faces_padded(), torch.tensor([[[0, 1, 2]]]))
|
||||
|
||||
self.assertIsInstance(mesh.textures, TexturesVertex)
|
||||
colors = mesh.textures.verts_features_padded()
|
||||
|
||||
self.assertEqual(colors.shape, (1, 8, 3))
|
||||
self.assertClose(colors[0, :, [0, 2]], torch.zeros(8, 2))
|
||||
self.assertClose(colors[0, :, 1], torch.full((8,), 1.0 / 255))
|
||||
|
||||
def test_load_lumpy(self):
|
||||
# Example off file whose faces have different numbers of vertices.
|
||||
off_file_lines = [
|
||||
"8 3 12",
|
||||
" 1.0 0.0 1.4142",
|
||||
" 0.0 1.0 1.4142",
|
||||
"-1.0 0.0 1.4142",
|
||||
" 0.0 -1.0 1.4142",
|
||||
" 1.0 0.0 0.0",
|
||||
" 0.0 1.0 0.0",
|
||||
"-1.0 0.0 0.0",
|
||||
" 0.0 -1.0 0.0",
|
||||
"3 0 1 2 255 0 0 #red",
|
||||
"4 7 4 0 3 0 255 0 #green",
|
||||
"4 4 5 1 0 0 0 255 #blue",
|
||||
]
|
||||
off_file = "\n".join(off_file_lines)
|
||||
io = IO()
|
||||
with NamedTemporaryFile(mode="w", suffix=".off") as f:
|
||||
f.write(off_file)
|
||||
f.flush()
|
||||
mesh = io.load_mesh(f.name)
|
||||
|
||||
self.assertEqual(mesh.verts_padded().shape, (1, 8, 3))
|
||||
verts_str = " ".join(off_file_lines[1:9])
|
||||
verts_data = torch.tensor([float(i) for i in verts_str.split()])
|
||||
self.assertClose(mesh.verts_padded().flatten(), verts_data)
|
||||
|
||||
self.assertEqual(mesh.faces_padded().shape, (1, 5, 3))
|
||||
faces_expected = [[0, 1, 2], [7, 4, 0], [7, 0, 3], [4, 5, 1], [4, 1, 0]]
|
||||
self.assertClose(mesh.faces_padded()[0], torch.tensor(faces_expected))
|
||||
|
||||
def test_save_load_icosphere(self):
|
||||
# Test that saving a mesh as an off file and loading it results in the
|
||||
# same data on the correct device, for all permitted types of textures.
|
||||
# Standard test is for random colors, but also check totally white,
|
||||
# because there's a different in OFF semantics between "1.0" color (=full)
|
||||
# and "1" (= 1/255 color)
|
||||
sphere = ico_sphere(0)
|
||||
io = IO()
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
atlas_padded = torch.rand(1, sphere.faces_list()[0].shape[0], 1, 1, 3)
|
||||
atlas = TexturesAtlas(atlas_padded)
|
||||
|
||||
atlas_padded_white = torch.ones(1, sphere.faces_list()[0].shape[0], 1, 1, 3)
|
||||
atlas_white = TexturesAtlas(atlas_padded_white)
|
||||
|
||||
verts_colors_padded = torch.rand(1, sphere.verts_list()[0].shape[0], 3)
|
||||
vertex_texture = TexturesVertex(verts_colors_padded)
|
||||
|
||||
verts_colors_padded_white = torch.ones(1, sphere.verts_list()[0].shape[0], 3)
|
||||
vertex_texture_white = TexturesVertex(verts_colors_padded_white)
|
||||
|
||||
# No colors case
|
||||
with NamedTemporaryFile(mode="w", suffix=".off") as f:
|
||||
io.save_mesh(sphere, f.name)
|
||||
f.flush()
|
||||
mesh1 = io.load_mesh(f.name, device=device)
|
||||
self.assertEqual(mesh1.device, device)
|
||||
mesh1 = mesh1.cpu()
|
||||
self.assertClose(mesh1.verts_padded(), sphere.verts_padded())
|
||||
self.assertClose(mesh1.faces_padded(), sphere.faces_padded())
|
||||
self.assertIsNone(mesh1.textures)
|
||||
|
||||
# Atlas case
|
||||
sphere.textures = atlas
|
||||
with NamedTemporaryFile(mode="w", suffix=".off") as f:
|
||||
io.save_mesh(sphere, f.name)
|
||||
f.flush()
|
||||
mesh2 = io.load_mesh(f.name, device=device)
|
||||
|
||||
self.assertEqual(mesh2.device, device)
|
||||
mesh2 = mesh2.cpu()
|
||||
self.assertClose(mesh2.verts_padded(), sphere.verts_padded())
|
||||
self.assertClose(mesh2.faces_padded(), sphere.faces_padded())
|
||||
self.assertClose(mesh2.textures.atlas_padded(), atlas_padded, atol=1e-4)
|
||||
|
||||
# White atlas case
|
||||
sphere.textures = atlas_white
|
||||
with NamedTemporaryFile(mode="w", suffix=".off") as f:
|
||||
io.save_mesh(sphere, f.name)
|
||||
f.flush()
|
||||
mesh3 = io.load_mesh(f.name)
|
||||
|
||||
self.assertClose(mesh3.textures.atlas_padded(), atlas_padded_white, atol=1e-4)
|
||||
|
||||
# TexturesVertex case
|
||||
sphere.textures = vertex_texture
|
||||
with NamedTemporaryFile(mode="w", suffix=".off") as f:
|
||||
io.save_mesh(sphere, f.name)
|
||||
f.flush()
|
||||
mesh4 = io.load_mesh(f.name, device=device)
|
||||
|
||||
self.assertEqual(mesh4.device, device)
|
||||
mesh4 = mesh4.cpu()
|
||||
self.assertClose(mesh4.verts_padded(), sphere.verts_padded())
|
||||
self.assertClose(mesh4.faces_padded(), sphere.faces_padded())
|
||||
self.assertClose(
|
||||
mesh4.textures.verts_features_padded(), verts_colors_padded, atol=1e-4
|
||||
)
|
||||
|
||||
# white TexturesVertex case
|
||||
sphere.textures = vertex_texture_white
|
||||
with NamedTemporaryFile(mode="w", suffix=".off") as f:
|
||||
io.save_mesh(sphere, f.name)
|
||||
f.flush()
|
||||
mesh5 = io.load_mesh(f.name)
|
||||
|
||||
self.assertClose(
|
||||
mesh5.textures.verts_features_padded(), verts_colors_padded_white, atol=1e-4
|
||||
)
|
||||
|
||||
def test_bad(self):
|
||||
# Test errors from various invalid OFF files.
|
||||
io = IO()
|
||||
|
||||
def load(lines):
|
||||
off_file = "\n".join(lines)
|
||||
with NamedTemporaryFile(mode="w", suffix=".off") as f:
|
||||
f.write(off_file)
|
||||
f.flush()
|
||||
io.load_mesh(f.name)
|
||||
|
||||
# First a good example
|
||||
lines = [
|
||||
"4 2 12",
|
||||
" 1.0 0.0 1.4142",
|
||||
" 0.0 1.0 1.4142",
|
||||
" 1.0 0.0 0.4142",
|
||||
" 0.0 1.0 0.4142",
|
||||
"3 0 1 2 ",
|
||||
"3 1 3 0 ",
|
||||
]
|
||||
|
||||
# This example passes.
|
||||
load(lines)
|
||||
|
||||
# OFF can occur on the first line separately
|
||||
load(["OFF"] + lines)
|
||||
|
||||
# OFF line can be merged in to the first line
|
||||
lines2 = lines.copy()
|
||||
lines2[0] = "OFF " + lines[0]
|
||||
load(lines2)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Not enough face data."):
|
||||
load(lines[:-1])
|
||||
|
||||
lines2 = lines.copy()
|
||||
lines2[0] = "4 1 12"
|
||||
with self.assertRaisesRegex(ValueError, "Extra data at end of file:"):
|
||||
load(lines2)
|
||||
|
||||
lines2 = lines.copy()
|
||||
lines2[-1] = "2 1 3"
|
||||
with self.assertRaisesRegex(ValueError, "Faces must have at least 3 vertices."):
|
||||
load(lines2)
|
||||
|
||||
lines2 = lines.copy()
|
||||
lines2[-1] = "4 1 3 0"
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "A line of face data did not have the specified length."
|
||||
):
|
||||
load(lines2)
|
||||
|
||||
lines2 = lines.copy()
|
||||
lines2[0] = "6 2 0"
|
||||
with self.assertRaisesRegex(ValueError, "Wrong number of columns at line 5"):
|
||||
load(lines2)
|
||||
|
||||
lines2[0] = "5 1 0"
|
||||
with self.assertRaisesRegex(ValueError, "Wrong number of columns at line 5"):
|
||||
load(lines2)
|
||||
|
||||
lines2[0] = "16 2 0"
|
||||
with self.assertRaisesRegex(ValueError, "Wrong number of columns at line 5"):
|
||||
load(lines2)
|
||||
|
||||
lines2[0] = "3 3 0"
|
||||
# This is a bit of a special case because the last vertex could be a face
|
||||
with self.assertRaisesRegex(ValueError, "Faces must have at least 3 vertices."):
|
||||
load(lines2)
|
||||
|
||||
lines2[4] = "7.3 4.2 8.3"
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "A line of face data did not have the specified length."
|
||||
):
|
||||
load(lines2)
|
||||
|
||||
# Now try bad number of colors
|
||||
|
||||
lines2 = lines.copy()
|
||||
lines2[2] = "7.3 4.2 8.3 932"
|
||||
with self.assertRaisesRegex(ValueError, "Wrong number of columns at line 2"):
|
||||
load(lines2)
|
||||
|
||||
lines2[1] = "7.3 4.2 8.3 932"
|
||||
lines2[3] = "7.3 4.2 8.3 932"
|
||||
lines2[4] = "7.3 4.2 8.3 932"
|
||||
with self.assertRaisesRegex(ValueError, "Bad vertex data."):
|
||||
load(lines2)
|
||||
|
||||
lines2 = lines.copy()
|
||||
lines2[5] = "3 0 1 2 0.9"
|
||||
lines2[6] = "3 0 3 0 0.9"
|
||||
with self.assertRaisesRegex(ValueError, "Unexpected number of colors."):
|
||||
load(lines2)
|
||||
|
||||
lines2 = lines.copy()
|
||||
for i in range(1, 7):
|
||||
lines2[i] = lines2[i] + " 4 4 4 4"
|
||||
msg = "Faces colors ignored because vertex colors provided too."
|
||||
with self.assertWarnsRegex(UserWarning, msg):
|
||||
load(lines2)
|
Loading…
x
Reference in New Issue
Block a user