mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
PLY pointcloud loading
Summary: Allow PLY files to not contain faces. Allow loading pointclouds with color, at least encoded according to the way of some cloudcompare examples. TODO: Allow vertex normals to be read, and allow vertex colors to be written. Make the return type of load_ply something more user friendly, like a dict. Noticed in https://github.com/facebookresearch/pytorch3d/issues/209 Reviewed By: nikhilaravi Differential Revision: D22573314 fbshipit-source-id: 72ba1f7c6417f5dfc83f2ebf359eff017057635c
This commit is contained in:
parent
3b9fbfc08c
commit
95707fba1c
@ -12,7 +12,7 @@ from pytorch3d.structures import Meshes, Pointclouds
|
||||
|
||||
from .obj_io import MeshObjFormat
|
||||
from .pluggable_formats import MeshFormatInterpreter, PointcloudFormatInterpreter
|
||||
from .ply_io import MeshPlyFormat
|
||||
from .ply_io import MeshPlyFormat, PointcloudPlyFormat
|
||||
|
||||
|
||||
"""
|
||||
@ -74,6 +74,7 @@ class IO:
|
||||
def register_default_formats(self) -> None:
|
||||
self.register_meshes_format(MeshObjFormat())
|
||||
self.register_meshes_format(MeshPlyFormat())
|
||||
self.register_pointcloud_format(PointcloudPlyFormat())
|
||||
|
||||
def register_meshes_format(self, interpreter: MeshFormatInterpreter) -> None:
|
||||
"""
|
||||
|
@ -3,23 +3,30 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
"""This module implements utility functions for loading and saving meshes."""
|
||||
"""
|
||||
This module implements utility functions for loading and saving
|
||||
meshes and point clouds from PLY files.
|
||||
"""
|
||||
import itertools
|
||||
import struct
|
||||
import sys
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from io import BytesIO
|
||||
from io import BytesIO, TextIOBase
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
|
||||
from pytorch3d.structures import Meshes
|
||||
from pytorch3d.structures import Meshes, Pointclouds
|
||||
|
||||
from .pluggable_formats import MeshFormatInterpreter, endswith
|
||||
from .pluggable_formats import (
|
||||
MeshFormatInterpreter,
|
||||
PointcloudFormatInterpreter,
|
||||
endswith,
|
||||
)
|
||||
|
||||
|
||||
_PlyTypeData = namedtuple("_PlyTypeData", "size struct_char np_type")
|
||||
@ -127,7 +134,7 @@ class _PlyHeader:
|
||||
self.elements: (List[_PlyElementType]) element description
|
||||
self.ascii: (bool) Whether in ascii format
|
||||
self.big_endian: (bool) (if not ascii) whether big endian
|
||||
self.obj_info: (dict) arbitrary extra data
|
||||
self.obj_info: (List[str]) arbitrary extra data
|
||||
|
||||
Args:
|
||||
f: file-like object.
|
||||
@ -136,7 +143,7 @@ class _PlyHeader:
|
||||
raise ValueError("Invalid file header.")
|
||||
seen_format = False
|
||||
self.elements = []
|
||||
self.obj_info = {}
|
||||
self.obj_info = []
|
||||
while True:
|
||||
line = f.readline()
|
||||
if isinstance(line, bytes):
|
||||
@ -172,11 +179,8 @@ class _PlyHeader:
|
||||
if line.startswith("element"):
|
||||
self._parse_element(line)
|
||||
continue
|
||||
if line.startswith("obj_info"):
|
||||
items = line.split(" ")
|
||||
if len(items) != 3:
|
||||
raise ValueError("Invalid line: %s" % line)
|
||||
self.obj_info[items[1]] = items[2]
|
||||
if line.startswith("obj_info "):
|
||||
self.obj_info.append(line[9:])
|
||||
continue
|
||||
if line.startswith("property"):
|
||||
self._parse_property(line)
|
||||
@ -736,6 +740,10 @@ def _load_ply_raw_stream(f) -> Tuple[_PlyHeader, dict]:
|
||||
for element in header.elements:
|
||||
elements[element.name] = _read_ply_element_ascii(f, element)
|
||||
else:
|
||||
if isinstance(f, TextIOBase):
|
||||
raise ValueError(
|
||||
"Cannot safely read a binary ply file using a Text stream."
|
||||
)
|
||||
big = header.big_endian
|
||||
for element in header.elements:
|
||||
elements[element.name] = _read_ply_element_binary(f, element, big)
|
||||
@ -769,7 +777,187 @@ def _load_ply_raw(f, path_manager: PathManager) -> Tuple[_PlyHeader, dict]:
|
||||
return header, elements
|
||||
|
||||
|
||||
def load_ply(f, path_manager: Optional[PathManager] = None):
|
||||
def _get_verts_column_indices(
|
||||
vertex_head: _PlyElementType,
|
||||
) -> Tuple[List[int], Optional[List[int]]]:
|
||||
"""
|
||||
Get the columns of verts and verts_colors in the vertex
|
||||
element of a parsed ply file.
|
||||
|
||||
Args:
|
||||
vertex_head: as returned from load_ply_raw.
|
||||
|
||||
Returns:
|
||||
point_idxs: List[int] of 3 point columns.
|
||||
color_idxs: List[int] of 3 color columns if they are present,
|
||||
otherwise None.
|
||||
"""
|
||||
point_idxs: List[Optional[int]] = [None, None, None]
|
||||
color_idxs: List[Optional[int]] = [None, None, None]
|
||||
for i, prop in enumerate(vertex_head.properties):
|
||||
if prop.list_size_type is not None:
|
||||
raise ValueError("Invalid vertices in file: did not expect list.")
|
||||
for j, letter in enumerate(["x", "y", "z"]):
|
||||
if prop.name == letter:
|
||||
point_idxs[j] = i
|
||||
for j, name in enumerate(["red", "green", "blue"]):
|
||||
if prop.name == name:
|
||||
color_idxs[j] = i
|
||||
if None in point_idxs:
|
||||
raise ValueError("Invalid vertices in file.")
|
||||
if None in color_idxs:
|
||||
return point_idxs, None
|
||||
return point_idxs, color_idxs
|
||||
|
||||
|
||||
def _get_verts(
|
||||
header: _PlyHeader, elements: dict
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Get the vertex locations and colors from a parsed ply file.
|
||||
|
||||
Args:
|
||||
header, elements: as returned from load_ply_raw.
|
||||
|
||||
Returns:
|
||||
verts: FloatTensor of shape (V, 3).
|
||||
vertex_colors: None or FloatTensor of shape (V, 3).
|
||||
"""
|
||||
|
||||
vertex = elements.get("vertex", None)
|
||||
if vertex is None:
|
||||
raise ValueError("The ply file has no vertex element.")
|
||||
if not isinstance(vertex, list):
|
||||
raise ValueError("Invalid vertices in file.")
|
||||
vertex_head = next(head for head in header.elements if head.name == "vertex")
|
||||
point_idxs, color_idxs = _get_verts_column_indices(vertex_head)
|
||||
|
||||
# Case of no vertices
|
||||
if vertex_head.count == 0:
|
||||
verts = torch.zeros((0, 3), dtype=torch.float32)
|
||||
if color_idxs is None:
|
||||
return verts, None
|
||||
return verts, torch.zeros((0, 3), dtype=torch.float32)
|
||||
|
||||
# Simple case where the only data is the vertices themselves
|
||||
if (
|
||||
len(vertex) == 1
|
||||
and isinstance(vertex[0], np.ndarray)
|
||||
and vertex[0].ndim == 2
|
||||
and vertex[0].shape[1] == 3
|
||||
):
|
||||
return _make_tensor(vertex[0], cols=3, dtype=torch.float32), None
|
||||
|
||||
vertex_colors = None
|
||||
|
||||
if len(vertex) == 1:
|
||||
# This is the case where the whole vertex element has one type,
|
||||
# so it was read as a single array and we can index straight into it.
|
||||
verts = torch.tensor(vertex[0][:, point_idxs], dtype=torch.float32)
|
||||
if color_idxs is not None:
|
||||
vertex_colors = torch.tensor(vertex[0][:, color_idxs], dtype=torch.float32)
|
||||
else:
|
||||
# The vertex element is heterogeneous. It was read as several arrays,
|
||||
# part by part, where a part is a set of properties with the same type.
|
||||
# For each property (=column in the file), we store in
|
||||
# prop_to_partnum_col its partnum (i.e. the index of what part it is
|
||||
# in) and its column number (its index within its part).
|
||||
prop_to_partnum_col = [
|
||||
(partnum, col)
|
||||
for partnum, array in enumerate(vertex)
|
||||
for col in range(array.shape[1])
|
||||
]
|
||||
verts = torch.empty(size=(vertex_head.count, 3), dtype=torch.float32)
|
||||
for axis in range(3):
|
||||
partnum, col = prop_to_partnum_col[point_idxs[axis]]
|
||||
verts.numpy()[:, axis] = vertex[partnum][:, col]
|
||||
# Note that in the previous line, we made the assignment
|
||||
# as numpy arrays by casting verts. If we took the (more
|
||||
# obvious) method of converting the right hand side to
|
||||
# torch, then we might have an extra data copy because
|
||||
# torch wants contiguity. The code would be like:
|
||||
# if not vertex[partnum].flags["C_CONTIGUOUS"]:
|
||||
# vertex[partnum] = np.ascontiguousarray(vertex[partnum])
|
||||
# verts[:, axis] = torch.tensor((vertex[partnum][:, col]))
|
||||
if color_idxs is not None:
|
||||
vertex_colors = torch.empty(
|
||||
size=(vertex_head.count, 3), dtype=torch.float32
|
||||
)
|
||||
for color in range(3):
|
||||
partnum, col = prop_to_partnum_col[color_idxs[color]]
|
||||
vertex_colors.numpy()[:, color] = vertex[partnum][:, col]
|
||||
|
||||
return verts, vertex_colors
|
||||
|
||||
|
||||
def _load_ply(
|
||||
f, *, path_manager: PathManager, return_vertex_colors: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Load the data from a .ply file.
|
||||
|
||||
Args:
|
||||
f: A binary or text file-like object (with methods read, readline,
|
||||
tell and seek), a pathlib path or a string containing a file name.
|
||||
If the ply file is in the binary ply format rather than the text
|
||||
ply format, then a text stream is not supported.
|
||||
It is easiest to use a binary stream in all cases.
|
||||
path_manager: PathManager for loading if f is a str.
|
||||
return_vertex_colors: whether to return vertex colors.
|
||||
|
||||
Returns:
|
||||
verts: FloatTensor of shape (V, 3).
|
||||
faces: None or LongTensor of vertex indices, shape (F, 3).
|
||||
vertex_colors: None or FloatTensor of shape (V, 3), only if requested
|
||||
"""
|
||||
header, elements = _load_ply_raw(f, path_manager=path_manager)
|
||||
|
||||
verts, vertex_colors = _get_verts(header, elements)
|
||||
|
||||
face = elements.get("face", None)
|
||||
if face is not None:
|
||||
face_head = next(head for head in header.elements if head.name == "face")
|
||||
if (
|
||||
len(face_head.properties) != 1
|
||||
or face_head.properties[0].list_size_type is None
|
||||
):
|
||||
raise ValueError("Unexpected form of faces data.")
|
||||
# face_head.properties[0].name is usually "vertex_index" or "vertex_indices"
|
||||
# but we don't need to enforce this.
|
||||
|
||||
if face is None:
|
||||
faces = None
|
||||
elif not len(face):
|
||||
# pyre is happier when this condition is not joined to the
|
||||
# previous one with `or`.
|
||||
faces = None
|
||||
elif isinstance(face, np.ndarray) and face.ndim == 2: # Homogeneous elements
|
||||
if face.shape[1] < 3:
|
||||
raise ValueError("Faces must have at least 3 vertices.")
|
||||
face_arrays = [face[:, [0, i + 1, i + 2]] for i in range(face.shape[1] - 2)]
|
||||
faces = torch.LongTensor(np.vstack(face_arrays))
|
||||
else:
|
||||
face_list = []
|
||||
for face_item in face:
|
||||
if face_item.ndim != 1:
|
||||
raise ValueError("Bad face data.")
|
||||
if face_item.shape[0] < 3:
|
||||
raise ValueError("Faces must have at least 3 vertices.")
|
||||
for i in range(face_item.shape[0] - 2):
|
||||
face_list.append([face_item[0], face_item[i + 1], face_item[i + 2]])
|
||||
faces = torch.tensor(face_list, dtype=torch.int64)
|
||||
|
||||
if faces is not None:
|
||||
_check_faces_indices(faces, max_index=verts.shape[0])
|
||||
|
||||
if return_vertex_colors:
|
||||
return verts, faces, vertex_colors
|
||||
return verts, faces, None
|
||||
|
||||
|
||||
def load_ply(
|
||||
f, *, path_manager: Optional[PathManager] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Load the data from a .ply file.
|
||||
|
||||
@ -809,72 +997,27 @@ def load_ply(f, path_manager: Optional[PathManager] = None):
|
||||
It is easiest to use a binary stream in all cases.
|
||||
path_manager: PathManager for loading if f is a str.
|
||||
|
||||
|
||||
Returns:
|
||||
verts: FloatTensor of shape (V, 3).
|
||||
faces: LongTensor of vertex indices, shape (F, 3).
|
||||
"""
|
||||
|
||||
if path_manager is None:
|
||||
path_manager = PathManager()
|
||||
header, elements = _load_ply_raw(f, path_manager=path_manager)
|
||||
verts, faces, _ = _load_ply(f, path_manager=path_manager)
|
||||
if faces is None:
|
||||
faces = torch.zeros(0, 3, dtype=torch.int64)
|
||||
|
||||
vertex = elements.get("vertex", None)
|
||||
if vertex is None:
|
||||
raise ValueError("The ply file has no vertex element.")
|
||||
|
||||
face = elements.get("face", None)
|
||||
if face is None:
|
||||
raise ValueError("The ply file has no face element.")
|
||||
|
||||
if not isinstance(vertex, list) or len(vertex) > 1:
|
||||
raise ValueError("Invalid vertices in file.")
|
||||
|
||||
if len(vertex):
|
||||
vertex0 = vertex[0]
|
||||
if len(vertex0) and (
|
||||
not isinstance(vertex0, np.ndarray)
|
||||
or vertex0.ndim != 2
|
||||
or vertex0.shape[1] != 3
|
||||
):
|
||||
raise ValueError("Invalid vertices in file.")
|
||||
else:
|
||||
vertex0 = []
|
||||
verts = _make_tensor(vertex0, cols=3, dtype=torch.float32)
|
||||
|
||||
face_head = next(head for head in header.elements if head.name == "face")
|
||||
if len(face_head.properties) != 1 or face_head.properties[0].list_size_type is None:
|
||||
raise ValueError("Unexpected form of faces data.")
|
||||
# face_head.properties[0].name is usually "vertex_index" or "vertex_indices"
|
||||
# but we don't need to enforce this.
|
||||
|
||||
if not len(face):
|
||||
faces = torch.zeros((0, 3), dtype=torch.int64)
|
||||
elif isinstance(face, np.ndarray) and face.ndim == 2: # Homogeneous elements
|
||||
if face.shape[1] < 3:
|
||||
raise ValueError("Faces must have at least 3 vertices.")
|
||||
face_arrays = [face[:, [0, i + 1, i + 2]] for i in range(face.shape[1] - 2)]
|
||||
faces = torch.LongTensor(np.vstack(face_arrays))
|
||||
else:
|
||||
face_list = []
|
||||
for face_item in face:
|
||||
if face_item.ndim != 1:
|
||||
raise ValueError("Bad face data.")
|
||||
if face_item.shape[0] < 3:
|
||||
raise ValueError("Faces must have at least 3 vertices.")
|
||||
for i in range(face_item.shape[0] - 2):
|
||||
face_list.append([face_item[0], face_item[i + 1], face_item[i + 2]])
|
||||
# pyre-fixme[6]: Expected `dtype` for 3rd param but got `Type[torch.int64]`.
|
||||
faces = _make_tensor(face_list, cols=3, dtype=torch.int64)
|
||||
|
||||
_check_faces_indices(faces, max_index=verts.shape[0])
|
||||
return verts, faces
|
||||
|
||||
|
||||
def _save_ply(
|
||||
f,
|
||||
*,
|
||||
verts: torch.Tensor,
|
||||
faces: torch.LongTensor,
|
||||
faces: Optional[torch.LongTensor],
|
||||
verts_normals: torch.Tensor,
|
||||
verts_colors: torch.Tensor,
|
||||
ascii: bool,
|
||||
decimal_places: Optional[int] = None,
|
||||
) -> None:
|
||||
@ -890,10 +1033,14 @@ def _save_ply(
|
||||
decimal_places: Number of decimal places for saving if ascii=True.
|
||||
"""
|
||||
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
|
||||
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
|
||||
if faces is not None:
|
||||
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
|
||||
assert not len(verts_normals) or (
|
||||
verts_normals.dim() == 2 and verts_normals.size(1) == 3
|
||||
)
|
||||
assert not len(verts_colors) or (
|
||||
verts_colors.dim() == 2 and verts_colors.size(1) == 3
|
||||
)
|
||||
|
||||
if ascii:
|
||||
f.write(b"ply\nformat ascii 1.0\n")
|
||||
@ -909,15 +1056,20 @@ def _save_ply(
|
||||
f.write(b"property float nx\n")
|
||||
f.write(b"property float ny\n")
|
||||
f.write(b"property float nz\n")
|
||||
f.write(f"element face {faces.shape[0]}\n".encode("ascii"))
|
||||
f.write(b"property list uchar int vertex_index\n")
|
||||
if verts_colors.numel() > 0:
|
||||
f.write(b"property float red\n")
|
||||
f.write(b"property float green\n")
|
||||
f.write(b"property float blue\n")
|
||||
if len(verts) and faces is not None:
|
||||
f.write(f"element face {faces.shape[0]}\n".encode("ascii"))
|
||||
f.write(b"property list uchar int vertex_index\n")
|
||||
f.write(b"end_header\n")
|
||||
|
||||
if not (len(verts) or len(faces)):
|
||||
warnings.warn("Empty 'verts' and 'faces' arguments provided")
|
||||
if not (len(verts)):
|
||||
warnings.warn("Empty 'verts' provided")
|
||||
return
|
||||
|
||||
vert_data = torch.cat((verts, verts_normals), dim=1).detach().numpy()
|
||||
vert_data = torch.cat((verts, verts_normals, verts_colors), dim=1).detach().numpy()
|
||||
if ascii:
|
||||
if decimal_places is None:
|
||||
float_str = "%f"
|
||||
@ -932,21 +1084,22 @@ def _save_ply(
|
||||
else:
|
||||
vert_data.tofile(f)
|
||||
|
||||
faces_array = faces.detach().numpy()
|
||||
if faces is not None:
|
||||
faces_array = faces.detach().numpy()
|
||||
|
||||
_check_faces_indices(faces, max_index=verts.shape[0])
|
||||
_check_faces_indices(faces, max_index=verts.shape[0])
|
||||
|
||||
if len(faces_array):
|
||||
if ascii:
|
||||
np.savetxt(f, faces_array, "3 %d %d %d")
|
||||
else:
|
||||
# rows are 13 bytes: a one-byte 3 followed by three four-byte face indices.
|
||||
faces_uints = np.full((len(faces_array), 13), 3, dtype=np.uint8)
|
||||
faces_uints[:, 1:] = faces_array.astype(np.uint32).view(np.uint8)
|
||||
if isinstance(f, BytesIO):
|
||||
f.write(faces_uints.tobytes())
|
||||
if len(faces_array):
|
||||
if ascii:
|
||||
np.savetxt(f, faces_array, "3 %d %d %d")
|
||||
else:
|
||||
faces_uints.tofile(f)
|
||||
# rows are 13 bytes: a one-byte 3 followed by three four-byte face indices.
|
||||
faces_uints = np.full((len(faces_array), 13), 3, dtype=np.uint8)
|
||||
faces_uints[:, 1:] = faces_array.astype(np.uint32).view(np.uint8)
|
||||
if isinstance(f, BytesIO):
|
||||
f.write(faces_uints.tobytes())
|
||||
else:
|
||||
faces_uints.tofile(f)
|
||||
|
||||
|
||||
def save_ply(
|
||||
@ -977,13 +1130,16 @@ def save_ply(
|
||||
if verts_normals is None
|
||||
else verts_normals
|
||||
)
|
||||
faces = torch.LongTensor([]) if faces is None else faces
|
||||
|
||||
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
|
||||
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
|
||||
raise ValueError(message)
|
||||
|
||||
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
|
||||
if (
|
||||
faces is not None
|
||||
and len(faces)
|
||||
and not (faces.dim() == 2 and faces.size(1) == 3)
|
||||
):
|
||||
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
|
||||
raise ValueError(message)
|
||||
|
||||
@ -995,10 +1151,20 @@ def save_ply(
|
||||
message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)."
|
||||
raise ValueError(message)
|
||||
|
||||
verts_colors = torch.FloatTensor([])
|
||||
|
||||
if path_manager is None:
|
||||
path_manager = PathManager()
|
||||
with _open_file(f, path_manager, "wb") as f:
|
||||
_save_ply(f, verts, faces, verts_normals, ascii, decimal_places)
|
||||
_save_ply(
|
||||
f,
|
||||
verts=verts,
|
||||
faces=faces,
|
||||
verts_normals=verts_normals,
|
||||
verts_colors=verts_colors,
|
||||
ascii=ascii,
|
||||
decimal_places=decimal_places,
|
||||
)
|
||||
|
||||
|
||||
class MeshPlyFormat(MeshFormatInterpreter):
|
||||
@ -1044,3 +1210,54 @@ class MeshPlyFormat(MeshFormatInterpreter):
|
||||
path_manager=path_manager,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
class PointcloudPlyFormat(PointcloudFormatInterpreter):
|
||||
def __init__(self):
|
||||
self.known_suffixes = (".ply",)
|
||||
|
||||
def read(
|
||||
self,
|
||||
path: Union[str, Path],
|
||||
device,
|
||||
path_manager: PathManager,
|
||||
**kwargs,
|
||||
) -> Optional[Pointclouds]:
|
||||
if not endswith(path, self.known_suffixes):
|
||||
return None
|
||||
|
||||
verts, faces, features = _load_ply(
|
||||
f=path, path_manager=path_manager, return_vertex_colors=True
|
||||
)
|
||||
verts = verts.to(device)
|
||||
if features is None:
|
||||
pointcloud = Pointclouds(points=[verts])
|
||||
else:
|
||||
pointcloud = Pointclouds(points=[verts], features=[features.to(device)])
|
||||
return pointcloud
|
||||
|
||||
def save(
|
||||
self,
|
||||
data: Pointclouds,
|
||||
path: Union[str, Path],
|
||||
path_manager: PathManager,
|
||||
binary: Optional[bool],
|
||||
decimal_places: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
if not endswith(path, self.known_suffixes):
|
||||
return False
|
||||
|
||||
points = data.points_list()[0]
|
||||
features = data.features_list()[0]
|
||||
with _open_file(path, path_manager, "wb") as f:
|
||||
_save_ply(
|
||||
f=f,
|
||||
verts=points,
|
||||
verts_colors=features,
|
||||
verts_normals=torch.FloatTensor([]),
|
||||
faces=None,
|
||||
ascii=binary is False,
|
||||
decimal_places=decimal_places,
|
||||
)
|
||||
return True
|
||||
|
@ -12,6 +12,7 @@ from common_testing import TestCaseMixin
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.io import IO
|
||||
from pytorch3d.io.ply_io import load_ply, save_ply
|
||||
from pytorch3d.structures import Pointclouds
|
||||
from pytorch3d.utils import torus
|
||||
|
||||
|
||||
@ -229,9 +230,13 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
||||
expected_verts = torch.zeros(size=(0, 3), dtype=torch.float32)
|
||||
if not len(expected_faces): # Always compare with an (F, 3) tensor
|
||||
expected_faces = torch.zeros(size=(0, 3), dtype=torch.int64)
|
||||
|
||||
actual_verts, actual_faces = load_ply(f)
|
||||
self.assertClose(expected_verts, actual_verts)
|
||||
self.assertClose(expected_faces, actual_faces)
|
||||
if len(actual_verts):
|
||||
self.assertClose(expected_faces, actual_faces)
|
||||
else:
|
||||
self.assertEqual(actual_faces.numel(), 0)
|
||||
|
||||
def test_normals_save(self):
|
||||
verts = torch.tensor(
|
||||
@ -255,9 +260,10 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
||||
self._test_save_load(verts, faces)
|
||||
|
||||
# Faces + empty vertices
|
||||
message_regex = "Faces have invalid indices"
|
||||
# => We don't save the faces
|
||||
verts = torch.FloatTensor([])
|
||||
faces = torch.LongTensor([[0, 1, 2]])
|
||||
message_regex = "Empty 'verts' provided"
|
||||
with self.assertWarnsRegex(UserWarning, message_regex):
|
||||
self._test_save_load(verts, faces)
|
||||
|
||||
@ -266,7 +272,6 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
||||
self._test_save_load(verts, faces)
|
||||
|
||||
# Empty vertices + empty faces
|
||||
message_regex = "Empty 'verts' and 'faces' arguments provided"
|
||||
verts0 = torch.FloatTensor([])
|
||||
faces0 = torch.LongTensor([])
|
||||
with self.assertWarnsRegex(UserWarning, message_regex):
|
||||
@ -354,6 +359,115 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(x, X)
|
||||
self.assertClose(yz, YZ.reshape(8, 2))
|
||||
|
||||
def test_load_cloudcompare_pointcloud(self):
|
||||
"""
|
||||
Test loading a pointcloud styled like some cloudcompare output.
|
||||
cloudcompare is an open source 3D point cloud processing software.
|
||||
"""
|
||||
header = "\n".join(
|
||||
[
|
||||
"ply",
|
||||
"format binary_little_endian 1.0",
|
||||
"obj_info Not a key-value pair!",
|
||||
"element vertex 8",
|
||||
"property double x",
|
||||
"property double y",
|
||||
"property double z",
|
||||
"property uchar red",
|
||||
"property uchar green",
|
||||
"property uchar blue",
|
||||
"property float my_Favorite",
|
||||
"end_header",
|
||||
"",
|
||||
]
|
||||
).encode("ascii")
|
||||
data = struct.pack("<" + "dddBBBf" * 8, *range(56))
|
||||
io = IO()
|
||||
with NamedTemporaryFile(mode="wb", suffix=".ply") as f:
|
||||
f.write(header)
|
||||
f.write(data)
|
||||
f.flush()
|
||||
pointcloud = io.load_pointcloud(f.name)
|
||||
|
||||
self.assertClose(
|
||||
pointcloud.points_padded()[0],
|
||||
torch.FloatTensor([0, 1, 2]) + 7 * torch.arange(8)[:, None],
|
||||
)
|
||||
self.assertClose(
|
||||
pointcloud.features_padded()[0],
|
||||
torch.FloatTensor([3, 4, 5]) + 7 * torch.arange(8)[:, None],
|
||||
)
|
||||
|
||||
def test_save_pointcloud(self):
|
||||
header = "\n".join(
|
||||
[
|
||||
"ply",
|
||||
"format binary_little_endian 1.0",
|
||||
"element vertex 8",
|
||||
"property float x",
|
||||
"property float y",
|
||||
"property float z",
|
||||
"property float red",
|
||||
"property float green",
|
||||
"property float blue",
|
||||
"end_header",
|
||||
"",
|
||||
]
|
||||
).encode("ascii")
|
||||
data = struct.pack("<" + "f" * 48, *range(48))
|
||||
points = torch.FloatTensor([0, 1, 2]) + 6 * torch.arange(8)[:, None]
|
||||
features = torch.FloatTensor([3, 4, 5]) + 6 * torch.arange(8)[:, None]
|
||||
pointcloud = Pointclouds(points=[points], features=[features])
|
||||
|
||||
io = IO()
|
||||
with NamedTemporaryFile(mode="rb", suffix=".ply") as f:
|
||||
io.save_pointcloud(data=pointcloud, path=f.name)
|
||||
f.flush()
|
||||
f.seek(0)
|
||||
actual_data = f.read()
|
||||
reloaded_pointcloud = io.load_pointcloud(f.name)
|
||||
|
||||
self.assertEqual(header + data, actual_data)
|
||||
self.assertClose(reloaded_pointcloud.points_list()[0], points)
|
||||
self.assertClose(reloaded_pointcloud.features_list()[0], features)
|
||||
|
||||
with NamedTemporaryFile(mode="r", suffix=".ply") as f:
|
||||
io.save_pointcloud(data=pointcloud, path=f.name, binary=False)
|
||||
reloaded_pointcloud2 = io.load_pointcloud(f.name)
|
||||
self.assertEqual(f.readline(), "ply\n")
|
||||
self.assertEqual(f.readline(), "format ascii 1.0\n")
|
||||
self.assertClose(reloaded_pointcloud2.points_list()[0], points)
|
||||
self.assertClose(reloaded_pointcloud2.features_list()[0], features)
|
||||
|
||||
def test_load_pointcloud_bad_order(self):
|
||||
"""
|
||||
Ply file with a strange property order
|
||||
"""
|
||||
file = "\n".join(
|
||||
[
|
||||
"ply",
|
||||
"format ascii 1.0",
|
||||
"element vertex 1",
|
||||
"property uchar green",
|
||||
"property float x",
|
||||
"property float z",
|
||||
"property uchar red",
|
||||
"property float y",
|
||||
"property uchar blue",
|
||||
"end_header",
|
||||
"1 2 3 4 5 6",
|
||||
]
|
||||
)
|
||||
|
||||
io = IO()
|
||||
pointcloud_gpu = io.load_pointcloud(StringIO(file), device="cuda:0")
|
||||
self.assertEqual(pointcloud_gpu.device, torch.device("cuda:0"))
|
||||
pointcloud = pointcloud_gpu.to(torch.device("cpu"))
|
||||
expected_points = torch.tensor([[[2, 5, 3]]], dtype=torch.float32)
|
||||
expected_features = torch.tensor([[[4, 1, 6]]], dtype=torch.float32)
|
||||
self.assertClose(pointcloud.points_padded(), expected_points)
|
||||
self.assertClose(pointcloud.features_padded(), expected_features)
|
||||
|
||||
def test_load_simple_binary(self):
|
||||
for big_endian in [True, False]:
|
||||
verts = (
|
||||
@ -569,9 +683,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
||||
with self.assertRaisesRegex(ValueError, "Inconsistent data for vertex."):
|
||||
_load_ply_raw(StringIO("\n".join(lines2)))
|
||||
|
||||
# Now make the ply file actually be readable as a Mesh
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "The ply file has no face element."):
|
||||
with self.assertRaisesRegex(ValueError, "Invalid vertices in file."):
|
||||
load_ply(StringIO("\n".join(lines)))
|
||||
|
||||
lines2 = lines.copy()
|
||||
|
Loading…
x
Reference in New Issue
Block a user