diff --git a/docs/notes/io.md b/docs/notes/io.md new file mode 100644 index 00000000..a6c17793 --- /dev/null +++ b/docs/notes/io.md @@ -0,0 +1,24 @@ +--- +hide_title: true +sidebar_label: File IO +--- + +# File IO +There is a flexible interface for loading and saving point clouds and meshes from different formats. + +The main usage is via the `pytorch3d.io.IO` object, and its methods +`load_mesh`, `save_mesh`, `load_point_cloud` and `save_point_cloud`. + +For example, to load a mesh you might do +``` +from pytorch3d.io import IO + +device=torch.device("cuda:0") +mesh = IO().load_mesh("mymesh.ply", device=device) +``` + +and to save a pointcloud you might do +``` +pcl = Pointclouds(...) +IO().save_point_cloud(pcl, "output_poincloud.obj") +``` diff --git a/pytorch3d/io/__init__.py b/pytorch3d/io/__init__.py index 388fd126..b3fe212b 100644 --- a/pytorch3d/io/__init__.py +++ b/pytorch3d/io/__init__.py @@ -2,6 +2,7 @@ from .obj_io import load_obj, load_objs_as_meshes, save_obj +from .pluggable import IO from .ply_io import load_ply, save_ply diff --git a/pytorch3d/io/pluggable.py b/pytorch3d/io/pluggable.py new file mode 100644 index 00000000..5c03dd4a --- /dev/null +++ b/pytorch3d/io/pluggable.py @@ -0,0 +1,208 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +from collections import deque +from pathlib import Path +from typing import Deque, Optional, Union + +from iopath.common.file_io import PathManager +from pytorch3d.structures import Meshes, Pointclouds + +from .pluggable_formats import MeshFormatInterpreter, PointcloudFormatInterpreter + + +""" +This module has the master functions for loading and saving data. + +The main usage is via the IO object, and its methods +`load_mesh`, `save_mesh`, `load_pointcloud` and `save_pointcloud`. + +For example, to load a mesh you might do +``` +from pytorch3d.io import IO + +mesh = IO().load_mesh("mymesh.obj") +``` + +and to save a point cloud you might do + +``` +pcl = Pointclouds(...) +IO().save_pointcloud(pcl, "output_poincloud.obj") +``` + +""" + + +class IO: + """ + This class is the interface to flexible loading and saving of meshes and point clouds. + + In simple cases the user will just initialise an instance of this class as `IO()` + and then use its load and save functions. The arguments of the initializer are not + usually needed. + + The user can add their own formats for saving and loading by passing their own objects + to the register_* functions. + + Args: + include_default_formats: If False, the built-in file formats will not be available. + Then only user-registered formats can be used. + path_manager: Used to customise how paths given as strings are interpreted. + """ + + def __init__( + self, + include_default_formats: bool = True, + path_manager: Optional[PathManager] = None, + ): + if path_manager is None: + self.path_manager = PathManager() + else: + self.path_manager = path_manager + + self.mesh_interpreters: Deque[MeshFormatInterpreter] = deque() + self.pointcloud_interpreters: Deque[PointcloudFormatInterpreter] = deque() + + if include_default_formats: + self.register_default_formats() + + def register_default_formats(self) -> None: + # This will be populated in later diffs + pass + + def register_meshes_format(self, interpreter: MeshFormatInterpreter) -> None: + """ + Register a new interpreter for a new mesh file format. + + Args: + interpreter: the new interpreter to use, which must be an instance + of a class which inherits MeshFormatInterpreter. + """ + self.mesh_interpreters.appendleft(interpreter) + + def register_pointcloud_format( + self, interpreter: PointcloudFormatInterpreter + ) -> None: + """ + Register a new interpreter for a new point cloud file format. + + Args: + interpreter: the new interpreter to use, which must be an instance + of a class which inherits PointcloudFormatInterpreter. + """ + self.pointcloud_interpreters.appendleft(interpreter) + + def load_mesh( + self, + path: Union[str, Path], + include_textures: bool = True, + device="cpu", + **kwargs, + ) -> Meshes: + """ + Attempt to load a mesh from the given file, using a registered format. + Materials are not returned. If you have a .obj file with materials + you might want to load them with the load_obj function instead. + + Args: + path: file to read + include_textures: whether to try to load texture information + device: device on which to leave the data. + + Returns: + new Meshes object containing one mesh. + """ + for mesh_interpreter in self.mesh_interpreters: + mesh = mesh_interpreter.read( + path, + include_textures=include_textures, + path_manager=self.path_manager, + device=device, + **kwargs, + ) + if mesh is not None: + return mesh + + raise ValueError(f"No mesh interpreter found to read {path}.") + + def save_mesh( + self, + data: Meshes, + path: Union[str, Path], + binary: Optional[bool] = None, + include_textures: bool = True, + **kwargs, + ) -> None: + """ + Attempt to save a mesh to the given file, using a registered format. + + Args: + data: a 1-element Meshes + path: file to write + binary: If there is a choice, whether to save in a binary format. + include_textures: If textures are present, whether to try to save + them. + """ + if len(data) != 1: + raise ValueError("Can only save a single mesh.") + + for mesh_interpreter in self.mesh_interpreters: + success = mesh_interpreter.save( + data, path, path_manager=self.path_manager, binary=binary, **kwargs + ) + if success: + return + + raise ValueError(f"No mesh interpreter found to write to {path}.") + + def load_pointcloud( + self, path: Union[str, Path], device="cpu", **kwargs + ) -> Pointclouds: + """ + Attempt to load a point cloud from the given file, using a registered format. + + Args: + path: file to read + device: torch.device on which to load the data. + + Returns: + new Pointclouds object containing one mesh. + """ + for pointcloud_interpreter in self.pointcloud_interpreters: + pointcloud = pointcloud_interpreter.read( + path, path_manager=self.path_manager, device=device, **kwargs + ) + if pointcloud is not None: + return pointcloud + + raise ValueError(f"No point cloud interpreter found to read {path}.") + + def save_pointcloud( + self, + data: Pointclouds, + path: Union[str, Path], + binary: Optional[bool] = None, + **kwargs, + ) -> None: + """ + Attempt to save a point cloud to the given file, using a registered format. + + Args: + data: a 1-element Pointclouds + path: file to write + binary: If there is a choice, whether to save in a binary format. + """ + if len(data) != 1: + raise ValueError("Can only save a single point cloud.") + + for pointcloud_interpreter in self.pointcloud_interpreters: + success = pointcloud_interpreter.save( + data, path, path_manager=self.path_manager, binary=binary, **kwargs + ) + if success: + return + + raise ValueError(f"No point cloud interpreter found to write to {path}.") diff --git a/pytorch3d/io/pluggable_formats.py b/pytorch3d/io/pluggable_formats.py new file mode 100644 index 00000000..968ac7c6 --- /dev/null +++ b/pytorch3d/io/pluggable_formats.py @@ -0,0 +1,136 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +from pathlib import Path +from typing import Optional, Tuple, Union + +from iopath.common.file_io import PathManager +from pytorch3d.structures import Meshes, Pointclouds + + +""" +This module has the base classes which must be extended to define +an interpreter for loading and saving data in a particular format. +These can be registered on an IO object so that they can be used in +its load_* and save_* functions. +""" + + +def endswith(path, suffixes: Tuple[str, ...]) -> bool: + """ + Returns whether the path ends with one of the given suffixes. + If `path` is not actually a path, returns True. This is useful + for allowing interpreters to bypass inappropriate paths, but + always accepting streams. + """ + if isinstance(path, Path): + return path.suffix.lower() in suffixes + if isinstance(path, str): + return path.lower().endswith(suffixes) + return True + + +class MeshFormatInterpreter: + """ + This is a base class for an interpreter which can read or write + a mesh in a particular format. + """ + + def read( + self, + path: Union[str, Path], + include_textures: bool, + device, + path_manager: PathManager, + **kwargs, + ) -> Optional[Meshes]: + """ + Read the data from the specified file and return it as + a Meshes object. + + Args: + path: path to load. + include_textures: whether to try to load texture information. + device: torch.device to load data on to. + path_manager: PathManager to interpret the path. + + Returns: + None if self is not the appropriate object to interpret the given + path. + Otherwise, the read Meshes object. + """ + raise NotImplementedError() + + def save( + self, + data: Meshes, + path: Union[str, Path], + path_manager: PathManager, + binary: Optional[bool], + **kwargs, + ) -> bool: + """ + Save the given Meshes object to the given path. + + Args: + data: mesh to save + path: path to save to, which may be overwritten. + path_manager: PathManager to interpret the path. + binary: If there is a choice, whether to save in a binary format. + + Returns: + False: if self is not the appropriate object to write to the given path. + True: on success. + """ + raise NotImplementedError() + + +class PointcloudFormatInterpreter: + """ + This is a base class for an interpreter which can read or write + a point cloud in a particular format. + """ + + def read( + self, path: Union[str, Path], device, path_manager: PathManager, **kwargs + ) -> Optional[Pointclouds]: + """ + Read the data from the specified file and return it as + a Pointclouds object. + + Args: + path: path to load. + device: torch.device to load data on to. + path_manager: PathManager to interpret the path. + + Returns: + None if self is not the appropriate object to interpret the given + path. + Otherwise, the read Pointclouds object. + """ + raise NotImplementedError() + + def save( + self, + data: Pointclouds, + path: Union[str, Path], + path_manager: PathManager, + binary: Optional[bool], + **kwargs, + ) -> bool: + """ + Save the given Pointclouds object to the given path. + + Args: + data: point cloud object to save + path: path to save to, which may be overwritten. + path_manager: PathManager to interpret the path. + binary: If there is a choice, whether to save in a binary format. + + Returns: + False: if self is not the appropriate object to write to the given path. + True: on success. + """ + raise NotImplementedError() diff --git a/pytorch3d/structures/pointclouds.py b/pytorch3d/structures/pointclouds.py index 4b48eb6a..87909578 100644 --- a/pytorch3d/structures/pointclouds.py +++ b/pytorch3d/structures/pointclouds.py @@ -777,7 +777,7 @@ class Pointclouds(object): returned. Returns: - list[PointClouds]. + list[Pointclouds]. """ if not all(isinstance(x, int) for x in split_sizes): raise ValueError("Value of split_sizes must be a list of integers.")