PLY pointcloud loading

Summary:
Allow PLY files to not contain faces. Allow loading pointclouds with color, at least encoded according to the way of some cloudcompare examples.

TODO: Allow vertex normals to be read, and allow vertex colors to be written. Make the return type of load_ply something more user friendly, like a dict.

Noticed in https://github.com/facebookresearch/pytorch3d/issues/209

Reviewed By: nikhilaravi

Differential Revision: D22573314

fbshipit-source-id: 72ba1f7c6417f5dfc83f2ebf359eff017057635c
This commit is contained in:
Jeremy Reizenstein 2021-01-07 15:38:49 -08:00 committed by Facebook GitHub Bot
parent 3b9fbfc08c
commit 95707fba1c
3 changed files with 423 additions and 93 deletions

View File

@ -12,7 +12,7 @@ from pytorch3d.structures import Meshes, Pointclouds
from .obj_io import MeshObjFormat
from .pluggable_formats import MeshFormatInterpreter, PointcloudFormatInterpreter
from .ply_io import MeshPlyFormat
from .ply_io import MeshPlyFormat, PointcloudPlyFormat
"""
@ -74,6 +74,7 @@ class IO:
def register_default_formats(self) -> None:
self.register_meshes_format(MeshObjFormat())
self.register_meshes_format(MeshPlyFormat())
self.register_pointcloud_format(PointcloudPlyFormat())
def register_meshes_format(self, interpreter: MeshFormatInterpreter) -> None:
"""

View File

