diff --git a/pytorch3d/io/ply_io.py b/pytorch3d/io/ply_io.py index a98ab168..37c86517 100644 --- a/pytorch3d/io/ply_io.py +++ b/pytorch3d/io/ply_io.py @@ -14,8 +14,9 @@ import struct import sys import warnings from collections import namedtuple +from dataclasses import asdict, dataclass from io import BytesIO, TextIOBase -from typing import List, Optional, Tuple, cast +from typing import List, Optional, Tuple import numpy as np import torch @@ -137,6 +138,7 @@ class _PlyHeader: self.ascii: (bool) Whether in ascii format self.big_endian: (bool) (if not ascii) whether big endian self.obj_info: (List[str]) arbitrary extra data + self.comments: (List[str]) comments Args: f: file-like object. @@ -145,7 +147,8 @@ class _PlyHeader: raise ValueError("Invalid file header.") seen_format = False self.elements: List[_PlyElementType] = [] - self.obj_info = [] + self.comments: List[str] = [] + self.obj_info: List[str] = [] while True: line = f.readline() if isinstance(line, bytes): @@ -176,6 +179,9 @@ class _PlyHeader: continue if line.startswith("format"): raise ValueError("Invalid format line.") + if line.startswith("comment "): + self.comments.append(line[8:]) + continue if line.startswith("comment") or len(line) == 0: continue if line.startswith("element"): @@ -781,9 +787,28 @@ def _load_ply_raw(f, path_manager: PathManager) -> Tuple[_PlyHeader, dict]: return header, elements +@dataclass(frozen=True) +class _VertsColumnIndices: + """ + Contains the relevant layout of the verts section of file being read. + Members + point_idxs: List[int] of 3 point columns. + color_idxs: List[int] of 3 color columns if they are present, + otherwise None. + color_scale: value to scale colors by. + normal_idxs: List[int] of 3 normals columns if they are present, + otherwise None. + """ + + point_idxs: List[int] + color_idxs: Optional[List[int]] + color_scale: float + normal_idxs: Optional[List[int]] + + def _get_verts_column_indices( vertex_head: _PlyElementType, -) -> Tuple[List[int], Optional[List[int]], float, Optional[List[int]]]: +) -> _VertsColumnIndices: """ Get the columns of verts, verts_colors, and verts_normals in the vertex element of a parsed ply file, together with a color scale factor. @@ -809,12 +834,7 @@ def _get_verts_column_indices( vertex_head: as returned from load_ply_raw. Returns: - point_idxs: List[int] of 3 point columns. - color_idxs: List[int] of 3 color columns if they are present, - otherwise None. - color_scale: value to scale colors by. - normal_idxs: List[int] of 3 normals columns if they are present, - otherwise None. + _VertsColumnIndices object """ point_idxs: List[Optional[int]] = [None, None, None] color_idxs: List[Optional[int]] = [None, None, None] @@ -839,19 +859,30 @@ def _get_verts_column_indices( for idx in color_idxs ): color_scale = 1.0 / 255 - return ( - point_idxs, - # pyre-fixme[22]: The cast is redundant. - None if None in color_idxs else cast(List[int], color_idxs), - color_scale, - # pyre-fixme[22]: The cast is redundant. - None if None in normal_idxs else cast(List[int], normal_idxs), + return _VertsColumnIndices( + point_idxs=point_idxs, + color_idxs=None if None in color_idxs else color_idxs, + color_scale=color_scale, + normal_idxs=None if None in normal_idxs else normal_idxs, ) -def _get_verts( - header: _PlyHeader, elements: dict -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: +@dataclass(frozen=True) +class _VertsData: + """ + Contains the data of the verts section of file being read. + Members: + verts: FloatTensor of shape (V, 3). + verts_colors: None or FloatTensor of shape (V, 3). + verts_normals: None or FloatTensor of shape (V, 3). + """ + + verts: torch.Tensor + verts_colors: Optional[torch.Tensor] = None + verts_normals: Optional[torch.Tensor] = None + + +def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData: """ Get the vertex locations, colors and normals from a parsed ply file. @@ -859,9 +890,7 @@ def _get_verts( header, elements: as returned from load_ply_raw. Returns: - verts: FloatTensor of shape (V, 3). - vertex_colors: None or FloatTensor of shape (V, 3). - vertex_normals: None or FloatTensor of shape (V, 3). + _VertsData object """ vertex = elements.get("vertex", None) @@ -870,16 +899,17 @@ def _get_verts( if not isinstance(vertex, list): raise ValueError("Invalid vertices in file.") vertex_head = next(head for head in header.elements if head.name == "vertex") - point_idxs, color_idxs, color_scale, normal_idxs = _get_verts_column_indices( - vertex_head - ) + + column_idxs = _get_verts_column_indices(vertex_head) # Case of no vertices if vertex_head.count == 0: verts = torch.zeros((0, 3), dtype=torch.float32) - if color_idxs is None: - return verts, None, None - return verts, torch.zeros((0, 3), dtype=torch.float32), None + if column_idxs.color_idxs is None: + return _VertsData(verts=verts) + return _VertsData( + verts=verts, verts_colors=torch.zeros((0, 3), dtype=torch.float32) + ) # Simple case where the only data is the vertices themselves if ( @@ -888,7 +918,7 @@ def _get_verts( and vertex[0].ndim == 2 and vertex[0].shape[1] == 3 ): - return _make_tensor(vertex[0], cols=3, dtype=torch.float32), None, None + return _VertsData(verts=_make_tensor(vertex[0], cols=3, dtype=torch.float32)) vertex_colors = None vertex_normals = None @@ -896,14 +926,14 @@ def _get_verts( if len(vertex) == 1: # This is the case where the whole vertex element has one type, # so it was read as a single array and we can index straight into it. - verts = torch.tensor(vertex[0][:, point_idxs], dtype=torch.float32) - if color_idxs is not None: - vertex_colors = color_scale * torch.tensor( - vertex[0][:, color_idxs], dtype=torch.float32 + verts = torch.tensor(vertex[0][:, column_idxs.point_idxs], dtype=torch.float32) + if column_idxs.color_idxs is not None: + vertex_colors = column_idxs.color_scale * torch.tensor( + vertex[0][:, column_idxs.color_idxs], dtype=torch.float32 ) - if normal_idxs is not None: + if column_idxs.normal_idxs is not None: vertex_normals = torch.tensor( - vertex[0][:, normal_idxs], dtype=torch.float32 + vertex[0][:, column_idxs.normal_idxs], dtype=torch.float32 ) else: # The vertex element is heterogeneous. It was read as several arrays, @@ -918,7 +948,7 @@ def _get_verts( ] verts = torch.empty(size=(vertex_head.count, 3), dtype=torch.float32) for axis in range(3): - partnum, col = prop_to_partnum_col[point_idxs[axis]] + partnum, col = prop_to_partnum_col[column_idxs.point_idxs[axis]] verts.numpy()[:, axis] = vertex[partnum][:, col] # Note that in the previous line, we made the assignment # as numpy arrays by casting verts. If we took the (more @@ -928,30 +958,49 @@ def _get_verts( # if not vertex[partnum].flags["C_CONTIGUOUS"]: # vertex[partnum] = np.ascontiguousarray(vertex[partnum]) # verts[:, axis] = torch.tensor((vertex[partnum][:, col])) - if color_idxs is not None: + if column_idxs.color_idxs is not None: vertex_colors = torch.empty( size=(vertex_head.count, 3), dtype=torch.float32 ) for color in range(3): - partnum, col = prop_to_partnum_col[color_idxs[color]] + partnum, col = prop_to_partnum_col[column_idxs.color_idxs[color]] vertex_colors.numpy()[:, color] = vertex[partnum][:, col] - vertex_colors *= color_scale - if normal_idxs is not None: + vertex_colors *= column_idxs.color_scale + if column_idxs.normal_idxs is not None: vertex_normals = torch.empty( size=(vertex_head.count, 3), dtype=torch.float32 ) for axis in range(3): - partnum, col = prop_to_partnum_col[normal_idxs[axis]] + partnum, col = prop_to_partnum_col[column_idxs.normal_idxs[axis]] vertex_normals.numpy()[:, axis] = vertex[partnum][:, col] - return verts, vertex_colors, vertex_normals + return _VertsData( + verts=verts, + verts_colors=vertex_colors, + verts_normals=vertex_normals, + ) -def _load_ply( - f, *, path_manager: PathManager -) -> Tuple[ - torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor] -]: +@dataclass(frozen=True) +class _PlyData: + """ + Contains the data from a PLY file which has been read. + Members: + header: _PlyHeader of file metadata from the header + verts: FloatTensor of shape (V, 3). + faces: None or LongTensor of vertex indices, shape (F, 3). + verts_colors: None or FloatTensor of shape (V, 3). + verts_normals: None or FloatTensor of shape (V, 3). + """ + + header: _PlyHeader + verts: torch.Tensor + faces: Optional[torch.Tensor] + verts_colors: Optional[torch.Tensor] + verts_normals: Optional[torch.Tensor] + + +def _load_ply(f, *, path_manager: PathManager) -> _PlyData: """ Load the data from a .ply file. @@ -964,14 +1013,11 @@ def _load_ply( path_manager: PathManager for loading if f is a str. Returns: - verts: FloatTensor of shape (V, 3). - faces: None or LongTensor of vertex indices, shape (F, 3). - vertex_colors: None or FloatTensor of shape (V, 3). - vertex_normals: None or FloatTensor of shape (V, 3). + _PlyData object """ header, elements = _load_ply_raw(f, path_manager=path_manager) - verts, vertex_colors, vertex_normals = _get_verts(header, elements) + verts_data = _get_verts(header, elements) face = elements.get("face", None) if face is not None: @@ -1007,9 +1053,9 @@ def _load_ply( faces = torch.tensor(face_list, dtype=torch.int64) if faces is not None: - _check_faces_indices(faces, max_index=verts.shape[0]) + _check_faces_indices(faces, max_index=verts_data.verts.shape[0]) - return verts, faces, vertex_colors, vertex_normals + return _PlyData(**asdict(verts_data), faces=faces, header=header) def load_ply( @@ -1064,11 +1110,12 @@ def load_ply( if path_manager is None: path_manager = PathManager() - verts, faces, _, _ = _load_ply(f, path_manager=path_manager) + data = _load_ply(f, path_manager=path_manager) + faces = data.faces if faces is None: faces = torch.zeros(0, 3, dtype=torch.int64) - return verts, faces + return data.verts, faces def _write_ply_header( @@ -1305,20 +1352,20 @@ class MeshPlyFormat(MeshFormatInterpreter): if not endswith(path, self.known_suffixes): return None - verts, faces, verts_colors, verts_normals = _load_ply( - f=path, path_manager=path_manager - ) + data = _load_ply(f=path, path_manager=path_manager) + faces = data.faces if faces is None: faces = torch.zeros(0, 3, dtype=torch.int64) texture = None - if include_textures and verts_colors is not None: - texture = TexturesVertex([verts_colors.to(device)]) + if include_textures and data.verts_colors is not None: + texture = TexturesVertex([data.verts_colors.to(device)]) - if verts_normals is not None: - verts_normals = [verts_normals] + verts_normals = None + if data.verts_normals is not None: + verts_normals = [data.verts_normals.to(device)] mesh = Meshes( - verts=[verts.to(device)], + verts=[data.verts.to(device)], faces=[faces.to(device)], textures=texture, verts_normals=verts_normals, @@ -1392,14 +1439,17 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter): if not endswith(path, self.known_suffixes): return None - verts, faces, features, normals = _load_ply(f=path, path_manager=path_manager) - verts = verts.to(device) - if features is not None: - features = [features.to(device)] - if normals is not None: - normals = [normals.to(device)] + data = _load_ply(f=path, path_manager=path_manager) + features = None + if data.verts_colors is not None: + features = [data.verts_colors.to(device)] + normals = None + if data.verts_normals is not None: + normals = [data.verts_normals.to(device)] - pointcloud = Pointclouds(points=[verts], features=features, normals=normals) + pointcloud = Pointclouds( + points=[data.verts.to(device)], features=features, normals=normals + ) return pointcloud def save(