PathManagerFactory

Summary: Allow access to manifold internally by default.

Reviewed By: davnov134

Differential Revision: D36760481

fbshipit-source-id: 2a16bd40e81ef526085ac1b3f4606b63c1841428
This commit is contained in:
Jeremy Reizenstein 2022-06-10 12:22:46 -07:00 committed by Facebook GitHub Bot
parent 1fb268dea6
commit 1d43251391
7 changed files with 65 additions and 52 deletions

View File

@ -303,7 +303,9 @@ data_source_args:
image_width: 800 image_width: 800
image_height: 800 image_height: 800
remove_empty_masks: true remove_empty_masks: true
path_manager: null path_manager_factory_class_type: PathManagerFactory
path_manager_factory_PathManagerFactory_args:
silence_logs: true
data_loader_map_provider_SequenceDataLoaderMapProvider_args: data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 1 batch_size: 1
num_workers: 0 num_workers: 0

View File

@ -4,18 +4,13 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import logging
import os import os
import unittest import unittest
from pathlib import Path from pathlib import Path
import experiment import experiment
import torch import torch
from iopath.common.file_io import PathManager
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
JsonIndexDatasetMapProvider,
)
def interactive_testing_requested() -> bool: def interactive_testing_requested() -> bool:
@ -38,38 +33,9 @@ DEBUG: bool = False
# - deal with the temporary output files this test creates # - deal with the temporary output files this test creates
def get_path_manager(silence_logs: bool = False) -> PathManager:
"""
Returns a path manager which can access manifold internally.
Args:
silence_logs: Whether to reduce log output from iopath library.
"""
if silence_logs:
logging.getLogger("iopath.fb.manifold").setLevel(logging.CRITICAL)
logging.getLogger("iopath.common.file_io").setLevel(logging.CRITICAL)
if os.environ.get("INSIDE_RE_WORKER", False):
raise ValueError("Cannot get to manifold from RE")
path_manager = PathManager()
if os.environ.get("FB_TEST", False):
from iopath.fb.manifold import ManifoldPathHandler
path_manager.register_handler(ManifoldPathHandler())
return path_manager
def set_path_manager(self):
self.path_manager = get_path_manager()
class TestExperiment(unittest.TestCase): class TestExperiment(unittest.TestCase):
def setUp(self): def setUp(self):
self.maxDiff = None self.maxDiff = None
JsonIndexDatasetMapProvider.__post_init__ = set_path_manager
def test_from_defaults(self): def test_from_defaults(self):
# Test making minimal changes to the dataclass defaults. # Test making minimal changes to the dataclass defaults.

View File

@ -4,11 +4,14 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import logging
import os
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Iterator, Optional from typing import Iterator, Optional
from pytorch3d.implicitron.tools.config import ReplaceableBase from iopath.common.file_io import PathManager
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from .dataset_base import DatasetBase from .dataset_base import DatasetBase
@ -69,3 +72,39 @@ class DatasetMapProviderBase(ReplaceableBase):
def get_task(self) -> Task: def get_task(self) -> Task:
raise NotImplementedError() raise NotImplementedError()
@registry.register
class PathManagerFactory(ReplaceableBase):
"""
Base class and default implementation of a tool which dataset_map_provider implementations
may use to construct a path manager if needed.
Args:
silence_logs: Whether to reduce log output from iopath library.
"""
silence_logs: bool = True
def get(self) -> Optional[PathManager]:
"""
Makes a PathManager if needed.
For open source users, this function should always return None.
Internally, this allows manifold access.
"""
if os.environ.get("INSIDE_RE_WORKER", False):
return None
try:
from iopath.fb.manifold import ManifoldPathHandler
except ImportError:
return None
if self.silence_logs:
logging.getLogger("iopath.fb.manifold").setLevel(logging.CRITICAL)
logging.getLogger("iopath.common.file_io").setLevel(logging.CRITICAL)
path_manager = PathManager()
path_manager.register_handler(ManifoldPathHandler())
return path_manager

View File

