Use dataclasses inside ply_io.

Summary: Refactor ply_io to make it easier to add new features. Mostly taken from the starting code I attached to https://github.com/facebookresearch/pytorch3d/issues/904.

Reviewed By: patricklabatut

Differential Revision: D34375978

fbshipit-source-id: ec017d31f07c6f71ba6d97a0623bb10be1e81212
This commit is contained in:
Jeremy Reizenstein 2022-02-21 06:57:53 -08:00 committed by Facebook GitHub Bot
parent feb5d36394
commit 967a099231

View File

@ -14,8 +14,9 @@ import struct
import sys
import warnings
from collections import namedtuple
from dataclasses import asdict, dataclass
from io import BytesIO, TextIOBase
from typing import List, Optional, Tuple, cast
from typing import List, Optional, Tuple
import numpy as np
import torch
@ -137,6 +138,7 @@ class _PlyHeader:
self.ascii: (bool) Whether in ascii format
self.big_endian: (bool) (if not ascii) whether big endian
self.obj_info: (List[str]) arbitrary extra data
self.comments: (List[str]) comments
Args:
f: file-like object.
@ -145,7 +147,8 @@ class _PlyHeader:
raise ValueError("Invalid file header.")
seen_format = False
self.elements: List[_PlyElementType] = []
self.obj_info = []
self.comments: List[str] = []
self.obj_info: List[str] = []
while True:
line = f.readline()
if isinstance(line, bytes):
@ -176,6 +179,9 @@ class _PlyHeader:
continue
if line.startswith("format"):
raise ValueError("Invalid format line.")
if line.startswith("comment "):
self.comments.append(line[8:])
continue
if line.startswith("comment") or len(line) == 0:
continue
if line.startswith("element"):
@ -781,9 +787,28 @@ def _load_ply_raw(f, path_manager: PathManager) -> Tuple[_PlyHeader, dict]:
return header, elements
@dataclass(frozen=True)
class _VertsColumnIndices:
"""
Contains the relevant layout of the verts section of file being read.
Members
point_idxs: List[int] of 3 point columns.
color_idxs: List[int] of 3 color columns if they are present,
otherwise None.
color_scale: value to scale colors by.
normal_idxs: List[int] of 3 normals columns if they are present,
otherwise None.
"""
point_idxs: List[int]
color_idxs: Optional[List[int]]
color_scale: float
normal_idxs: Optional[List[int]]
def _get_verts_column_indices(
vertex_head: _PlyElementType,
) -> Tuple[List[int], Optional[List[int]], float, Optional[List[int]]]:
) -> _VertsColumnIndices:
"""
Get the columns of verts, verts_colors, and verts_normals in the vertex
element of a parsed ply file, together with a color scale factor.
@ -809,12 +834,7 @@ def _get_verts_column_indices(
vertex_head: as returned from load_ply_raw.
Returns:
point_idxs: List[int] of 3 point columns.
color_idxs: List[int] of 3 color columns if they are present,
otherwise None.
color_scale: value to scale colors by.
normal_idxs: List[int] of 3 normals columns if they are present,
otherwise None.
_VertsColumnIndices object
"""
point_idxs: List[Optional[int]] = [None, None, None]
color_idxs: List[Optional[int]] = [None, None, None]
@ -839,19 +859,30 @@ def _get_verts_column_indices(
for idx in color_idxs
):
color_scale = 1.0 / 255
return (
point_idxs,
# pyre-fixme[22]: The cast is redundant.
None if None in color_idxs else cast(List[int], color_idxs),
color_scale,
# pyre-fixme[22]: The cast is redundant.
None if None in normal_idxs else cast(List[int], normal_idxs),
return _VertsColumnIndices(
point_idxs=point_idxs,
color_idxs=None if None in color_idxs else color_idxs,
color_scale=color_scale,
normal_idxs=None if None in normal_idxs else normal_idxs,
)
def _get_verts(
header: _PlyHeader, elements: dict
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
@dataclass(frozen=True)
class _VertsData:
"""
Contains the data of the verts section of file being read.
Members:
verts: FloatTensor of shape (V, 3).
verts_colors: None or FloatTensor of shape (V, 3).
verts_normals: None or FloatTensor of shape (V, 3).
"""
verts: torch.Tensor
verts_colors: Optional[torch.Tensor] = None
verts_normals: Optional[torch.Tensor] = None
def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData:
"""
Get the vertex locations, colors and normals from a parsed ply file.
@ -859,9 +890,7 @@ def _get_verts(
header, elements: as returned from load_ply_raw.
Returns:
verts: FloatTensor of shape (V, 3).
vertex_colors: None or FloatTensor of shape (V, 3).
vertex_normals: None or FloatTensor of shape (V, 3).
_VertsData object
"""
vertex = elements.get("vertex", None)
@ -870,16 +899,17 @@ def _get_verts(
if not isinstance(vertex, list):
raise ValueError("Invalid vertices in file.")
vertex_head = next(head for head in header.elements if head.name == "vertex")
point_idxs, color_idxs, color_scale, normal_idxs = _get_verts_column_indices(
vertex_head
)
column_idxs = _get_verts_column_indices(vertex_head)
# Case of no vertices
if vertex_head.count == 0:
verts = torch.zeros((0, 3), dtype=torch.float32)
if color_idxs is None:
return verts, None, None
return verts, torch.zeros((0, 3), dtype=torch.float32), None
if column_idxs.color_idxs is None:
return _VertsData(verts=verts)
return _VertsData(
verts=verts, verts_colors=torch.zeros((0, 3), dtype=torch.float32)
)
# Simple case where the only data is the vertices themselves
if (
@ -888,7 +918,7 @@ def _get_verts(
and vertex[0].ndim == 2
and vertex[0].shape[1] == 3
):
return _make_tensor(vertex[0], cols=3, dtype=torch.float32), None, None
return _VertsData(verts=_make_tensor(vertex[0], cols=3, dtype=torch.float32))
vertex_colors = None
vertex_normals = None
@ -896,14 +926,14 @@ def _get_verts(
if len(vertex) == 1:
# This is the case where the whole vertex element has one type,
# so it was read as a single array and we can index straight into it.
verts = torch.tensor(vertex[0][:, point_idxs], dtype=torch.float32)
if color_idxs is not None:
vertex_colors = color_scale * torch.tensor(
vertex[0][:, color_idxs], dtype=torch.float32
verts = torch.tensor(vertex[0][:, column_idxs.point_idxs], dtype=torch.float32)
if column_idxs.color_idxs is not None:
vertex_colors = column_idxs.color_scale * torch.tensor(
vertex[0][:, column_idxs.color_idxs], dtype=torch.float32
)
if normal_idxs is not None:
if column_idxs.normal_idxs is not None:
vertex_normals = torch.tensor(
vertex[0][:, normal_idxs], dtype=torch.float32
vertex[0][:, column_idxs.normal_idxs], dtype=torch.float32
)
else:
# The vertex element is heterogeneous. It was read as several arrays,
@ -918,7 +948,7 @@ def _get_verts(
]
verts = torch.empty(size=(vertex_head.count, 3), dtype=torch.float32)
for axis in range(3):
partnum, col = prop_to_partnum_col[point_idxs[axis]]
partnum, col = prop_to_partnum_col[column_idxs.point_idxs[axis]]
verts.numpy()[:, axis] = vertex[partnum][:, col]
# Note that in the previous line, we made the assignment
# as numpy arrays by casting verts. If we took the (more
@ -928,30 +958,49 @@ def _get_verts(
# if not vertex[partnum].flags["C_CONTIGUOUS"]:
# vertex[partnum] = np.ascontiguousarray(vertex[partnum])
# verts[:, axis] = torch.tensor((vertex[partnum][:, col]))
if color_idxs is not None:
if column_idxs.color_idxs is not None:
vertex_colors = torch.empty(
size=(vertex_head.count, 3), dtype=torch.float32
)
for color in range(3):
partnum, col = prop_to_partnum_col[color_idxs[color]]
partnum, col = prop_to_partnum_col[column_idxs.color_idxs[color]]
vertex_colors.numpy()[:, color] = vertex[partnum][:, col]
vertex_colors *= color_scale
if normal_idxs is not None:
vertex_colors *= column_idxs.color_scale
if column_idxs.normal_idxs is not None:
vertex_normals = torch.empty(
size=(vertex_head.count, 3), dtype=torch.float32
)
for axis in range(3):
partnum, col = prop_to_partnum_col[normal_idxs[axis]]
partnum, col = prop_to_partnum_col[column_idxs.normal_idxs[axis]]
vertex_normals.numpy()[:, axis] = vertex[partnum][:, col]
return verts, vertex_colors, vertex_normals
return _VertsData(
verts=verts,
verts_colors=vertex_colors,
verts_normals=vertex_normals,
)
def _load_ply(
f, *, path_manager: PathManager
) -> Tuple[
torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]
]:
@dataclass(frozen=True)
class _PlyData:
"""
Contains the data from a PLY file which has been read.
Members:
header: _PlyHeader of file metadata from the header
verts: FloatTensor of shape (V, 3).
faces: None or LongTensor of vertex indices, shape (F, 3).
verts_colors: None or FloatTensor of shape (V, 3).
verts_normals: None or FloatTensor of shape (V, 3).
"""
header: _PlyHeader
verts: torch.Tensor
faces: Optional[torch.Tensor]
verts_colors: Optional[torch.Tensor]
verts_normals: Optional[torch.Tensor]
def _load_ply(f, *, path_manager: PathManager) -> _PlyData:
"""
Load the data from a .ply file.
@ -964,14 +1013,11 @@ def _load_ply(
path_manager: PathManager for loading if f is a str.
Returns:
verts: FloatTensor of shape (V, 3).
faces: None or LongTensor of vertex indices, shape (F, 3).
vertex_colors: None or FloatTensor of shape (V, 3).
vertex_normals: None or FloatTensor of shape (V, 3).
_PlyData object
"""
header, elements = _load_ply_raw(f, path_manager=path_manager)
verts, vertex_colors, vertex_normals = _get_verts(header, elements)
verts_data = _get_verts(header, elements)
face = elements.get("face", None)
if face is not None:
@ -1007,9 +1053,9 @@ def _load_ply(
faces = torch.tensor(face_list, dtype=torch.int64)
if faces is not None:
_check_faces_indices(faces, max_index=verts.shape[0])
_check_faces_indices(faces, max_index=verts_data.verts.shape[0])
return verts, faces, vertex_colors, vertex_normals
return _PlyData(**asdict(verts_data), faces=faces, header=header)
def load_ply(
@ -1064,11 +1110,12 @@ def load_ply(
if path_manager is None:
path_manager = PathManager()
verts, faces, _, _ = _load_ply(f, path_manager=path_manager)
data = _load_ply(f, path_manager=path_manager)
faces = data.faces
if faces is None:
faces = torch.zeros(0, 3, dtype=torch.int64)
return verts, faces
return data.verts, faces
def _write_ply_header(
@ -1305,20 +1352,20 @@ class MeshPlyFormat(MeshFormatInterpreter):
if not endswith(path, self.known_suffixes):
return None
verts, faces, verts_colors, verts_normals = _load_ply(
f=path, path_manager=path_manager
)
data = _load_ply(f=path, path_manager=path_manager)
faces = data.faces
if faces is None:
faces = torch.zeros(0, 3, dtype=torch.int64)
texture = None
if include_textures and verts_colors is not None:
texture = TexturesVertex([verts_colors.to(device)])
if include_textures and data.verts_colors is not None:
texture = TexturesVertex([data.verts_colors.to(device)])
if verts_normals is not None:
verts_normals = [verts_normals]
verts_normals = None
if data.verts_normals is not None:
verts_normals = [data.verts_normals.to(device)]
mesh = Meshes(
verts=[verts.to(device)],
verts=[data.verts.to(device)],
faces=[faces.to(device)],
textures=texture,
verts_normals=verts_normals,
@ -1392,14 +1439,17 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter):
if not endswith(path, self.known_suffixes):
return None
verts, faces, features, normals = _load_ply(f=path, path_manager=path_manager)
verts = verts.to(device)
if features is not None:
features = [features.to(device)]
if normals is not None:
normals = [normals.to(device)]
data = _load_ply(f=path, path_manager=path_manager)
features = None
if data.verts_colors is not None:
features = [data.verts_colors.to(device)]
normals = None
if data.verts_normals is not None:
normals = [data.verts_normals.to(device)]
pointcloud = Pointclouds(points=[verts], features=features, normals=normals)
pointcloud = Pointclouds(
points=[data.verts.to(device)], features=features, normals=normals
)
return pointcloud
def save(