@ -3,23 +3,30 @@
# LICENSE file in the root directory of this source tree.
"""This module implements utility functions for loading and saving meshes."""
"""
This module implements utility functions for loading and saving
meshes and point clouds from PLY files.
"""
import itertools
import struct
import sys
import warnings
from collections import namedtuple
from io import BytesIO
from io import BytesIO, TextIOBase
from pathlib import Path
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from iopath.common.file_io import PathManager
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
from pytorch3d.structures import Meshes
from pytorch3d.structures import Meshes, Pointclouds
from .pluggable_formats import MeshFormatInterpreter, endswith
from .pluggable_formats import (
MeshFormatInterpreter,
PointcloudFormatInterpreter,
endswith,
)
_PlyTypeData = namedtuple("_PlyTypeData", "size struct_char np_type")
@ -127,7 +134,7 @@ class _PlyHeader:
self.elements: (List[_PlyElementType]) element description
self.ascii: (bool) Whether in ascii format
self.big_endian: (bool) (if not ascii) whether big endian
self.obj_info: (dict) arbitrary extra data
self.obj_info: (List[str]) arbitrary extra data
Args:
f: file-like object.
@ -136,7 +143,7 @@ class _PlyHeader:
raise ValueError("Invalid file header.")
seen_format = False
self.elements = []
self.obj_info = {}
self.obj_info = []
while True:
line = f.readline()
if isinstance(line, bytes):
@ -173,10 +180,7 @@ class _PlyHeader:
self._parse_element(line)
continue
if line.startswith("obj_info "):
items = line.split(" ")
if len(items) != 3:
raise ValueError("Invalid line: %s" % line)
self.obj_info[items[1]] = items[2]
self.obj_info.append(line[9:])
continue
if line.startswith("property"):
self._parse_property(line)
@ -736,6 +740,10 @@ def _load_ply_raw_stream(f) -> Tuple[_PlyHeader, dict]:
for element in header.elements:
elements[element.name] = _read_ply_element_ascii(f, element)
else:
if isinstance(f, TextIOBase):
raise ValueError(
"Cannot safely read a binary ply file using a Text stream."
)
big = header.big_endian
for element in header.elements:
elements[element.name] = _read_ply_element_binary(f, element, big)
@ -769,7 +777,187 @@ def _load_ply_raw(f, path_manager: PathManager) -> Tuple[_PlyHeader, dict]:
return header, elements
def load_ply(f, path_manager: Optional[PathManager] = None):
def _get_verts_column_indices(
vertex_head: _PlyElementType,
) -> Tuple[List[int], Optional[List[int]]]:
"""
Get the columns of verts and verts_colors in the vertex
element of a parsed ply file.
Args:
vertex_head: as returned from load_ply_raw.
Returns:
point_idxs: List[int] of 3 point columns.
color_idxs: List[int] of 3 color columns if they are present,
otherwise None.
"""
point_idxs: List[Optional[int]] = [None, None, None]
color_idxs: List[Optional[int]] = [None, None, None]
for i, prop in enumerate(vertex_head.properties):
if prop.list_size_type is not None:
raise ValueError("Invalid vertices in file: did not expect list.")
for j, letter in enumerate(["x", "y", "z"]):
if prop.name == letter:
point_idxs[j] = i
for j, name in enumerate(["red", "green", "blue"]):
if prop.name == name:
color_idxs[j] = i
if None in point_idxs:
raise ValueError("Invalid vertices in file.")
if None in color_idxs:
return point_idxs, None
return point_idxs, color_idxs
def _get_verts(
header: _PlyHeader, elements: dict
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Get the vertex locations and colors from a parsed ply file.
Args:
header, elements: as returned from load_ply_raw.
Returns:
verts: FloatTensor of shape (V, 3).
vertex_colors: None or FloatTensor of shape (V, 3).
"""
vertex = elements.get("vertex", None)
if vertex is None:
raise ValueError("The ply file has no vertex element.")
if not isinstance(vertex, list):
raise ValueError("Invalid vertices in file.")
vertex_head = next(head for head in header.elements if head.name == "vertex")
point_idxs, color_idxs = _get_verts_column_indices(vertex_head)
# Case of no vertices
if vertex_head.count == 0:
verts = torch.zeros((0, 3), dtype=torch.float32)
if color_idxs is None:
return verts, None
return verts, torch.zeros((0, 3), dtype=torch.float32)
# Simple case where the only data is the vertices themselves
if (
len(vertex) == 1
and isinstance(vertex[0], np.ndarray)
and vertex[0].ndim == 2
and vertex[0].shape[1] == 3
):
return _make_tensor(vertex[0], cols=3, dtype=torch.float32), None
vertex_colors = None
if len(vertex) == 1:
# This is the case where the whole vertex element has one type,
# so it was read as a single array and we can index straight into it.
verts = torch.tensor(vertex[0][:, point_idxs], dtype=torch.float32)
if color_idxs is not None:
vertex_colors = torch.tensor(vertex[0][:, color_idxs], dtype=torch.float32)
else:
# The vertex element is heterogeneous. It was read as several arrays,
# part by part, where a part is a set of properties with the same type.
# For each property (=column in the file), we store in
# prop_to_partnum_col its partnum (i.e. the index of what part it is
# in) and its column number (its index within its part).
prop_to_partnum_col = [
(partnum, col)
for partnum, array in enumerate(vertex)
for col in range(array.shape[1])
]
verts = torch.empty(size=(vertex_head.count, 3), dtype=torch.float32)
for axis in range(3):
partnum, col = prop_to_partnum_col[point_idxs[axis]]
verts.numpy()[:, axis] = vertex[partnum][:, col]
# Note that in the previous line, we made the assignment
# as numpy arrays by casting verts. If we took the (more
# obvious) method of converting the right hand side to
# torch, then we might have an extra data copy because
# torch wants contiguity. The code would be like:
# if not vertex[partnum].flags["C_CONTIGUOUS"]:
# vertex[partnum] = np.ascontiguousarray(vertex[partnum])
# verts[:, axis] = torch.tensor((vertex[partnum][:, col]))
if color_idxs is not None:
vertex_colors = torch.empty(
size=(vertex_head.count, 3), dtype=torch.float32
)
for color in range(3):
partnum, col = prop_to_partnum_col[color_idxs[color]]
vertex_colors.numpy()[:, color] = vertex[partnum][:, col]
return verts, vertex_colors
def _load_ply(
f, *, path_manager: PathManager, return_vertex_colors: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Load the data from a .ply file.
Args:
f: A binary or text file-like object (with methods read, readline,
tell and seek), a pathlib path or a string containing a file name.
If the ply file is in the binary ply format rather than the text
ply format, then a text stream is not supported.
It is easiest to use a binary stream in all cases.
path_manager: PathManager for loading if f is a str.
return_vertex_colors: whether to return vertex colors.
Returns:
verts: FloatTensor of shape (V, 3).
faces: None or LongTensor of vertex indices, shape (F, 3).
vertex_colors: None or FloatTensor of shape (V, 3), only if requested
"""
header, elements = _load_ply_raw(f, path_manager=path_manager)
verts, vertex_colors = _get_verts(header, elements)
face = elements.get("face", None)
if face is not None:
face_head = next(head for head in header.elements if head.name == "face")
if (
len(face_head.properties) != 1
or face_head.properties[0].list_size_type is None
):
raise ValueError("Unexpected form of faces data.")
# face_head.properties[0].name is usually "vertex_index" or "vertex_indices"
# but we don't need to enforce this.
if face is None:
faces = None
elif not len(face):
# pyre is happier when this condition is not joined to the
# previous one with `or`.
faces = None
elif isinstance(face, np.ndarray) and face.ndim == 2: # Homogeneous elements
if face.shape[1] < 3:
raise ValueError("Faces must have at least 3 vertices.")
face_arrays = [face[:, [0, i + 1, i + 2]] for i in range(face.shape[1] - 2)]
faces = torch.LongTensor(np.vstack(face_arrays))
else:
face_list = []
for face_item in face:
if face_item.ndim != 1:
raise ValueError("Bad face data.")
if face_item.shape[0] < 3:
raise ValueError("Faces must have at least 3 vertices.")
for i in range(face_item.shape[0] - 2):
face_list.append([face_item[0], face_item[i + 1], face_item[i + 2]])
faces = torch.tensor(face_list, dtype=torch.int64)
if faces is not None:
_check_faces_indices(faces, max_index=verts.shape[0])
if return_vertex_colors:
return verts, faces, vertex_colors
return verts, faces, None
def load_ply(
f, *, path_manager: Optional[PathManager] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Load the data from a .ply file.
@ -809,72 +997,27 @@ def load_ply(f, path_manager: Optional[PathManager] = None):
It is easiest to use a binary stream in all cases.
path_manager: PathManager for loading if f is a str.
Returns:
verts: FloatTensor of shape (V, 3).
faces: LongTensor of vertex indices, shape (F, 3).
"""
if path_manager is None:
path_manager = PathManager()
header, elements = _load_ply_raw(f, path_manager=path_manager)
verts, faces, _ = _load_ply(f, path_manager=path_manager)
if faces is None:
faces = torch.zeros(0, 3, dtype=torch.int64)
vertex = elements.get("vertex", None)
if vertex is None:
raise ValueError("The ply file has no vertex element.")
face = elements.get("face", None)
if face is None:
raise ValueError("The ply file has no face element.")
if not isinstance(vertex, list) or len(vertex) > 1:
raise ValueError("Invalid vertices in file.")
if len(vertex):
vertex0 = vertex[0]
if len(vertex0) and (
not isinstance(vertex0, np.ndarray)
or vertex0.ndim != 2
or vertex0.shape[1] != 3
):
raise ValueError("Invalid vertices in file.")
else:
vertex0 = []
verts = _make_tensor(vertex0, cols=3, dtype=torch.float32)
face_head = next(head for head in header.elements if head.name == "face")
if len(face_head.properties) != 1 or face_head.properties[0].list_size_type is None:
raise ValueError("Unexpected form of faces data.")
# face_head.properties[0].name is usually "vertex_index" or "vertex_indices"
# but we don't need to enforce this.
if not len(face):
faces = torch.zeros((0, 3), dtype=torch.int64)
elif isinstance(face, np.ndarray) and face.ndim == 2: # Homogeneous elements
if face.shape[1] < 3:
raise ValueError("Faces must have at least 3 vertices.")
face_arrays = [face[:, [0, i + 1, i + 2]] for i in range(face.shape[1] - 2)]
faces = torch.LongTensor(np.vstack(face_arrays))
else:
face_list = []
for face_item in face:
if face_item.ndim != 1:
raise ValueError("Bad face data.")
if face_item.shape[0] < 3:
raise ValueError("Faces must have at least 3 vertices.")
for i in range(face_item.shape[0] - 2):
face_list.append([face_item[0], face_item[i + 1], face_item[i + 2]])
# pyre-fixme[6]: Expected `dtype` for 3rd param but got `Type[torch.int64]`.
faces = _make_tensor(face_list, cols=3, dtype=torch.int64)
_check_faces_indices(faces, max_index=verts.shape[0])
return verts, faces
def _save_ply(
f,
*,
verts: torch.Tensor,
faces: torch.LongTensor,
faces: Optional[torch.LongTensor],
verts_normals: torch.Tensor,
verts_colors: torch.Tensor,
ascii: bool,
decimal_places: Optional[int] = None,
) -> None:
@ -890,10 +1033,14 @@ def _save_ply(
decimal_places: Number of decimal places for saving if ascii=True.
"""
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
if faces is not None:
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
assert not len(verts_normals) or (
verts_normals.dim() == 2 and verts_normals.size(1) == 3
)
assert not len(verts_colors) or (
verts_colors.dim() == 2 and verts_colors.size(1) == 3
)
if ascii:
f.write(b"ply\nformat ascii 1.0\n")
@ -909,15 +1056,20 @@ def _save_ply(
f.write(b"property float nx\n")
f.write(b"property float ny\n")
f.write(b"property float nz\n")
if verts_colors.numel() > 0:
f.write(b"property float red\n")
f.write(b"property float green\n")
f.write(b"property float blue\n")
if len(verts) and faces is not None:
f.write(f"element face {faces.shape[0]}\n".encode("ascii"))
f.write(b"property list uchar int vertex_index\n")
f.write(b"end_header\n")
if not (len(verts) or len(faces)):
warnings.warn("Empty 'verts' and 'faces' arguments provided")
if not (len(verts)):
warnings.warn("Empty 'verts' provided")
return
vert_data = torch.cat((verts, verts_normals), dim=1).detach().numpy()
vert_data = torch.cat((verts, verts_normals, verts_colors), dim=1).detach().numpy()
if ascii:
if decimal_places is None:
float_str = "%f"
@ -932,6 +1084,7 @@ def _save_ply(
else:
vert_data.tofile(f)
if faces is not None:
faces_array = faces.detach().numpy()
_check_faces_indices(faces, max_index=verts.shape[0])
@ -977,13 +1130,16 @@ def save_ply(
if verts_normals is None
else verts_normals
)
faces = torch.LongTensor([]) if faces is None else faces
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
raise ValueError(message)
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
if (
faces is not None
and len(faces)
and not (faces.dim() == 2 and faces.size(1) == 3)
):
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
raise ValueError(message)
@ -995,10 +1151,20 @@ def save_ply(
message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)."
raise ValueError(message)
verts_colors = torch.FloatTensor([])
if path_manager is None:
path_manager = PathManager()
with _open_file(f, path_manager, "wb") as f:
_save_ply(f, verts, faces, verts_normals, ascii, decimal_places)
_save_ply(
f,
verts=verts,
faces=faces,
verts_normals=verts_normals,
verts_colors=verts_colors,
ascii=ascii,
decimal_places=decimal_places,
)
class MeshPlyFormat(MeshFormatInterpreter):
@ -1044,3 +1210,54 @@ class MeshPlyFormat(MeshFormatInterpreter):
path_manager=path_manager,
)
return True
class PointcloudPlyFormat(PointcloudFormatInterpreter):
def __init__(self):
self.known_suffixes = (".ply",)
def read(
self,
path: Union[str, Path],
device,
path_manager: PathManager,
**kwargs,
) -> Optional[Pointclouds]:
if not endswith(path, self.known_suffixes):
return None
verts, faces, features = _load_ply(
f=path, path_manager=path_manager, return_vertex_colors=True
)
verts = verts.to(device)
if features is None:
pointcloud = Pointclouds(points=[verts])
else:
pointcloud = Pointclouds(points=[verts], features=[features.to(device)])
return pointcloud
def save(
self,
data: Pointclouds,
path: Union[str, Path],
path_manager: PathManager,
binary: Optional[bool],
decimal_places: Optional[int] = None,
**kwargs,
) -> bool:
if not endswith(path, self.known_suffixes):
return False
points = data.points_list()[0]
features = data.features_list()[0]
with _open_file(path, path_manager, "wb") as f:
_save_ply(
f=f,
verts=points,
verts_colors=features,
verts_normals=torch.FloatTensor([]),
faces=None,
ascii=binary is False,
decimal_places=decimal_places,
)
return True

View File

@ -12,6 +12,7 @@ from common_testing import TestCaseMixin
from iopath.common.file_io import PathManager
from pytorch3d.io import IO
from pytorch3d.io.ply_io import load_ply, save_ply
from pytorch3d.structures import Pointclouds
from pytorch3d.utils import torus
@ -229,9 +230,13 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
expected_verts = torch.zeros(size=(0, 3), dtype=torch.float32)
if not len(expected_faces): # Always compare with an (F, 3) tensor
expected_faces = torch.zeros(size=(0, 3), dtype=torch.int64)
actual_verts, actual_faces = load_ply(f)
self.assertClose(expected_verts, actual_verts)
if len(actual_verts):
self.assertClose(expected_faces, actual_faces)
else:
self.assertEqual(actual_faces.numel(), 0)
def test_normals_save(self):
verts = torch.tensor(
@ -255,9 +260,10 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
self._test_save_load(verts, faces)
# Faces + empty vertices
message_regex = "Faces have invalid indices"
# => We don't save the faces
verts = torch.FloatTensor([])
faces = torch.LongTensor([[0, 1, 2]])
message_regex = "Empty 'verts' provided"
with self.assertWarnsRegex(UserWarning, message_regex):
self._test_save_load(verts, faces)
@ -266,7 +272,6 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
self._test_save_load(verts, faces)
# Empty vertices + empty faces
message_regex = "Empty 'verts' and 'faces' arguments provided"
verts0 = torch.FloatTensor([])
faces0 = torch.LongTensor([])
with self.assertWarnsRegex(UserWarning, message_regex):
@ -354,6 +359,115 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
self.assertClose(x, X)
self.assertClose(yz, YZ.reshape(8, 2))
def test_load_cloudcompare_pointcloud(self):
"""
Test loading a pointcloud styled like some cloudcompare output.
cloudcompare is an open source 3D point cloud processing software.
"""
header = "\n".join(
[
"ply",
"format binary_little_endian 1.0",
"obj_info Not a key-value pair!",
"element vertex 8",
"property double x",
"property double y",
"property double z",
"property uchar red",
"property uchar green",
"property uchar blue",
"property float my_Favorite",
"end_header",
"",
]
).encode("ascii")
data = struct.pack("<" + "dddBBBf" * 8, *range(56))
io = IO()
with NamedTemporaryFile(mode="wb", suffix=".ply") as f:
f.write(header)
f.write(data)
f.flush()
pointcloud = io.load_pointcloud(f.name)
self.assertClose(
pointcloud.points_padded()[0],
torch.FloatTensor([0, 1, 2]) + 7 * torch.arange(8)[:, None],
)
self.assertClose(
pointcloud.features_padded()[0],
torch.FloatTensor([3, 4, 5]) + 7 * torch.arange(8)[:, None],
)
def test_save_pointcloud(self):
header = "\n".join(
[
"ply",
"format binary_little_endian 1.0",
"element vertex 8",
"property float x",
"property float y",
"property float z",
"property float red",
"property float green",
"property float blue",
"end_header",
"",
]
).encode("ascii")
data = struct.pack("<" + "f" * 48, *range(48))
points = torch.FloatTensor([0, 1, 2]) + 6 * torch.arange(8)[:, None]
features = torch.FloatTensor([3, 4, 5]) + 6 * torch.arange(8)[:, None]
pointcloud = Pointclouds(points=[points], features=[features])
io = IO()
with NamedTemporaryFile(mode="rb", suffix=".ply") as f:
io.save_pointcloud(data=pointcloud, path=f.name)
f.flush()
f.seek(0)
actual_data = f.read()
reloaded_pointcloud = io.load_pointcloud(f.name)
self.assertEqual(header + data, actual_data)
self.assertClose(reloaded_pointcloud.points_list()[0], points)
self.assertClose(reloaded_pointcloud.features_list()[0], features)
with NamedTemporaryFile(mode="r", suffix=".ply") as f:
io.save_pointcloud(data=pointcloud, path=f.name, binary=False)
reloaded_pointcloud2 = io.load_pointcloud(f.name)
self.assertEqual(f.readline(), "ply\n")
self.assertEqual(f.readline(), "format ascii 1.0\n")
self.assertClose(reloaded_pointcloud2.points_list()[0], points)
self.assertClose(reloaded_pointcloud2.features_list()[0], features)
def test_load_pointcloud_bad_order(self):
"""
Ply file with a strange property order
"""
file = "\n".join(
[
"ply",
"format ascii 1.0",
"element vertex 1",
"property uchar green",
"property float x",
"property float z",
"property uchar red",
"property float y",
"property uchar blue",
"end_header",
"1 2 3 4 5 6",
]
)
io = IO()
pointcloud_gpu = io.load_pointcloud(StringIO(file), device="cuda:0")
self.assertEqual(pointcloud_gpu.device, torch.device("cuda:0"))
pointcloud = pointcloud_gpu.to(torch.device("cpu"))
expected_points = torch.tensor([[[2, 5, 3]]], dtype=torch.float32)
expected_features = torch.tensor([[[4, 1, 6]]], dtype=torch.float32)
self.assertClose(pointcloud.points_padded(), expected_points)
self.assertClose(pointcloud.features_padded(), expected_features)
def test_load_simple_binary(self):
for big_endian in [True, False]:
verts = (
@ -569,9 +683,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
with self.assertRaisesRegex(ValueError, "Inconsistent data for vertex."):
_load_ply_raw(StringIO("\n".join(lines2)))
# Now make the ply file actually be readable as a Mesh
with self.assertRaisesRegex(ValueError, "The ply file has no face element."):
with self.assertRaisesRegex(ValueError, "Invalid vertices in file."):
load_ply(StringIO("\n".join(lines)))
lines2 = lines.copy()