mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-02-27 08:46:00 +08:00
Summary: Remove the dataset's need to provide the task type. Reviewed By: davnov134, kjchalup Differential Revision: D38314000 fbshipit-source-id: 3805d885b5d4528abdc78c0da03247edb9abf3f7
78 lines
2.9 KiB
Python
78 lines
2.9 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from typing import Optional, Tuple
|
|
|
|
from pytorch3d.implicitron.tools.config import (
|
|
registry,
|
|
ReplaceableBase,
|
|
run_auto_creation,
|
|
)
|
|
from pytorch3d.renderer.cameras import CamerasBase
|
|
|
|
from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa
|
|
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
|
|
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase
|
|
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
|
|
from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa
|
|
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa
|
|
from .rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider # noqa
|
|
|
|
|
|
class DataSourceBase(ReplaceableBase):
|
|
"""
|
|
Base class for a data source in Implicitron. It encapsulates Dataset
|
|
and DataLoader configuration.
|
|
"""
|
|
|
|
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
|
|
raise NotImplementedError()
|
|
|
|
@property
|
|
def all_train_cameras(self) -> Optional[CamerasBase]:
|
|
"""
|
|
If the data is all for a single scene, a list
|
|
of the known training cameras for that scene, which is
|
|
used for evaluating the viewpoint difficulty of the
|
|
unseen cameras.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
|
|
@registry.register
|
|
class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
|
|
"""
|
|
Represents the data used in Implicitron. This is the only implementation
|
|
of DataSourceBase provided.
|
|
|
|
Members:
|
|
dataset_map_provider_class_type: identifies type for dataset_map_provider.
|
|
e.g. JsonIndexDatasetMapProvider for Co3D.
|
|
data_loader_map_provider_class_type: identifies type for data_loader_map_provider.
|
|
"""
|
|
|
|
dataset_map_provider: DatasetMapProviderBase
|
|
dataset_map_provider_class_type: str
|
|
data_loader_map_provider: DataLoaderMapProviderBase
|
|
data_loader_map_provider_class_type: str = "SequenceDataLoaderMapProvider"
|
|
|
|
def __post_init__(self):
|
|
run_auto_creation(self)
|
|
self._all_train_cameras_cache: Optional[Tuple[Optional[CamerasBase]]] = None
|
|
|
|
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
|
|
datasets = self.dataset_map_provider.get_dataset_map()
|
|
dataloaders = self.data_loader_map_provider.get_data_loader_map(datasets)
|
|
return datasets, dataloaders
|
|
|
|
@property
|
|
def all_train_cameras(self) -> Optional[CamerasBase]:
|
|
if self._all_train_cameras_cache is None: # pyre-ignore[16]
|
|
all_train_cameras = self.dataset_map_provider.get_all_train_cameras()
|
|
self._all_train_cameras_cache = (all_train_cameras,)
|
|
|
|
return self._all_train_cameras_cache[0]
|