Use dataclasses inside ply_io.

Summary: Refactor ply_io to make it easier to add new features. Mostly taken from the starting code I attached to https://github.com/facebookresearch/pytorch3d/issues/904.

Reviewed By: patricklabatut

Differential Revision: D34375978

fbshipit-source-id: ec017d31f07c6f71ba6d97a0623bb10be1e81212
This commit is contained in:
Jeremy Reizenstein 2022-02-21 06:57:53 -08:00 committed by Facebook GitHub Bot
parent feb5d36394
commit 967a099231

View File

@ -14,8 +14,9 @@ import struct
import sys import sys
import warnings import warnings
from collections import namedtuple from collections import namedtuple
from dataclasses import asdict, dataclass
from io import BytesIO, TextIOBase from io import BytesIO, TextIOBase
from typing import List, Optional, Tuple, cast from typing import List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
@ -137,6 +138,7 @@ class _PlyHeader:
self.ascii: (bool) Whether in ascii format self.ascii: (bool) Whether in ascii format
self.big_endian: (bool) (if not ascii) whether big endian self.big_endian: (bool) (if not ascii) whether big endian
self.obj_info: (List[str]) arbitrary extra data self.obj_info: (List[str]) arbitrary extra data
self.comments: (List[str]) comments
Args: Args:
f: file-like object. f: file-like object.
@ -145,7 +147,8 @@ class _PlyHeader:
raise ValueError("Invalid file header.") raise ValueError("Invalid file header.")
seen_format = False seen_format = False
self.elements: List[_PlyElementType] = [] self.elements: List[_PlyElementType] = []
self.obj_info = [] self.comments: List[str] = []
self.obj_info: List[str] = []
while True: while True:
line = f.readline() line = f.readline()
if isinstance(line, bytes): if isinstance(line, bytes):
@ -176,6 +179,9 @@ class _PlyHeader:
continue continue
if line.startswith("format"): if line.startswith("format"):
raise ValueError("Invalid format line.") raise ValueError("Invalid format line.")
if line.startswith("comment "):
self.comments.append(line[8:])
continue
if line.startswith("comment") or len(line) == 0: if line.startswith("comment") or len(line) == 0:
continue continue
if line.startswith("element"): if line.startswith("element"):
@ -781,9 +787,28 @@ def _load_ply_raw(f, path_manager: PathManager) -> Tuple[_PlyHeader, dict]:
return header, elements return header, elements
@dataclass(frozen=True)
class _VertsColumnIndices:
"""
Contains the relevant layout of the verts section of file being read.
Members
point_idxs: List[int] of 3 point columns.
color_idxs: List[int] of 3 color columns if they are present,
otherwise None.
color_scale: value to scale colors by.
normal_idxs: List[int] of 3 normals columns if they are present,
otherwise None.
"""
point_idxs: List[int]
color_idxs: Optional[List[int]]
color_scale: float
normal_idxs: Optional[List[int]]
def _get_verts_column_indices( def _get_verts_column_indices(
vertex_head: _PlyElementType, vertex_head: _PlyElementType,
) -> Tuple[List[int], Optional[List[int]], float, Optional[List[int]]]: ) -> _VertsColumnIndices:
""" """
Get the columns of verts, verts_colors, and verts_normals in the vertex Get the columns of verts, verts_colors, and verts_normals in the vertex
element of a parsed ply file, together with a color scale factor. element of a parsed ply file, together with a color scale factor.
@ -809,12 +834,7 @@ def _get_verts_column_indices(
vertex_head: as returned from load_ply_raw. vertex_head: as returned from load_ply_raw.
Returns: Returns:
point_idxs: List[int] of 3 point columns. _VertsColumnIndices object
color_idxs: List[int] of 3 color columns if they are present,
otherwise None.
color_scale: value to scale colors by.
normal_idxs: List[int] of 3 normals columns if they are present,
otherwise None.
""" """
point_idxs: List[Optional[int]] = [None, None, None] point_idxs: List[Optional[int]] = [None, None, None]
color_idxs: List[Optional[int]] = [None, None, None] color_idxs: List[Optional[int]] = [None, None, None]
@ -839,19 +859,30 @@ def _get_verts_column_indices(
for idx in color_idxs for idx in color_idxs
): ):
color_scale = 1.0 / 255 color_scale = 1.0 / 255
return ( return _VertsColumnIndices(
point_idxs, point_idxs=point_idxs,
# pyre-fixme[22]: The cast is redundant. color_idxs=None if None in color_idxs else color_idxs,
None if None in color_idxs else cast(List[int], color_idxs), color_scale=color_scale,
color_scale, normal_idxs=None if None in normal_idxs else normal_idxs,
# pyre-fixme[22]: The cast is redundant.
None if None in normal_idxs else cast(List[int], normal_idxs),
) )
def _get_verts( @dataclass(frozen=True)
header: _PlyHeader, elements: dict class _VertsData:
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """
Contains the data of the verts section of file being read.
Members:
verts: FloatTensor of shape (V, 3).
verts_colors: None or FloatTensor of shape (V, 3).
verts_normals: None or FloatTensor of shape (V, 3).
"""
verts: torch.Tensor
verts_colors: Optional[torch.Tensor] = None
verts_normals: Optional[torch.Tensor] = None
def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData:
""" """
Get the vertex locations, colors and normals from a parsed ply file. Get the vertex locations, colors and normals from a parsed ply file.
@ -859,9 +890,7 @@ def _get_verts(
header, elements: as returned from load_ply_raw. header, elements: as returned from load_ply_raw.
Returns: Returns:
verts: FloatTensor of shape (V, 3). _VertsData object
vertex_colors: None or FloatTensor of shape (V, 3).
vertex_normals: None or FloatTensor of shape (V, 3).
""" """
vertex = elements.get("vertex", None) vertex = elements.get("vertex", None)
@ -870,16 +899,17 @@ def _get_verts(
if not isinstance(vertex, list): if not isinstance(vertex, list):
raise ValueError("Invalid vertices in file.") raise ValueError("Invalid vertices in file.")
vertex_head = next(head for head in header.elements if head.name == "vertex") vertex_head = next(head for head in header.elements if head.name == "vertex")
point_idxs, color_idxs, color_scale, normal_idxs = _get_verts_column_indices(
vertex_head column_idxs = _get_verts_column_indices(vertex_head)
)
# Case of no vertices # Case of no vertices
if vertex_head.count == 0: if vertex_head.count == 0:
verts = torch.zeros((0, 3), dtype=torch.float32) verts = torch.zeros((0, 3), dtype=torch.float32)
if color_idxs is None: if column_idxs.color_idxs is None:
return verts, None, None return _VertsData(verts=verts)
return verts, torch.zeros((0, 3), dtype=torch.float32), None return _VertsData(
verts=verts, verts_colors=torch.zeros((0, 3), dtype=torch.float32)
)
# Simple case where the only data is the vertices themselves # Simple case where the only data is the vertices themselves
if ( if (
@ -888,7 +918,7 @@ def _get_verts(
and vertex[0].ndim == 2 and vertex[0].ndim == 2
and vertex[0].shape[1] == 3 and vertex[0].shape[1] == 3
): ):
return _make_tensor(vertex[0], cols=3, dtype=torch.float32), None, None return _VertsData(verts=_make_tensor(vertex[0], cols=3, dtype=torch.float32))
vertex_colors = None vertex_colors = None
vertex_normals = None vertex_normals = None
@ -896,14 +926,14 @@ def _get_verts(
if len(vertex) == 1: if len(vertex) == 1:
# This is the case where the whole vertex element has one type, # This is the case where the whole vertex element has one type,
# so it was read as a single array and we can index straight into it. # so it was read as a single array and we can index straight into it.
verts = torch.tensor(vertex[0][:, point_idxs], dtype=torch.float32) verts = torch.tensor(vertex[0][:, column_idxs.point_idxs], dtype=torch.float32)
if color_idxs is not None: if column_idxs.color_idxs is not None:
vertex_colors = color_scale * torch.tensor( vertex_colors = column_idxs.color_scale * torch.tensor(
vertex[0][:, color_idxs], dtype=torch.float32 vertex[0][:, column_idxs.color_idxs], dtype=torch.float32
) )
if normal_idxs is not None: if column_idxs.normal_idxs is not None:
vertex_normals = torch.tensor( vertex_normals = torch.tensor(
vertex[0][:, normal_idxs], dtype=torch.float32 vertex[0][:, column_idxs.normal_idxs], dtype=torch.float32
) )
else: else:
# The vertex element is heterogeneous. It was read as several arrays, # The vertex element is heterogeneous. It was read as several arrays,
@ -918,7 +948,7 @@ def _get_verts(
] ]
verts = torch.empty(size=(vertex_head.count, 3), dtype=torch.float32) verts = torch.empty(size=(vertex_head.count, 3), dtype=torch.float32)
for axis in range(3): for axis in range(3):
partnum, col = prop_to_partnum_col[point_idxs[axis]] partnum, col = prop_to_partnum_col[column_idxs.point_idxs[axis]]
verts.numpy()[:, axis] = vertex[partnum][:, col] verts.numpy()[:, axis] = vertex[partnum][:, col]
# Note that in the previous line, we made the assignment # Note that in the previous line, we made the assignment
# as numpy arrays by casting verts. If we took the (more # as numpy arrays by casting verts. If we took the (more
@ -928,30 +958,49 @@ def _get_verts(
# if not vertex[partnum].flags["C_CONTIGUOUS"]: # if not vertex[partnum].flags["C_CONTIGUOUS"]:
# vertex[partnum] = np.ascontiguousarray(vertex[partnum]) # vertex[partnum] = np.ascontiguousarray(vertex[partnum])
# verts[:, axis] = torch.tensor((vertex[partnum][:, col])) # verts[:, axis] = torch.tensor((vertex[partnum][:, col]))
if color_idxs is not None: if column_idxs.color_idxs is not None:
vertex_colors = torch.empty( vertex_colors = torch.empty(
size=(vertex_head.count, 3), dtype=torch.float32 size=(vertex_head.count, 3), dtype=torch.float32
) )
for color in range(3): for color in range(3):
partnum, col = prop_to_partnum_col[color_idxs[color]] partnum, col = prop_to_partnum_col[column_idxs.color_idxs[color]]
vertex_colors.numpy()[:, color] = vertex[partnum][:, col] vertex_colors.numpy()[:, color] = vertex[partnum][:, col]
vertex_colors *= color_scale vertex_colors *= column_idxs.color_scale
if normal_idxs is not None: if column_idxs.normal_idxs is not None:
vertex_normals = torch.empty( vertex_normals = torch.empty(
size=(vertex_head.count, 3), dtype=torch.float32 size=(vertex_head.count, 3), dtype=torch.float32
) )
for axis in range(3): for axis in range(3):
partnum, col = prop_to_partnum_col[normal_idxs[axis]] partnum, col = prop_to_partnum_col[column_idxs.normal_idxs[axis]]
vertex_normals.numpy()[:, axis] = vertex[partnum][:, col] vertex_normals.numpy()[:, axis] = vertex[partnum][:, col]
return verts, vertex_colors, vertex_normals return _VertsData(
verts=verts,
verts_colors=vertex_colors,
verts_normals=vertex_normals,
)
def _load_ply( @dataclass(frozen=True)
f, *, path_manager: PathManager class _PlyData:
) -> Tuple[ """
torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor] Contains the data from a PLY file which has been read.
]: Members:
header: _PlyHeader of file metadata from the header
verts: FloatTensor of shape (V, 3).
faces: None or LongTensor of vertex indices, shape (F, 3).
verts_colors: None or FloatTensor of shape (V, 3).
verts_normals: None or FloatTensor of shape (V, 3).
"""
header: _PlyHeader
verts: torch.Tensor
faces: Optional[torch.Tensor]
verts_colors: Optional[torch.Tensor]
verts_normals: Optional[torch.Tensor]
def _load_ply(f, *, path_manager: PathManager) -> _PlyData:
""" """
Load the data from a .ply file. Load the data from a .ply file.
@ -964,14 +1013,11 @@ def _load_ply(
path_manager: PathManager for loading if f is a str. path_manager: PathManager for loading if f is a str.
Returns: Returns:
verts: FloatTensor of shape (V, 3). _PlyData object
faces: None or LongTensor of vertex indices, shape (F, 3).
vertex_colors: None or FloatTensor of shape (V, 3).
vertex_normals: None or FloatTensor of shape (V, 3).
""" """
header, elements = _load_ply_raw(f, path_manager=path_manager) header, elements = _load_ply_raw(f, path_manager=path_manager)
verts, vertex_colors, vertex_normals = _get_verts(header, elements) verts_data = _get_verts(header, elements)
face = elements.get("face", None) face = elements.get("face", None)
if face is not None: if face is not None:
@ -1007,9 +1053,9 @@ def _load_ply(
faces = torch.tensor(face_list, dtype=torch.int64) faces = torch.tensor(face_list, dtype=torch.int64)
if faces is not None: if faces is not None:
_check_faces_indices(faces, max_index=verts.shape[0]) _check_faces_indices(faces, max_index=verts_data.verts.shape[0])
return verts, faces, vertex_colors, vertex_normals return _PlyData(**asdict(verts_data), faces=faces, header=header)
def load_ply( def load_ply(
@ -1064,11 +1110,12 @@ def load_ply(
if path_manager is None: if path_manager is None:
path_manager = PathManager() path_manager = PathManager()
verts, faces, _, _ = _load_ply(f, path_manager=path_manager) data = _load_ply(f, path_manager=path_manager)
faces = data.faces
if faces is None: if faces is None:
faces = torch.zeros(0, 3, dtype=torch.int64) faces = torch.zeros(0, 3, dtype=torch.int64)
return verts, faces return data.verts, faces
def _write_ply_header( def _write_ply_header(
@ -1305,20 +1352,20 @@ class MeshPlyFormat(MeshFormatInterpreter):
if not endswith(path, self.known_suffixes): if not endswith(path, self.known_suffixes):
return None return None
verts, faces, verts_colors, verts_normals = _load_ply( data = _load_ply(f=path, path_manager=path_manager)
f=path, path_manager=path_manager faces = data.faces
)
if faces is None: if faces is None:
faces = torch.zeros(0, 3, dtype=torch.int64) faces = torch.zeros(0, 3, dtype=torch.int64)
texture = None texture = None
if include_textures and verts_colors is not None: if include_textures and data.verts_colors is not None:
texture = TexturesVertex([verts_colors.to(device)]) texture = TexturesVertex([data.verts_colors.to(device)])
if verts_normals is not None: verts_normals = None
verts_normals = [verts_normals] if data.verts_normals is not None:
verts_normals = [data.verts_normals.to(device)]
mesh = Meshes( mesh = Meshes(
verts=[verts.to(device)], verts=[data.verts.to(device)],
faces=[faces.to(device)], faces=[faces.to(device)],
textures=texture, textures=texture,
verts_normals=verts_normals, verts_normals=verts_normals,
@ -1392,14 +1439,17 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter):
if not endswith(path, self.known_suffixes): if not endswith(path, self.known_suffixes):
return None return None
verts, faces, features, normals = _load_ply(f=path, path_manager=path_manager) data = _load_ply(f=path, path_manager=path_manager)
verts = verts.to(device) features = None
if features is not None: if data.verts_colors is not None:
features = [features.to(device)] features = [data.verts_colors.to(device)]
if normals is not None: normals = None
normals = [normals.to(device)] if data.verts_normals is not None:
normals = [data.verts_normals.to(device)]
pointcloud = Pointclouds(points=[verts], features=features, normals=normals) pointcloud = Pointclouds(
points=[data.verts.to(device)], features=features, normals=normals
)
return pointcloud return pointcloud
def save( def save(