From 1d43251391c4c05ce88670d90dd4ef25bc1c2cb7 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Fri, 10 Jun 2022 12:22:46 -0700 Subject: [PATCH] PathManagerFactory Summary: Allow access to manifold internally by default. Reviewed By: davnov134 Differential Revision: D36760481 fbshipit-source-id: 2a16bd40e81ef526085ac1b3f4606b63c1841428 --- .../implicitron_trainer/tests/experiment.yaml | 4 +- .../tests/test_experiment.py | 34 --------------- .../dataset/dataset_map_provider.py | 41 ++++++++++++++++++- .../json_index_dataset_map_provider.py | 24 +++++++---- pytorch3d/implicitron/eval_demo.py | 3 -- tests/implicitron/data/data_source.yaml | 4 +- tests/implicitron/test_eval_demo.py | 7 +--- 7 files changed, 65 insertions(+), 52 deletions(-) diff --git a/projects/implicitron_trainer/tests/experiment.yaml b/projects/implicitron_trainer/tests/experiment.yaml index 640fa31b..95e3837a 100644 --- a/projects/implicitron_trainer/tests/experiment.yaml +++ b/projects/implicitron_trainer/tests/experiment.yaml @@ -303,7 +303,9 @@ data_source_args: image_width: 800 image_height: 800 remove_empty_masks: true - path_manager: null + path_manager_factory_class_type: PathManagerFactory + path_manager_factory_PathManagerFactory_args: + silence_logs: true data_loader_map_provider_SequenceDataLoaderMapProvider_args: batch_size: 1 num_workers: 0 diff --git a/projects/implicitron_trainer/tests/test_experiment.py b/projects/implicitron_trainer/tests/test_experiment.py index a7052aef..73e3b995 100644 --- a/projects/implicitron_trainer/tests/test_experiment.py +++ b/projects/implicitron_trainer/tests/test_experiment.py @@ -4,18 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import logging import os import unittest from pathlib import Path import experiment import torch -from iopath.common.file_io import PathManager from omegaconf import OmegaConf -from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import ( - JsonIndexDatasetMapProvider, -) def interactive_testing_requested() -> bool: @@ -38,38 +33,9 @@ DEBUG: bool = False # - deal with the temporary output files this test creates -def get_path_manager(silence_logs: bool = False) -> PathManager: - """ - Returns a path manager which can access manifold internally. - - Args: - silence_logs: Whether to reduce log output from iopath library. - """ - if silence_logs: - logging.getLogger("iopath.fb.manifold").setLevel(logging.CRITICAL) - logging.getLogger("iopath.common.file_io").setLevel(logging.CRITICAL) - - if os.environ.get("INSIDE_RE_WORKER", False): - raise ValueError("Cannot get to manifold from RE") - - path_manager = PathManager() - - if os.environ.get("FB_TEST", False): - from iopath.fb.manifold import ManifoldPathHandler - - path_manager.register_handler(ManifoldPathHandler()) - - return path_manager - - -def set_path_manager(self): - self.path_manager = get_path_manager() - - class TestExperiment(unittest.TestCase): def setUp(self): self.maxDiff = None - JsonIndexDatasetMapProvider.__post_init__ = set_path_manager def test_from_defaults(self): # Test making minimal changes to the dataclass defaults. diff --git a/pytorch3d/implicitron/dataset/dataset_map_provider.py b/pytorch3d/implicitron/dataset/dataset_map_provider.py index 810ce234..72b60535 100644 --- a/pytorch3d/implicitron/dataset/dataset_map_provider.py +++ b/pytorch3d/implicitron/dataset/dataset_map_provider.py @@ -4,11 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging +import os from dataclasses import dataclass from enum import Enum from typing import Iterator, Optional -from pytorch3d.implicitron.tools.config import ReplaceableBase +from iopath.common.file_io import PathManager +from pytorch3d.implicitron.tools.config import registry, ReplaceableBase from .dataset_base import DatasetBase @@ -69,3 +72,39 @@ class DatasetMapProviderBase(ReplaceableBase): def get_task(self) -> Task: raise NotImplementedError() + + +@registry.register +class PathManagerFactory(ReplaceableBase): + """ + Base class and default implementation of a tool which dataset_map_provider implementations + may use to construct a path manager if needed. + + Args: + silence_logs: Whether to reduce log output from iopath library. + """ + + silence_logs: bool = True + + def get(self) -> Optional[PathManager]: + """ + Makes a PathManager if needed. + For open source users, this function should always return None. + Internally, this allows manifold access. + """ + if os.environ.get("INSIDE_RE_WORKER", False): + return None + + try: + from iopath.fb.manifold import ManifoldPathHandler + except ImportError: + return None + + if self.silence_logs: + logging.getLogger("iopath.fb.manifold").setLevel(logging.CRITICAL) + logging.getLogger("iopath.common.file_io").setLevel(logging.CRITICAL) + + path_manager = PathManager() + path_manager.register_handler(ManifoldPathHandler()) + + return path_manager diff --git a/pytorch3d/implicitron/dataset/json_index_dataset_map_provider.py b/pytorch3d/implicitron/dataset/json_index_dataset_map_provider.py index 461cb8b5..a6d4ba40 100644 --- a/pytorch3d/implicitron/dataset/json_index_dataset_map_provider.py +++ b/pytorch3d/implicitron/dataset/json_index_dataset_map_provider.py @@ -11,9 +11,14 @@ from dataclasses import field from typing import Any, Dict, List, Sequence from omegaconf import DictConfig -from pytorch3d.implicitron.tools.config import registry +from pytorch3d.implicitron.tools.config import registry, run_auto_creation -from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task +from .dataset_map_provider import ( + DatasetMap, + DatasetMapProviderBase, + PathManagerFactory, + Task, +) from .json_index_dataset import JsonIndexDataset from .utils import ( DATASET_TYPE_KNOWN, @@ -87,7 +92,6 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13] only_test_set: Load only the test set. aux_dataset_kwargs: Specifies additional arguments to the JsonIndexDataset constructor call. - path_manager: Optional[PathManager] for interpreting paths """ category: str @@ -105,12 +109,18 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13] assert_single_seq: bool = False only_test_set: bool = False aux_dataset_kwargs: DictConfig = field(default_factory=_make_default_config) - path_manager: Any = None + path_manager_factory: PathManagerFactory + path_manager_factory_class_type: str = "PathManagerFactory" + + def __post_init__(self): + run_auto_creation(self) def get_dataset_map(self) -> DatasetMap: if self.only_test_set and self.test_on_train: raise ValueError("Cannot have only_test_set and test_on_train") + path_manager = self.path_manager_factory.get() + # TODO: # - implement loading multiple categories @@ -130,7 +140,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13] "load_point_clouds": self.load_point_clouds, "mask_images": self.mask_images, "mask_depths": self.mask_depths, - "path_manager": self.path_manager, + "path_manager": path_manager, "frame_annotations_file": frame_file, "sequence_annotations_file": sequence_file, "subset_lists_file": subset_lists_file, @@ -151,8 +161,8 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13] self.category, f"eval_batches_{self.task_str}.json", ) - if self.path_manager is not None: - batch_indices_path = self.path_manager.get_local_path(batch_indices_path) + if path_manager is not None: + batch_indices_path = path_manager.get_local_path(batch_indices_path) if not os.path.isfile(batch_indices_path): # The batch indices file does not exist. # Most probably the user has not specified the root folder. diff --git a/pytorch3d/implicitron/eval_demo.py b/pytorch3d/implicitron/eval_demo.py index 54267c42..1cb53368 100644 --- a/pytorch3d/implicitron/eval_demo.py +++ b/pytorch3d/implicitron/eval_demo.py @@ -11,7 +11,6 @@ from typing import Any, cast, Dict, List, Optional, Tuple import lpips import torch -from iopath.common.file_io import PathManager from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset @@ -80,7 +79,6 @@ def evaluate_dbir_for_category( bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0), single_sequence_id: Optional[int] = None, num_workers: int = 16, - path_manager: Optional[PathManager] = None, ): """ Evaluates new view synthesis metrics of a simple depth-based image rendering @@ -110,7 +108,6 @@ def evaluate_dbir_for_category( "test_on_train": False, "load_point_clouds": True, "test_restrict_sequence_id": single_sequence_id, - "path_manager": path_manager, } data_source = ImplicitronDataSource( dataset_map_provider_JsonIndexDatasetMapProvider_args=dataset_map_provider_args diff --git a/tests/implicitron/data/data_source.yaml b/tests/implicitron/data/data_source.yaml index e71dac1f..d7e3c5db 100644 --- a/tests/implicitron/data/data_source.yaml +++ b/tests/implicitron/data/data_source.yaml @@ -21,7 +21,9 @@ dataset_map_provider_JsonIndexDatasetMapProvider_args: image_width: 800 image_height: 800 remove_empty_masks: true - path_manager: null + path_manager_factory_class_type: PathManagerFactory + path_manager_factory_PathManagerFactory_args: + silence_logs: true data_loader_map_provider_SequenceDataLoaderMapProvider_args: batch_size: 1 num_workers: 0 diff --git a/tests/implicitron/test_eval_demo.py b/tests/implicitron/test_eval_demo.py index 390398a2..37df395d 100644 --- a/tests/implicitron/test_eval_demo.py +++ b/tests/implicitron/test_eval_demo.py @@ -11,7 +11,7 @@ from pytorch3d.implicitron import eval_demo from tests.common_testing import interactive_testing_requested -from .common_resources import CO3D_MANIFOLD_PATH, get_path_manager +from .common_resources import CO3D_MANIFOLD_PATH """ This test runs a single sequence eval_demo, useful for debugging datasets. @@ -25,8 +25,5 @@ class TestEvalDemo(unittest.TestCase): return os.environ["CO3D_DATASET_ROOT"] = CO3D_MANIFOLD_PATH - path_manager = get_path_manager(silence_logs=True) - eval_demo.evaluate_dbir_for_category( - "donut", single_sequence_id=0, path_manager=path_manager - ) + eval_demo.evaluate_dbir_for_category("donut", single_sequence_id=0)