mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
Use dataclasses inside ply_io.
Summary: Refactor ply_io to make it easier to add new features. Mostly taken from the starting code I attached to https://github.com/facebookresearch/pytorch3d/issues/904. Reviewed By: patricklabatut Differential Revision: D34375978 fbshipit-source-id: ec017d31f07c6f71ba6d97a0623bb10be1e81212
This commit is contained in:
parent
feb5d36394
commit
967a099231
@ -14,8 +14,9 @@ import struct
|
||||
import sys
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from dataclasses import asdict, dataclass
|
||||
from io import BytesIO, TextIOBase
|
||||
from typing import List, Optional, Tuple, cast
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -137,6 +138,7 @@ class _PlyHeader:
|
||||
self.ascii: (bool) Whether in ascii format
|
||||
self.big_endian: (bool) (if not ascii) whether big endian
|
||||
self.obj_info: (List[str]) arbitrary extra data
|
||||
self.comments: (List[str]) comments
|
||||
|
||||
Args:
|
||||
f: file-like object.
|
||||
@ -145,7 +147,8 @@ class _PlyHeader:
|
||||
raise ValueError("Invalid file header.")
|
||||
seen_format = False
|
||||
self.elements: List[_PlyElementType] = []
|
||||
self.obj_info = []
|
||||
self.comments: List[str] = []
|
||||
self.obj_info: List[str] = []
|
||||
while True:
|
||||
line = f.readline()
|
||||
if isinstance(line, bytes):
|
||||
@ -176,6 +179,9 @@ class _PlyHeader:
|
||||
continue
|
||||
if line.startswith("format"):
|
||||
raise ValueError("Invalid format line.")
|
||||
if line.startswith("comment "):
|
||||
self.comments.append(line[8:])
|
||||
continue
|
||||
if line.startswith("comment") or len(line) == 0:
|
||||
continue
|
||||
if line.startswith("element"):
|
||||
@ -781,9 +787,28 @@ def _load_ply_raw(f, path_manager: PathManager) -> Tuple[_PlyHeader, dict]:
|
||||
return header, elements
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _VertsColumnIndices:
|
||||
"""
|
||||
Contains the relevant layout of the verts section of file being read.
|
||||
Members
|
||||
point_idxs: List[int] of 3 point columns.
|
||||
color_idxs: List[int] of 3 color columns if they are present,
|
||||
otherwise None.
|
||||
color_scale: value to scale colors by.
|
||||
normal_idxs: List[int] of 3 normals columns if they are present,
|
||||
otherwise None.
|
||||
"""
|
||||
|
||||
point_idxs: List[int]
|
||||
color_idxs: Optional[List[int]]
|
||||
color_scale: float
|
||||
normal_idxs: Optional[List[int]]
|
||||
|
||||
|
||||
def _get_verts_column_indices(
|
||||
vertex_head: _PlyElementType,
|
||||
) -> Tuple[List[int], Optional[List[int]], float, Optional[List[int]]]:
|
||||
) -> _VertsColumnIndices:
|
||||
"""
|
||||
Get the columns of verts, verts_colors, and verts_normals in the vertex
|
||||
element of a parsed ply file, together with a color scale factor.
|
||||
@ -809,12 +834,7 @@ def _get_verts_column_indices(
|
||||
vertex_head: as returned from load_ply_raw.
|
||||
|
||||
Returns:
|
||||
point_idxs: List[int] of 3 point columns.
|
||||
color_idxs: List[int] of 3 color columns if they are present,
|
||||
otherwise None.
|
||||
color_scale: value to scale colors by.
|
||||
normal_idxs: List[int] of 3 normals columns if they are present,
|
||||
otherwise None.
|
||||
_VertsColumnIndices object
|
||||
"""
|
||||
point_idxs: List[Optional[int]] = [None, None, None]
|
||||
color_idxs: List[Optional[int]] = [None, None, None]
|
||||
@ -839,19 +859,30 @@ def _get_verts_column_indices(
|
||||
for idx in color_idxs
|
||||
):
|
||||
color_scale = 1.0 / 255
|
||||
return (
|
||||
point_idxs,
|
||||
# pyre-fixme[22]: The cast is redundant.
|
||||
None if None in color_idxs else cast(List[int], color_idxs),
|
||||
color_scale,
|
||||
# pyre-fixme[22]: The cast is redundant.
|
||||
None if None in normal_idxs else cast(List[int], normal_idxs),
|
||||
return _VertsColumnIndices(
|
||||
point_idxs=point_idxs,
|
||||
color_idxs=None if None in color_idxs else color_idxs,
|
||||
color_scale=color_scale,
|
||||
normal_idxs=None if None in normal_idxs else normal_idxs,
|
||||
)
|
||||
|
||||
|
||||
def _get_verts(
|
||||
header: _PlyHeader, elements: dict
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
@dataclass(frozen=True)
|
||||
class _VertsData:
|
||||
"""
|
||||
Contains the data of the verts section of file being read.
|
||||
Members:
|
||||
verts: FloatTensor of shape (V, 3).
|
||||
verts_colors: None or FloatTensor of shape (V, 3).
|
||||
verts_normals: None or FloatTensor of shape (V, 3).
|
||||
"""
|
||||
|
||||
verts: torch.Tensor
|
||||
verts_colors: Optional[torch.Tensor] = None
|
||||
verts_normals: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData:
|
||||
"""
|
||||
Get the vertex locations, colors and normals from a parsed ply file.
|
||||
|
||||
@ -859,9 +890,7 @@ def _get_verts(
|
||||
header, elements: as returned from load_ply_raw.
|
||||
|
||||
Returns:
|
||||
verts: FloatTensor of shape (V, 3).
|
||||
vertex_colors: None or FloatTensor of shape (V, 3).
|
||||
vertex_normals: None or FloatTensor of shape (V, 3).
|
||||
_VertsData object
|
||||
"""
|
||||
|
||||
vertex = elements.get("vertex", None)
|
||||
@ -870,16 +899,17 @@ def _get_verts(
|
||||
if not isinstance(vertex, list):
|
||||
raise ValueError("Invalid vertices in file.")
|
||||
vertex_head = next(head for head in header.elements if head.name == "vertex")
|
||||
point_idxs, color_idxs, color_scale, normal_idxs = _get_verts_column_indices(
|
||||
vertex_head
|
||||
)
|
||||
|
||||
column_idxs = _get_verts_column_indices(vertex_head)
|
||||
|
||||
# Case of no vertices
|
||||
if vertex_head.count == 0:
|
||||
verts = torch.zeros((0, 3), dtype=torch.float32)
|
||||
if color_idxs is None:
|
||||
return verts, None, None
|
||||
return verts, torch.zeros((0, 3), dtype=torch.float32), None
|
||||
if column_idxs.color_idxs is None:
|
||||
return _VertsData(verts=verts)
|
||||
return _VertsData(
|
||||
verts=verts, verts_colors=torch.zeros((0, 3), dtype=torch.float32)
|
||||
)
|
||||
|
||||
# Simple case where the only data is the vertices themselves
|
||||
if (
|
||||
@ -888,7 +918,7 @@ def _get_verts(
|
||||
and vertex[0].ndim == 2
|
||||
and vertex[0].shape[1] == 3
|
||||
):
|
||||
return _make_tensor(vertex[0], cols=3, dtype=torch.float32), None, None
|
||||
return _VertsData(verts=_make_tensor(vertex[0], cols=3, dtype=torch.float32))
|
||||
|
||||
vertex_colors = None
|
||||
vertex_normals = None
|
||||
@ -896,14 +926,14 @@ def _get_verts(
|
||||
if len(vertex) == 1:
|
||||
# This is the case where the whole vertex element has one type,
|
||||
# so it was read as a single array and we can index straight into it.
|
||||
verts = torch.tensor(vertex[0][:, point_idxs], dtype=torch.float32)
|
||||
if color_idxs is not None:
|
||||
vertex_colors = color_scale * torch.tensor(
|
||||
vertex[0][:, color_idxs], dtype=torch.float32
|
||||
verts = torch.tensor(vertex[0][:, column_idxs.point_idxs], dtype=torch.float32)
|
||||
if column_idxs.color_idxs is not None:
|
||||
vertex_colors = column_idxs.color_scale * torch.tensor(
|
||||
vertex[0][:, column_idxs.color_idxs], dtype=torch.float32
|
||||
)
|
||||
if normal_idxs is not None:
|
||||
if column_idxs.normal_idxs is not None:
|
||||
vertex_normals = torch.tensor(
|
||||
vertex[0][:, normal_idxs], dtype=torch.float32
|
||||
vertex[0][:, column_idxs.normal_idxs], dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
# The vertex element is heterogeneous. It was read as several arrays,
|
||||
@ -918,7 +948,7 @@ def _get_verts(
|
||||
]
|
||||
verts = torch.empty(size=(vertex_head.count, 3), dtype=torch.float32)
|
||||
for axis in range(3):
|
||||
partnum, col = prop_to_partnum_col[point_idxs[axis]]
|
||||
partnum, col = prop_to_partnum_col[column_idxs.point_idxs[axis]]
|
||||
verts.numpy()[:, axis] = vertex[partnum][:, col]
|
||||
# Note that in the previous line, we made the assignment
|
||||
# as numpy arrays by casting verts. If we took the (more
|
||||
@ -928,30 +958,49 @@ def _get_verts(
|
||||
# if not vertex[partnum].flags["C_CONTIGUOUS"]:
|
||||
# vertex[partnum] = np.ascontiguousarray(vertex[partnum])
|
||||
# verts[:, axis] = torch.tensor((vertex[partnum][:, col]))
|
||||
if color_idxs is not None:
|
||||
if column_idxs.color_idxs is not None:
|
||||
vertex_colors = torch.empty(
|
||||
size=(vertex_head.count, 3), dtype=torch.float32
|
||||
)
|
||||
for color in range(3):
|
||||
partnum, col = prop_to_partnum_col[color_idxs[color]]
|
||||
partnum, col = prop_to_partnum_col[column_idxs.color_idxs[color]]
|
||||
vertex_colors.numpy()[:, color] = vertex[partnum][:, col]
|
||||
vertex_colors *= color_scale
|
||||
if normal_idxs is not None:
|
||||
vertex_colors *= column_idxs.color_scale
|
||||
if column_idxs.normal_idxs is not None:
|
||||
vertex_normals = torch.empty(
|
||||
size=(vertex_head.count, 3), dtype=torch.float32
|
||||
)
|
||||
for axis in range(3):
|
||||
partnum, col = prop_to_partnum_col[normal_idxs[axis]]
|
||||
partnum, col = prop_to_partnum_col[column_idxs.normal_idxs[axis]]
|
||||
vertex_normals.numpy()[:, axis] = vertex[partnum][:, col]
|
||||
|
||||
return verts, vertex_colors, vertex_normals
|
||||
return _VertsData(
|
||||
verts=verts,
|
||||
verts_colors=vertex_colors,
|
||||
verts_normals=vertex_normals,
|
||||
)
|
||||
|
||||
|
||||
def _load_ply(
|
||||
f, *, path_manager: PathManager
|
||||
) -> Tuple[
|
||||
torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]
|
||||
]:
|
||||
@dataclass(frozen=True)
|
||||
class _PlyData:
|
||||
"""
|
||||
Contains the data from a PLY file which has been read.
|
||||
Members:
|
||||
header: _PlyHeader of file metadata from the header
|
||||
verts: FloatTensor of shape (V, 3).
|
||||
faces: None or LongTensor of vertex indices, shape (F, 3).
|
||||
verts_colors: None or FloatTensor of shape (V, 3).
|
||||
verts_normals: None or FloatTensor of shape (V, 3).
|
||||
"""
|
||||
|
||||
header: _PlyHeader
|
||||
verts: torch.Tensor
|
||||
faces: Optional[torch.Tensor]
|
||||
verts_colors: Optional[torch.Tensor]
|
||||
verts_normals: Optional[torch.Tensor]
|
||||
|
||||
|
||||
def _load_ply(f, *, path_manager: PathManager) -> _PlyData:
|
||||
"""
|
||||
Load the data from a .ply file.
|
||||
|
||||
@ -964,14 +1013,11 @@ def _load_ply(
|
||||
path_manager: PathManager for loading if f is a str.
|
||||
|
||||
Returns:
|
||||
verts: FloatTensor of shape (V, 3).
|
||||
faces: None or LongTensor of vertex indices, shape (F, 3).
|
||||
vertex_colors: None or FloatTensor of shape (V, 3).
|
||||
vertex_normals: None or FloatTensor of shape (V, 3).
|
||||
_PlyData object
|
||||
"""
|
||||
header, elements = _load_ply_raw(f, path_manager=path_manager)
|
||||
|
||||
verts, vertex_colors, vertex_normals = _get_verts(header, elements)
|
||||
verts_data = _get_verts(header, elements)
|
||||
|
||||
face = elements.get("face", None)
|
||||
if face is not None:
|
||||
@ -1007,9 +1053,9 @@ def _load_ply(
|
||||
faces = torch.tensor(face_list, dtype=torch.int64)
|
||||
|
||||
if faces is not None:
|
||||
_check_faces_indices(faces, max_index=verts.shape[0])
|
||||
_check_faces_indices(faces, max_index=verts_data.verts.shape[0])
|
||||
|
||||
return verts, faces, vertex_colors, vertex_normals
|
||||
return _PlyData(**asdict(verts_data), faces=faces, header=header)
|
||||
|
||||
|
||||
def load_ply(
|
||||
@ -1064,11 +1110,12 @@ def load_ply(
|
||||
|
||||
if path_manager is None:
|
||||
path_manager = PathManager()
|
||||
verts, faces, _, _ = _load_ply(f, path_manager=path_manager)
|
||||
data = _load_ply(f, path_manager=path_manager)
|
||||
faces = data.faces
|
||||
if faces is None:
|
||||
faces = torch.zeros(0, 3, dtype=torch.int64)
|
||||
|
||||
return verts, faces
|
||||
return data.verts, faces
|
||||
|
||||
|
||||
def _write_ply_header(
|
||||
@ -1305,20 +1352,20 @@ class MeshPlyFormat(MeshFormatInterpreter):
|
||||
if not endswith(path, self.known_suffixes):
|
||||
return None
|
||||
|
||||
verts, faces, verts_colors, verts_normals = _load_ply(
|
||||
f=path, path_manager=path_manager
|
||||
)
|
||||
data = _load_ply(f=path, path_manager=path_manager)
|
||||
faces = data.faces
|
||||
if faces is None:
|
||||
faces = torch.zeros(0, 3, dtype=torch.int64)
|
||||
|
||||
texture = None
|
||||
if include_textures and verts_colors is not None:
|
||||
texture = TexturesVertex([verts_colors.to(device)])
|
||||
if include_textures and data.verts_colors is not None:
|
||||
texture = TexturesVertex([data.verts_colors.to(device)])
|
||||
|
||||
if verts_normals is not None:
|
||||
verts_normals = [verts_normals]
|
||||
verts_normals = None
|
||||
if data.verts_normals is not None:
|
||||
verts_normals = [data.verts_normals.to(device)]
|
||||
mesh = Meshes(
|
||||
verts=[verts.to(device)],
|
||||
verts=[data.verts.to(device)],
|
||||
faces=[faces.to(device)],
|
||||
textures=texture,
|
||||
verts_normals=verts_normals,
|
||||
@ -1392,14 +1439,17 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter):
|
||||
if not endswith(path, self.known_suffixes):
|
||||
return None
|
||||
|
||||
verts, faces, features, normals = _load_ply(f=path, path_manager=path_manager)
|
||||
verts = verts.to(device)
|
||||
if features is not None:
|
||||
features = [features.to(device)]
|
||||
if normals is not None:
|
||||
normals = [normals.to(device)]
|
||||
data = _load_ply(f=path, path_manager=path_manager)
|
||||
features = None
|
||||
if data.verts_colors is not None:
|
||||
features = [data.verts_colors.to(device)]
|
||||
normals = None
|
||||
if data.verts_normals is not None:
|
||||
normals = [data.verts_normals.to(device)]
|
||||
|
||||
pointcloud = Pointclouds(points=[verts], features=features, normals=normals)
|
||||
pointcloud = Pointclouds(
|
||||
points=[data.verts.to(device)], features=features, normals=normals
|
||||
)
|
||||
return pointcloud
|
||||
|
||||
def save(
|
||||
|
Loading…
x
Reference in New Issue
Block a user