Jeremy Reizenstein 65f667fd2e loading llff and blender datasets
Summary: Copy code from NeRF for loading LLFF data and blender synthetic data, and create dataset objects for them

Reviewed By: shapovalov

Differential Revision: D35581039

fbshipit-source-id: af7a6f3e9a42499700693381b5b147c991f57e5d
2022-06-16 03:09:15 -07:00

54 lines
2.0 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple
from pytorch3d.implicitron.tools.config import ReplaceableBase, run_auto_creation
from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa
class DataSourceBase(ReplaceableBase):
"""
Base class for a data source in Implicitron. It encapsulates Dataset
and DataLoader configuration.
"""
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
raise NotImplementedError()
class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
"""
Represents the data used in Implicitron. This is the only implementation
of DataSourceBase provided.
Members:
dataset_map_provider_class_type: identifies type for dataset_map_provider.
e.g. JsonIndexDatasetMapProvider for Co3D.
data_loader_map_provider_class_type: identifies type for data_loader_map_provider.
"""
dataset_map_provider: DatasetMapProviderBase
dataset_map_provider_class_type: str
data_loader_map_provider: DataLoaderMapProviderBase
data_loader_map_provider_class_type: str = "SequenceDataLoaderMapProvider"
def __post_init__(self):
run_auto_creation(self)
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
datasets = self.dataset_map_provider.get_dataset_map()
dataloaders = self.data_loader_map_provider.get_data_loader_map(datasets)
return datasets, dataloaders
def get_task(self) -> Task:
return self.dataset_map_provider.get_task()