@ -11,9 +11,14 @@ from dataclasses import field
from typing import Any, Dict, List, Sequence from typing import Any, Dict, List, Sequence
from omegaconf import DictConfig from omegaconf import DictConfig
from pytorch3d.implicitron.tools.config import registry from pytorch3d.implicitron.tools.config import registry, run_auto_creation
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task from .dataset_map_provider import (
DatasetMap,
DatasetMapProviderBase,
PathManagerFactory,
Task,
)
from .json_index_dataset import JsonIndexDataset from .json_index_dataset import JsonIndexDataset
from .utils import ( from .utils import (
DATASET_TYPE_KNOWN, DATASET_TYPE_KNOWN,
@ -87,7 +92,6 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
only_test_set: Load only the test set. only_test_set: Load only the test set.
aux_dataset_kwargs: Specifies additional arguments to the aux_dataset_kwargs: Specifies additional arguments to the
JsonIndexDataset constructor call. JsonIndexDataset constructor call.
path_manager: Optional[PathManager] for interpreting paths
""" """
category: str category: str
@ -105,12 +109,18 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
assert_single_seq: bool = False assert_single_seq: bool = False
only_test_set: bool = False only_test_set: bool = False
aux_dataset_kwargs: DictConfig = field(default_factory=_make_default_config) aux_dataset_kwargs: DictConfig = field(default_factory=_make_default_config)
path_manager: Any = None path_manager_factory: PathManagerFactory
path_manager_factory_class_type: str = "PathManagerFactory"
def __post_init__(self):
run_auto_creation(self)
def get_dataset_map(self) -> DatasetMap: def get_dataset_map(self) -> DatasetMap:
if self.only_test_set and self.test_on_train: if self.only_test_set and self.test_on_train:
raise ValueError("Cannot have only_test_set and test_on_train") raise ValueError("Cannot have only_test_set and test_on_train")
path_manager = self.path_manager_factory.get()
# TODO: # TODO:
# - implement loading multiple categories # - implement loading multiple categories
@ -130,7 +140,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
"load_point_clouds": self.load_point_clouds, "load_point_clouds": self.load_point_clouds,
"mask_images": self.mask_images, "mask_images": self.mask_images,
"mask_depths": self.mask_depths, "mask_depths": self.mask_depths,
"path_manager": self.path_manager, "path_manager": path_manager,
"frame_annotations_file": frame_file, "frame_annotations_file": frame_file,
"sequence_annotations_file": sequence_file, "sequence_annotations_file": sequence_file,
"subset_lists_file": subset_lists_file, "subset_lists_file": subset_lists_file,
@ -151,8 +161,8 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
self.category, self.category,
f"eval_batches_{self.task_str}.json", f"eval_batches_{self.task_str}.json",
) )
if self.path_manager is not None: if path_manager is not None:
batch_indices_path = self.path_manager.get_local_path(batch_indices_path) batch_indices_path = path_manager.get_local_path(batch_indices_path)
if not os.path.isfile(batch_indices_path): if not os.path.isfile(batch_indices_path):
# The batch indices file does not exist. # The batch indices file does not exist.
# Most probably the user has not specified the root folder. # Most probably the user has not specified the root folder.

View File

@ -11,7 +11,6 @@ from typing import Any, cast, Dict, List, Optional, Tuple
import lpips import lpips
import torch import torch
from iopath.common.file_io import PathManager
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
@ -80,7 +79,6 @@ def evaluate_dbir_for_category(
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0), bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0),
single_sequence_id: Optional[int] = None, single_sequence_id: Optional[int] = None,
num_workers: int = 16, num_workers: int = 16,
path_manager: Optional[PathManager] = None,
): ):
""" """
Evaluates new view synthesis metrics of a simple depth-based image rendering Evaluates new view synthesis metrics of a simple depth-based image rendering
@ -110,7 +108,6 @@ def evaluate_dbir_for_category(
"test_on_train": False, "test_on_train": False,
"load_point_clouds": True, "load_point_clouds": True,
"test_restrict_sequence_id": single_sequence_id, "test_restrict_sequence_id": single_sequence_id,
"path_manager": path_manager,
} }
data_source = ImplicitronDataSource( data_source = ImplicitronDataSource(
dataset_map_provider_JsonIndexDatasetMapProvider_args=dataset_map_provider_args dataset_map_provider_JsonIndexDatasetMapProvider_args=dataset_map_provider_args

View File

@ -21,7 +21,9 @@ dataset_map_provider_JsonIndexDatasetMapProvider_args:
image_width: 800 image_width: 800
image_height: 800 image_height: 800
remove_empty_masks: true remove_empty_masks: true
path_manager: null path_manager_factory_class_type: PathManagerFactory
path_manager_factory_PathManagerFactory_args:
silence_logs: true
data_loader_map_provider_SequenceDataLoaderMapProvider_args: data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 1 batch_size: 1
num_workers: 0 num_workers: 0

View File

@ -11,7 +11,7 @@ from pytorch3d.implicitron import eval_demo
from tests.common_testing import interactive_testing_requested from tests.common_testing import interactive_testing_requested
from .common_resources import CO3D_MANIFOLD_PATH, get_path_manager from .common_resources import CO3D_MANIFOLD_PATH
""" """
This test runs a single sequence eval_demo, useful for debugging datasets. This test runs a single sequence eval_demo, useful for debugging datasets.
@ -25,8 +25,5 @@ class TestEvalDemo(unittest.TestCase):
return return
os.environ["CO3D_DATASET_ROOT"] = CO3D_MANIFOLD_PATH os.environ["CO3D_DATASET_ROOT"] = CO3D_MANIFOLD_PATH
path_manager = get_path_manager(silence_logs=True)
eval_demo.evaluate_dbir_for_category( eval_demo.evaluate_dbir_for_category("donut", single_sequence_id=0)
"donut", single_sequence_id=0, path_manager=path_manager
)