mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Use dataclasses inside ply_io.
Summary: Refactor ply_io to make it easier to add new features. Mostly taken from the starting code I attached to https://github.com/facebookresearch/pytorch3d/issues/904. Reviewed By: patricklabatut Differential Revision: D34375978 fbshipit-source-id: ec017d31f07c6f71ba6d97a0623bb10be1e81212
This commit is contained in:
parent
feb5d36394
commit
967a099231
@ -14,8 +14,9 @@ import struct
|
|||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
from io import BytesIO, TextIOBase
|
from io import BytesIO, TextIOBase
|
||||||
from typing import List, Optional, Tuple, cast
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -137,6 +138,7 @@ class _PlyHeader:
|
|||||||
self.ascii: (bool) Whether in ascii format
|
self.ascii: (bool) Whether in ascii format
|
||||||
self.big_endian: (bool) (if not ascii) whether big endian
|
self.big_endian: (bool) (if not ascii) whether big endian
|
||||||
self.obj_info: (List[str]) arbitrary extra data
|
self.obj_info: (List[str]) arbitrary extra data
|
||||||
|
self.comments: (List[str]) comments
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
f: file-like object.
|
f: file-like object.
|
||||||
@ -145,7 +147,8 @@ class _PlyHeader:
|
|||||||
raise ValueError("Invalid file header.")
|
raise ValueError("Invalid file header.")
|
||||||
seen_format = False
|
seen_format = False
|
||||||
self.elements: List[_PlyElementType] = []
|
self.elements: List[_PlyElementType] = []
|
||||||
self.obj_info = []
|
self.comments: List[str] = []
|
||||||
|
self.obj_info: List[str] = []
|
||||||
while True:
|
while True:
|
||||||
line = f.readline()
|
line = f.readline()
|
||||||
if isinstance(line, bytes):
|
if isinstance(line, bytes):
|
||||||
@ -176,6 +179,9 @@ class _PlyHeader:
|
|||||||
continue
|
continue
|
||||||
if line.startswith("format"):
|
if line.startswith("format"):
|
||||||
raise ValueError("Invalid format line.")
|
raise ValueError("Invalid format line.")
|
||||||
|
if line.startswith("comment "):
|
||||||
|
self.comments.append(line[8:])
|
||||||
|
continue
|
||||||
if line.startswith("comment") or len(line) == 0:
|
if line.startswith("comment") or len(line) == 0:
|
||||||
continue
|
continue
|
||||||
if line.startswith("element"):
|
if line.startswith("element"):
|
||||||
@ -781,9 +787,28 @@ def _load_ply_raw(f, path_manager: PathManager) -> Tuple[_PlyHeader, dict]:
|
|||||||
return header, elements
|
return header, elements
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class _VertsColumnIndices:
|
||||||
|
"""
|
||||||
|
Contains the relevant layout of the verts section of file being read.
|
||||||
|
Members
|
||||||
|
point_idxs: List[int] of 3 point columns.
|
||||||
|
color_idxs: List[int] of 3 color columns if they are present,
|
||||||
|
otherwise None.
|
||||||
|
color_scale: value to scale colors by.
|
||||||
|
normal_idxs: List[int] of 3 normals columns if they are present,
|
||||||
|
otherwise None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
point_idxs: List[int]
|
||||||
|
color_idxs: Optional[List[int]]
|
||||||
|
color_scale: float
|
||||||
|
normal_idxs: Optional[List[int]]
|
||||||
|
|
||||||
|
|
||||||
def _get_verts_column_indices(
|
def _get_verts_column_indices(
|
||||||
vertex_head: _PlyElementType,
|
vertex_head: _PlyElementType,
|
||||||
) -> Tuple[List[int], Optional[List[int]], float, Optional[List[int]]]:
|
) -> _VertsColumnIndices:
|
||||||
"""
|
"""
|
||||||
Get the columns of verts, verts_colors, and verts_normals in the vertex
|
Get the columns of verts, verts_colors, and verts_normals in the vertex
|
||||||
element of a parsed ply file, together with a color scale factor.
|
element of a parsed ply file, together with a color scale factor.
|
||||||
@ -809,12 +834,7 @@ def _get_verts_column_indices(
|
|||||||
vertex_head: as returned from load_ply_raw.
|
vertex_head: as returned from load_ply_raw.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
point_idxs: List[int] of 3 point columns.
|
_VertsColumnIndices object
|
||||||
color_idxs: List[int] of 3 color columns if they are present,
|
|
||||||
otherwise None.
|
|
||||||
color_scale: value to scale colors by.
|
|
||||||
normal_idxs: List[int] of 3 normals columns if they are present,
|
|
||||||
otherwise None.
|
|
||||||
"""
|
"""
|
||||||
point_idxs: List[Optional[int]] = [None, None, None]
|
point_idxs: List[Optional[int]] = [None, None, None]
|
||||||
color_idxs: List[Optional[int]] = [None, None, None]
|
color_idxs: List[Optional[int]] = [None, None, None]
|
||||||
@ -839,19 +859,30 @@ def _get_verts_column_indices(
|
|||||||
for idx in color_idxs
|
for idx in color_idxs
|
||||||
):
|
):
|
||||||
color_scale = 1.0 / 255
|
color_scale = 1.0 / 255
|
||||||
return (
|
return _VertsColumnIndices(
|
||||||
point_idxs,
|
point_idxs=point_idxs,
|
||||||
# pyre-fixme[22]: The cast is redundant.
|
color_idxs=None if None in color_idxs else color_idxs,
|
||||||
None if None in color_idxs else cast(List[int], color_idxs),
|
color_scale=color_scale,
|
||||||
color_scale,
|
normal_idxs=None if None in normal_idxs else normal_idxs,
|
||||||
# pyre-fixme[22]: The cast is redundant.
|
|
||||||
None if None in normal_idxs else cast(List[int], normal_idxs),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_verts(
|
@dataclass(frozen=True)
|
||||||
header: _PlyHeader, elements: dict
|
class _VertsData:
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
"""
|
||||||
|
Contains the data of the verts section of file being read.
|
||||||
|
Members:
|
||||||
|
verts: FloatTensor of shape (V, 3).
|
||||||
|
verts_colors: None or FloatTensor of shape (V, 3).
|
||||||
|
verts_normals: None or FloatTensor of shape (V, 3).
|
||||||
|
"""
|
||||||
|
|
||||||
|
verts: torch.Tensor
|
||||||
|
verts_colors: Optional[torch.Tensor] = None
|
||||||
|
verts_normals: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData:
|
||||||
"""
|
"""
|
||||||
Get the vertex locations, colors and normals from a parsed ply file.
|
Get the vertex locations, colors and normals from a parsed ply file.
|
||||||
|
|
||||||
@ -859,9 +890,7 @@ def _get_verts(
|
|||||||
header, elements: as returned from load_ply_raw.
|
header, elements: as returned from load_ply_raw.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
verts: FloatTensor of shape (V, 3).
|
_VertsData object
|
||||||
vertex_colors: None or FloatTensor of shape (V, 3).
|
|
||||||
vertex_normals: None or FloatTensor of shape (V, 3).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
vertex = elements.get("vertex", None)
|
vertex = elements.get("vertex", None)
|
||||||
@ -870,16 +899,17 @@ def _get_verts(
|
|||||||
if not isinstance(vertex, list):
|
if not isinstance(vertex, list):
|
||||||
raise ValueError("Invalid vertices in file.")
|
raise ValueError("Invalid vertices in file.")
|
||||||
vertex_head = next(head for head in header.elements if head.name == "vertex")
|
vertex_head = next(head for head in header.elements if head.name == "vertex")
|
||||||
point_idxs, color_idxs, color_scale, normal_idxs = _get_verts_column_indices(
|
|
||||||
vertex_head
|
column_idxs = _get_verts_column_indices(vertex_head)
|
||||||
)
|
|
||||||
|
|
||||||
# Case of no vertices
|
# Case of no vertices
|
||||||
if vertex_head.count == 0:
|
if vertex_head.count == 0:
|
||||||
verts = torch.zeros((0, 3), dtype=torch.float32)
|
verts = torch.zeros((0, 3), dtype=torch.float32)
|
||||||
if color_idxs is None:
|
if column_idxs.color_idxs is None:
|
||||||
return verts, None, None
|
return _VertsData(verts=verts)
|
||||||
return verts, torch.zeros((0, 3), dtype=torch.float32), None
|
return _VertsData(
|
||||||
|
verts=verts, verts_colors=torch.zeros((0, 3), dtype=torch.float32)
|
||||||
|
)
|
||||||
|
|
||||||
# Simple case where the only data is the vertices themselves
|
# Simple case where the only data is the vertices themselves
|
||||||
if (
|
if (
|
||||||
@ -888,7 +918,7 @@ def _get_verts(
|
|||||||
and vertex[0].ndim == 2
|
and vertex[0].ndim == 2
|
||||||
and vertex[0].shape[1] == 3
|
and vertex[0].shape[1] == 3
|
||||||
):
|
):
|
||||||
return _make_tensor(vertex[0], cols=3, dtype=torch.float32), None, None
|
return _VertsData(verts=_make_tensor(vertex[0], cols=3, dtype=torch.float32))
|
||||||
|
|
||||||
vertex_colors = None
|
vertex_colors = None
|
||||||
vertex_normals = None
|
vertex_normals = None
|
||||||
@ -896,14 +926,14 @@ def _get_verts(
|
|||||||
if len(vertex) == 1:
|
if len(vertex) == 1:
|
||||||
# This is the case where the whole vertex element has one type,
|
# This is the case where the whole vertex element has one type,
|
||||||
# so it was read as a single array and we can index straight into it.
|
# so it was read as a single array and we can index straight into it.
|
||||||
verts = torch.tensor(vertex[0][:, point_idxs], dtype=torch.float32)
|
verts = torch.tensor(vertex[0][:, column_idxs.point_idxs], dtype=torch.float32)
|
||||||
if color_idxs is not None:
|
if column_idxs.color_idxs is not None:
|
||||||
vertex_colors = color_scale * torch.tensor(
|
vertex_colors = column_idxs.color_scale * torch.tensor(
|
||||||
vertex[0][:, color_idxs], dtype=torch.float32
|
vertex[0][:, column_idxs.color_idxs], dtype=torch.float32
|
||||||
)
|
)
|
||||||
if normal_idxs is not None:
|
if column_idxs.normal_idxs is not None:
|
||||||
vertex_normals = torch.tensor(
|
vertex_normals = torch.tensor(
|
||||||
vertex[0][:, normal_idxs], dtype=torch.float32
|
vertex[0][:, column_idxs.normal_idxs], dtype=torch.float32
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# The vertex element is heterogeneous. It was read as several arrays,
|
# The vertex element is heterogeneous. It was read as several arrays,
|
||||||
@ -918,7 +948,7 @@ def _get_verts(
|
|||||||
]
|
]
|
||||||
verts = torch.empty(size=(vertex_head.count, 3), dtype=torch.float32)
|
verts = torch.empty(size=(vertex_head.count, 3), dtype=torch.float32)
|
||||||
for axis in range(3):
|
for axis in range(3):
|
||||||
partnum, col = prop_to_partnum_col[point_idxs[axis]]
|
partnum, col = prop_to_partnum_col[column_idxs.point_idxs[axis]]
|
||||||
verts.numpy()[:, axis] = vertex[partnum][:, col]
|
verts.numpy()[:, axis] = vertex[partnum][:, col]
|
||||||
# Note that in the previous line, we made the assignment
|
# Note that in the previous line, we made the assignment
|
||||||
# as numpy arrays by casting verts. If we took the (more
|
# as numpy arrays by casting verts. If we took the (more
|
||||||
@ -928,30 +958,49 @@ def _get_verts(
|
|||||||
# if not vertex[partnum].flags["C_CONTIGUOUS"]:
|
# if not vertex[partnum].flags["C_CONTIGUOUS"]:
|
||||||
# vertex[partnum] = np.ascontiguousarray(vertex[partnum])
|
# vertex[partnum] = np.ascontiguousarray(vertex[partnum])
|
||||||
# verts[:, axis] = torch.tensor((vertex[partnum][:, col]))
|
# verts[:, axis] = torch.tensor((vertex[partnum][:, col]))
|
||||||
if color_idxs is not None:
|
if column_idxs.color_idxs is not None:
|
||||||
vertex_colors = torch.empty(
|
vertex_colors = torch.empty(
|
||||||
size=(vertex_head.count, 3), dtype=torch.float32
|
size=(vertex_head.count, 3), dtype=torch.float32
|
||||||
)
|
)
|
||||||
for color in range(3):
|
for color in range(3):
|
||||||
partnum, col = prop_to_partnum_col[color_idxs[color]]
|
partnum, col = prop_to_partnum_col[column_idxs.color_idxs[color]]
|
||||||
vertex_colors.numpy()[:, color] = vertex[partnum][:, col]
|
vertex_colors.numpy()[:, color] = vertex[partnum][:, col]
|
||||||
vertex_colors *= color_scale
|
vertex_colors *= column_idxs.color_scale
|
||||||
if normal_idxs is not None:
|
if column_idxs.normal_idxs is not None:
|
||||||
vertex_normals = torch.empty(
|
vertex_normals = torch.empty(
|
||||||
size=(vertex_head.count, 3), dtype=torch.float32
|
size=(vertex_head.count, 3), dtype=torch.float32
|
||||||
)
|
)
|
||||||
for axis in range(3):
|
for axis in range(3):
|
||||||
partnum, col = prop_to_partnum_col[normal_idxs[axis]]
|
partnum, col = prop_to_partnum_col[column_idxs.normal_idxs[axis]]
|
||||||
vertex_normals.numpy()[:, axis] = vertex[partnum][:, col]
|
vertex_normals.numpy()[:, axis] = vertex[partnum][:, col]
|
||||||
|
|
||||||
return verts, vertex_colors, vertex_normals
|
return _VertsData(
|
||||||
|
verts=verts,
|
||||||
|
verts_colors=vertex_colors,
|
||||||
|
verts_normals=vertex_normals,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _load_ply(
|
@dataclass(frozen=True)
|
||||||
f, *, path_manager: PathManager
|
class _PlyData:
|
||||||
) -> Tuple[
|
"""
|
||||||
torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]
|
Contains the data from a PLY file which has been read.
|
||||||
]:
|
Members:
|
||||||
|
header: _PlyHeader of file metadata from the header
|
||||||
|
verts: FloatTensor of shape (V, 3).
|
||||||
|
faces: None or LongTensor of vertex indices, shape (F, 3).
|
||||||
|
verts_colors: None or FloatTensor of shape (V, 3).
|
||||||
|
verts_normals: None or FloatTensor of shape (V, 3).
|
||||||
|
"""
|
||||||
|
|
||||||
|
header: _PlyHeader
|
||||||
|
verts: torch.Tensor
|
||||||
|
faces: Optional[torch.Tensor]
|
||||||
|
verts_colors: Optional[torch.Tensor]
|
||||||
|
verts_normals: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
def _load_ply(f, *, path_manager: PathManager) -> _PlyData:
|
||||||
"""
|
"""
|
||||||
Load the data from a .ply file.
|
Load the data from a .ply file.
|
||||||
|
|
||||||
@ -964,14 +1013,11 @@ def _load_ply(
|
|||||||
path_manager: PathManager for loading if f is a str.
|
path_manager: PathManager for loading if f is a str.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
verts: FloatTensor of shape (V, 3).
|
_PlyData object
|
||||||
faces: None or LongTensor of vertex indices, shape (F, 3).
|
|
||||||
vertex_colors: None or FloatTensor of shape (V, 3).
|
|
||||||
vertex_normals: None or FloatTensor of shape (V, 3).
|
|
||||||
"""
|
"""
|
||||||
header, elements = _load_ply_raw(f, path_manager=path_manager)
|
header, elements = _load_ply_raw(f, path_manager=path_manager)
|
||||||
|
|
||||||
verts, vertex_colors, vertex_normals = _get_verts(header, elements)
|
verts_data = _get_verts(header, elements)
|
||||||
|
|
||||||
face = elements.get("face", None)
|
face = elements.get("face", None)
|
||||||
if face is not None:
|
if face is not None:
|
||||||
@ -1007,9 +1053,9 @@ def _load_ply(
|
|||||||
faces = torch.tensor(face_list, dtype=torch.int64)
|
faces = torch.tensor(face_list, dtype=torch.int64)
|
||||||
|
|
||||||
if faces is not None:
|
if faces is not None:
|
||||||
_check_faces_indices(faces, max_index=verts.shape[0])
|
_check_faces_indices(faces, max_index=verts_data.verts.shape[0])
|
||||||
|
|
||||||
return verts, faces, vertex_colors, vertex_normals
|
return _PlyData(**asdict(verts_data), faces=faces, header=header)
|
||||||
|
|
||||||
|
|
||||||
def load_ply(
|
def load_ply(
|
||||||
@ -1064,11 +1110,12 @@ def load_ply(
|
|||||||
|
|
||||||
if path_manager is None:
|
if path_manager is None:
|
||||||
path_manager = PathManager()
|
path_manager = PathManager()
|
||||||
verts, faces, _, _ = _load_ply(f, path_manager=path_manager)
|
data = _load_ply(f, path_manager=path_manager)
|
||||||
|
faces = data.faces
|
||||||
if faces is None:
|
if faces is None:
|
||||||
faces = torch.zeros(0, 3, dtype=torch.int64)
|
faces = torch.zeros(0, 3, dtype=torch.int64)
|
||||||
|
|
||||||
return verts, faces
|
return data.verts, faces
|
||||||
|
|
||||||
|
|
||||||
def _write_ply_header(
|
def _write_ply_header(
|
||||||
@ -1305,20 +1352,20 @@ class MeshPlyFormat(MeshFormatInterpreter):
|
|||||||
if not endswith(path, self.known_suffixes):
|
if not endswith(path, self.known_suffixes):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
verts, faces, verts_colors, verts_normals = _load_ply(
|
data = _load_ply(f=path, path_manager=path_manager)
|
||||||
f=path, path_manager=path_manager
|
faces = data.faces
|
||||||
)
|
|
||||||
if faces is None:
|
if faces is None:
|
||||||
faces = torch.zeros(0, 3, dtype=torch.int64)
|
faces = torch.zeros(0, 3, dtype=torch.int64)
|
||||||
|
|
||||||
texture = None
|
texture = None
|
||||||
if include_textures and verts_colors is not None:
|
if include_textures and data.verts_colors is not None:
|
||||||
texture = TexturesVertex([verts_colors.to(device)])
|
texture = TexturesVertex([data.verts_colors.to(device)])
|
||||||
|
|
||||||
if verts_normals is not None:
|
verts_normals = None
|
||||||
verts_normals = [verts_normals]
|
if data.verts_normals is not None:
|
||||||
|
verts_normals = [data.verts_normals.to(device)]
|
||||||
mesh = Meshes(
|
mesh = Meshes(
|
||||||
verts=[verts.to(device)],
|
verts=[data.verts.to(device)],
|
||||||
faces=[faces.to(device)],
|
faces=[faces.to(device)],
|
||||||
textures=texture,
|
textures=texture,
|
||||||
verts_normals=verts_normals,
|
verts_normals=verts_normals,
|
||||||
@ -1392,14 +1439,17 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter):
|
|||||||
if not endswith(path, self.known_suffixes):
|
if not endswith(path, self.known_suffixes):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
verts, faces, features, normals = _load_ply(f=path, path_manager=path_manager)
|
data = _load_ply(f=path, path_manager=path_manager)
|
||||||
verts = verts.to(device)
|
features = None
|
||||||
if features is not None:
|
if data.verts_colors is not None:
|
||||||
features = [features.to(device)]
|
features = [data.verts_colors.to(device)]
|
||||||
if normals is not None:
|
normals = None
|
||||||
normals = [normals.to(device)]
|
if data.verts_normals is not None:
|
||||||
|
normals = [data.verts_normals.to(device)]
|
||||||
|
|
||||||
pointcloud = Pointclouds(points=[verts], features=features, normals=normals)
|
pointcloud = Pointclouds(
|
||||||
|
points=[data.verts.to(device)], features=features, normals=normals
|
||||||
|
)
|
||||||
return pointcloud
|
return pointcloud
|
||||||
|
|
||||||
def save(
|
def save(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user