diff --git a/projects/implicitron_trainer/README.md b/projects/implicitron_trainer/README.md index d53aa3b9..80253a20 100644 --- a/projects/implicitron_trainer/README.md +++ b/projects/implicitron_trainer/README.md @@ -66,7 +66,7 @@ If you have a custom `experiment.py` script (as in the Option 2 above), replace To run training, pass a yaml config file, followed by a list of overridden arguments. For example, to train NeRF on the first skateboard sequence from CO3D dataset, you can run: ```shell -dataset_args=data_source_args.dataset_args +dataset_args=data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf $dataset_args.dataset_root= $dataset_args.category='skateboard' $dataset_args.test_restrict_sequence_id=0 test_when_finished=True exp_dir= ``` @@ -85,7 +85,7 @@ To run evaluation on the latest checkpoint after (or during) training, simply ad E.g. for executing the evaluation on the NeRF skateboard sequence, you can run: ```shell -dataset_args=data_source_args.dataset_args +dataset_args=data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf $dataset_args.dataset_root= $dataset_args.category='skateboard' $dataset_args.test_restrict_sequence_id=0 exp_dir= eval_only=True ``` Evaluation prints the metrics to `stdout` and dumps them to a json file in `exp_dir`. @@ -236,7 +236,7 @@ generic_model_args: GenericModel ╘== ReductionFeatureAggregator solver_args: init_optimizer data_source_args: ImplicitronDataSource -└-- dataset_args +└-- dataset_map_provider_*_args └-- dataloader_args ``` diff --git a/projects/implicitron_trainer/configs/repro_base.yaml b/projects/implicitron_trainer/configs/repro_base.yaml index be4ab289..21fc39c5 100644 --- a/projects/implicitron_trainer/configs/repro_base.yaml +++ b/projects/implicitron_trainer/configs/repro_base.yaml @@ -6,6 +6,7 @@ architecture: generic visualize_interval: 0 visdom_port: 8097 data_source_args: + dataset_provider_class_type: JsonIndexDatasetMapProvider dataloader_args: batch_size: 10 dataset_len: 1000 @@ -21,7 +22,7 @@ data_source_args: - 8 - 9 - 10 - dataset_args: + dataset_map_provider_JsonIndexDatasetMapProvider_args: dataset_root: ${oc.env:CO3D_DATASET_ROOT} load_point_clouds: false mask_depths: false diff --git a/projects/implicitron_trainer/configs/repro_multiseq_base.yaml b/projects/implicitron_trainer/configs/repro_multiseq_base.yaml index 128f29be..ae3eae32 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_base.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_base.yaml @@ -17,9 +17,9 @@ data_source_args: - 8 - 9 - 10 - dataset_args: + dataset_map_provider_JsonIndexDatasetMapProvider_args: assert_single_seq: false - dataset_name: co3d_multisequence + task_str: multisequence load_point_clouds: false mask_depths: false mask_images: false diff --git a/projects/implicitron_trainer/configs/repro_singleseq_base.yaml b/projects/implicitron_trainer/configs/repro_singleseq_base.yaml index 4e082395..1419b7e0 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_base.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_base.yaml @@ -9,8 +9,8 @@ data_source_args: num_workers: 8 images_per_seq_options: - 2 - dataset_args: - dataset_name: co3d_singlesequence + dataset_map_provider_JsonIndexDatasetMapProvider_args: + dataset_name: singlesequence assert_single_seq: true n_frames_per_sequence: -1 test_restrict_sequence_id: 0 diff --git a/projects/implicitron_trainer/experiment.py b/projects/implicitron_trainer/experiment.py index 738d848a..db6e0591 100755 --- a/projects/implicitron_trainer/experiment.py +++ b/projects/implicitron_trainer/experiment.py @@ -67,7 +67,7 @@ from pytorch3d.implicitron.dataset import utils as ds_utils from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task from pytorch3d.implicitron.dataset.dataloader_zoo import Dataloaders from pytorch3d.implicitron.dataset.dataset_base import FrameData -from pytorch3d.implicitron.dataset.dataset_zoo import Datasets +from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel @@ -552,7 +552,7 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None: def _eval_and_dump( cfg, task: Task, - datasets: Datasets, + datasets: DatasetMap, dataloaders: Dataloaders, model, stats, diff --git a/projects/implicitron_trainer/visualize_reconstruction.py b/projects/implicitron_trainer/visualize_reconstruction.py index 66946ef5..773a321d 100644 --- a/projects/implicitron_trainer/visualize_reconstruction.py +++ b/projects/implicitron_trainer/visualize_reconstruction.py @@ -24,8 +24,7 @@ import torch.nn.functional as Fu from experiment import init_model from omegaconf import OmegaConf from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource -from pytorch3d.implicitron.dataset.dataset_base import FrameData, ImplicitronDatasetBase -from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo +from pytorch3d.implicitron.dataset.dataset_base import FrameData from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset from pytorch3d.implicitron.dataset.utils import is_train_frame from pytorch3d.implicitron.models.base_model import EvaluationMode @@ -296,7 +295,7 @@ def export_scenes( output_directory: Optional[str] = None, render_size: Tuple[int, int] = (512, 512), video_size: Optional[Tuple[int, int]] = None, - split: str = "train", # train | test + split: str = "train", # train | val | test n_source_views: int = 9, n_eval_cameras: int = 40, visdom_server="http://127.0.0.1", @@ -324,14 +323,15 @@ def export_scenes( config.gpu_idx = gpu_idx config.exp_dir = exp_dir # important so that the CO3D dataset gets loaded in full - config.data_source_args.dataset_args.test_on_train = False + dataset_args = ( + config.data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args + ) + dataset_args.test_on_train = False # Set the rendering image size config.generic_model_args.render_image_width = render_size[0] config.generic_model_args.render_image_height = render_size[1] if restrict_sequence_name is not None: - config.data_source_args.dataset_args.restrict_sequence_name = ( - restrict_sequence_name - ) + dataset_args.restrict_sequence_name = restrict_sequence_name # Set up the CUDA env for the visualization os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" @@ -344,8 +344,8 @@ def export_scenes( # Setup the dataset datasource = ImplicitronDataSource(**config.data_source_args) - datasets = dataset_zoo(**datasource.dataset_args) - dataset: Optional[ImplicitronDatasetBase] = getattr(datasets, split, None) + dataset_map = datasource.dataset_map_provider.get_dataset_map() + dataset = dataset_map[split] if dataset is None: raise ValueError(f"{split} dataset not provided") diff --git a/pytorch3d/implicitron/dataset/data_source.py b/pytorch3d/implicitron/dataset/data_source.py index 89acb282..8723e089 100644 --- a/pytorch3d/implicitron/dataset/data_source.py +++ b/pytorch3d/implicitron/dataset/data_source.py @@ -4,19 +4,18 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from enum import Enum from typing import Tuple from omegaconf import DictConfig -from pytorch3d.implicitron.tools.config import get_default_args_field, ReplaceableBase +from pytorch3d.implicitron.tools.config import ( + get_default_args_field, + ReplaceableBase, + run_auto_creation, +) +from . import json_index_dataset_map_provider # noqa from .dataloader_zoo import dataloader_zoo, Dataloaders -from .dataset_zoo import dataset_zoo, Datasets - - -class Task(Enum): - SINGLE_SEQUENCE = "singlesequence" - MULTI_SEQUENCE = "multisequence" +from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task class DataSourceBase(ReplaceableBase): @@ -25,24 +24,31 @@ class DataSourceBase(ReplaceableBase): and DataLoader configuration. """ - def get_datasets_and_dataloaders(self) -> Tuple[Datasets, Dataloaders]: + def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, Dataloaders]: raise NotImplementedError() -class ImplicitronDataSource(DataSourceBase): +class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13] """ Represents the data used in Implicitron. This is the only implementation of DataSourceBase provided. + + Members: + dataset_map_provider_class_type: identifies type for dataset_map_provider. + e.g. JsonIndexDatasetMapProvider for Co3D. """ - dataset_args: DictConfig = get_default_args_field(dataset_zoo) + dataset_map_provider: DatasetMapProviderBase + dataset_map_provider_class_type: str dataloader_args: DictConfig = get_default_args_field(dataloader_zoo) - def get_datasets_and_dataloaders(self) -> Tuple[Datasets, Dataloaders]: - datasets = dataset_zoo(**self.dataset_args) + def __post_init__(self): + run_auto_creation(self) + + def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, Dataloaders]: + datasets = self.dataset_map_provider.get_dataset_map() dataloaders = dataloader_zoo(datasets, **self.dataloader_args) return datasets, dataloaders def get_task(self) -> Task: - eval_task = self.dataset_args["dataset_name"].split("_")[-1] - return Task(eval_task) + return self.dataset_map_provider.get_task() diff --git a/pytorch3d/implicitron/dataset/dataloader_zoo.py b/pytorch3d/implicitron/dataset/dataloader_zoo.py index bf7376af..58cb576e 100644 --- a/pytorch3d/implicitron/dataset/dataloader_zoo.py +++ b/pytorch3d/implicitron/dataset/dataloader_zoo.py @@ -11,7 +11,7 @@ import torch from pytorch3d.implicitron.tools.config import enable_get_default_args from .dataset_base import FrameData, ImplicitronDatasetBase -from .dataset_zoo import Datasets +from .dataset_map_provider import DatasetMap from .scene_batch_sampler import SceneBatchSampler @@ -33,7 +33,7 @@ class Dataloaders: def dataloader_zoo( - datasets: Datasets, + datasets: DatasetMap, batch_size: int = 1, num_workers: int = 0, dataset_len: int = 1000, @@ -49,7 +49,6 @@ def dataloader_zoo( Args: datasets: A dictionary containing the `"dataset_subset_name": torch_dataset_object` key, value pairs. - dataset_name: The name of the returned dataset. batch_size: The size of the batch of the dataloader. num_workers: Number data-loading threads. dataset_len: The number of batches in a training epoch. diff --git a/pytorch3d/implicitron/dataset/dataset_map_provider.py b/pytorch3d/implicitron/dataset/dataset_map_provider.py new file mode 100644 index 00000000..3177a0d7 --- /dev/null +++ b/pytorch3d/implicitron/dataset/dataset_map_provider.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from enum import Enum +from typing import Iterator, Optional + +from pytorch3d.implicitron.tools.config import ReplaceableBase + +from .dataset_base import ImplicitronDatasetBase + + +@dataclass +class DatasetMap: + """ + A collection of datasets for implicitron. + + Members: + + train: a dataset for training + val: a dataset for validating during training + test: a dataset for final evaluation + """ + + train: Optional[ImplicitronDatasetBase] + val: Optional[ImplicitronDatasetBase] + test: Optional[ImplicitronDatasetBase] + + def __getitem__(self, split: str) -> Optional[ImplicitronDatasetBase]: + """ + Get one of the datasets by key (name of data split) + """ + if split not in ["train", "val", "test"]: + raise ValueError(f"{split} was not a valid split name (train/val/test)") + return getattr(self, split) + + def iter_datasets(self) -> Iterator[ImplicitronDatasetBase]: + """ + Iterator over all datasets. + """ + if self.train is not None: + yield self.train + if self.val is not None: + yield self.val + if self.test is not None: + yield self.test + + +class Task(Enum): + SINGLE_SEQUENCE = "singlesequence" + MULTI_SEQUENCE = "multisequence" + + +class DatasetMapProviderBase(ReplaceableBase): + """ + Base class for a provider of training / validation and testing + dataset objects. + """ + + def get_dataset_map(self) -> DatasetMap: + """ + Returns: + An object containing the torch.Dataset objects in train/val/test fields. + """ + raise NotImplementedError() + + def get_task(self) -> Task: + raise NotImplementedError() diff --git a/pytorch3d/implicitron/dataset/dataset_zoo.py b/pytorch3d/implicitron/dataset/json_index_dataset_map_provider.py similarity index 58% rename from pytorch3d/implicitron/dataset/dataset_zoo.py rename to pytorch3d/implicitron/dataset/json_index_dataset_map_provider.py index 023b538b..1932697f 100644 --- a/pytorch3d/implicitron/dataset/dataset_zoo.py +++ b/pytorch3d/implicitron/dataset/json_index_dataset_map_provider.py @@ -7,13 +7,13 @@ import json import os -from dataclasses import dataclass -from typing import Any, Dict, Iterator, List, Optional, Sequence +from dataclasses import field +from typing import Any, Dict, List, Sequence -from iopath.common.file_io import PathManager -from pytorch3d.implicitron.tools.config import enable_get_default_args +from omegaconf import DictConfig +from pytorch3d.implicitron.tools.config import registry -from .dataset_base import ImplicitronDatasetBase +from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task from .implicitron_dataset import ImplicitronDataset from .utils import ( DATASET_TYPE_KNOWN, @@ -34,6 +34,11 @@ DATASET_CONFIGS: Dict[str, Dict[str, Any]] = { } } + +def _make_default_config() -> DictConfig: + return DictConfig(DATASET_CONFIGS["default"]) + + # fmt: off CO3D_CATEGORIES: List[str] = list(reversed([ "baseballbat", "banana", "bicycle", "microwave", "tv", @@ -53,59 +58,16 @@ CO3D_CATEGORIES: List[str] = list(reversed([ _CO3D_DATASET_ROOT: str = os.getenv("CO3D_DATASET_ROOT", "") -@dataclass -class Datasets: +@registry.register +class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13] """ - A provider of datasets for implicitron. - - Members: - - train: a dataset for training - val: a dataset for validating during training - test: a dataset for final evaluation - """ - - train: Optional[ImplicitronDatasetBase] - val: Optional[ImplicitronDatasetBase] - test: Optional[ImplicitronDatasetBase] - - def iter_datasets(self) -> Iterator[ImplicitronDatasetBase]: - """ - Iterator over all datasets. - """ - if self.train is not None: - yield self.train - if self.val is not None: - yield self.val - if self.test is not None: - yield self.test - - -def dataset_zoo( - dataset_name: str = "co3d_singlesequence", - dataset_root: str = _CO3D_DATASET_ROOT, - category: str = "DEFAULT", - limit_to: int = -1, - limit_sequences_to: int = -1, - n_frames_per_sequence: int = -1, - test_on_train: bool = False, - load_point_clouds: bool = False, - mask_images: bool = False, - mask_depths: bool = False, - restrict_sequence_name: Sequence[str] = (), - test_restrict_sequence_id: int = -1, - assert_single_seq: bool = False, - only_test_set: bool = False, - aux_dataset_kwargs: dict = DATASET_CONFIGS["default"], - path_manager: Optional[PathManager] = None, -) -> Datasets: - """ - Generates the training / validation and testing dataset objects. + Generates the training / validation and testing dataset objects for + a dataset laid out on disk like Co3D, with annotations in json files. Args: - dataset_name: The name of the returned dataset. - dataset_root: The root folder of the dataset. category: The object category of the dataset. + task_str: "multisequence" or "singlesequence". + dataset_root: The root folder of the dataset. limit_to: Limit the dataset to the first #limit_to frames. limit_sequences_to: Limit the dataset to the first #limit_sequences_to sequences. @@ -119,58 +81,78 @@ def dataset_zoo( restrict_sequence_name: Restrict the dataset sequences to the ones present in the given list of names. test_restrict_sequence_id: The ID of the loaded sequence. - Active for dataset_name='co3d_singlesequence'. + Active for task_str='singlesequence'. assert_single_seq: Assert that only frames from a single sequence are present in all generated datasets. only_test_set: Load only the test set. aux_dataset_kwargs: Specifies additional arguments to the ImplicitronDataset constructor call. - - Returns: - datasets: A dictionary containing the - `"dataset_subset_name": torch_dataset_object` key, value pairs. + path_manager: Optional[PathManager] for interpreting paths """ - if only_test_set and test_on_train: - raise ValueError("Cannot have only_test_set and test_on_train") - # TODO: - # - implement loading multiple categories + category: str + task_str: str = "singlesequence" + dataset_root: str = _CO3D_DATASET_ROOT + limit_to: int = -1 + limit_sequences_to: int = -1 + n_frames_per_sequence: int = -1 + test_on_train: bool = False + load_point_clouds: bool = False + mask_images: bool = False + mask_depths: bool = False + restrict_sequence_name: Sequence[str] = () + test_restrict_sequence_id: int = -1 + assert_single_seq: bool = False + only_test_set: bool = False + aux_dataset_kwargs: DictConfig = field(default_factory=_make_default_config) + path_manager: Any = None - if dataset_name in ["co3d_singlesequence", "co3d_multisequence"]: - frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz") - sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz") - subset_lists_file = os.path.join(dataset_root, category, "set_lists.json") + def get_dataset_map(self) -> DatasetMap: + if self.only_test_set and self.test_on_train: + raise ValueError("Cannot have only_test_set and test_on_train") + + # TODO: + # - implement loading multiple categories + + frame_file = os.path.join( + self.dataset_root, self.category, "frame_annotations.jgz" + ) + sequence_file = os.path.join( + self.dataset_root, self.category, "sequence_annotations.jgz" + ) + subset_lists_file = os.path.join( + self.dataset_root, self.category, "set_lists.json" + ) common_kwargs = { - "dataset_root": dataset_root, - "limit_to": limit_to, - "limit_sequences_to": limit_sequences_to, - "load_point_clouds": load_point_clouds, - "mask_images": mask_images, - "mask_depths": mask_depths, - "path_manager": path_manager, + "dataset_root": self.dataset_root, + "limit_to": self.limit_to, + "limit_sequences_to": self.limit_sequences_to, + "load_point_clouds": self.load_point_clouds, + "mask_images": self.mask_images, + "mask_depths": self.mask_depths, + "path_manager": self.path_manager, "frame_annotations_file": frame_file, "sequence_annotations_file": sequence_file, "subset_lists_file": subset_lists_file, - **aux_dataset_kwargs, + **self.aux_dataset_kwargs, } # This maps the common names of the dataset subsets ("train"/"val"/"test") # to the names of the subsets in the CO3D dataset. set_names_mapping = _get_co3d_set_names_mapping( - dataset_name, - test_on_train, - only_test_set, + self.get_task(), + self.test_on_train, + self.only_test_set, ) # load the evaluation batches - task = dataset_name.split("_")[-1] batch_indices_path = os.path.join( - dataset_root, - category, - f"eval_batches_{task}.json", + self.dataset_root, + self.category, + f"eval_batches_{self.task_str}.json", ) - if path_manager is not None: - batch_indices_path = path_manager.get_local_path(batch_indices_path) + if self.path_manager is not None: + batch_indices_path = self.path_manager.get_local_path(batch_indices_path) if not os.path.isfile(batch_indices_path): # The batch indices file does not exist. # Most probably the user has not specified the root folder. @@ -181,25 +163,31 @@ def dataset_zoo( with open(batch_indices_path, "r") as f: eval_batch_index = json.load(f) + restrict_sequence_name = self.restrict_sequence_name - if task == "singlesequence": - assert ( - test_restrict_sequence_id is not None and test_restrict_sequence_id >= 0 - ), ( - "Please specify an integer id 'test_restrict_sequence_id'" - + " of the sequence considered for 'singlesequence'" - + " training and evaluation." - ) - assert len(restrict_sequence_name) == 0, ( - "For the 'singlesequence' task, the restrict_sequence_name has" - " to be unset while test_restrict_sequence_id has to be set to an" - " integer defining the order of the evaluation sequence." - ) + if self.get_task() == Task.SINGLE_SEQUENCE: + if ( + self.test_restrict_sequence_id is None + or self.test_restrict_sequence_id < 0 + ): + raise ValueError( + "Please specify an integer id 'test_restrict_sequence_id'" + + " of the sequence considered for 'singlesequence'" + + " training and evaluation." + ) + if len(self.restrict_sequence_name) > 0: + raise ValueError( + "For the 'singlesequence' task, the restrict_sequence_name has" + " to be unset while test_restrict_sequence_id has to be set to an" + " integer defining the order of the evaluation sequence." + ) # a sort-stable set() equivalent: eval_batches_sequence_names = list( {b[0][0]: None for b in eval_batch_index}.keys() ) - eval_sequence_name = eval_batches_sequence_names[test_restrict_sequence_id] + eval_sequence_name = eval_batches_sequence_names[ + self.test_restrict_sequence_id + ] eval_batch_index = [ b for b in eval_batch_index if b[0][0] == eval_sequence_name ] @@ -207,14 +195,14 @@ def dataset_zoo( restrict_sequence_name = [eval_sequence_name] train_dataset = None - if not only_test_set: + if not self.only_test_set: train_dataset = ImplicitronDataset( - n_frames_per_sequence=n_frames_per_sequence, + n_frames_per_sequence=self.n_frames_per_sequence, subsets=set_names_mapping["train"], pick_sequence=restrict_sequence_name, **common_kwargs, ) - if test_on_train: + if self.test_on_train: assert train_dataset is not None val_dataset = test_dataset = train_dataset else: @@ -237,29 +225,26 @@ def dataset_zoo( test_dataset.eval_batches = test_dataset.seq_frame_index_to_dataset_index( eval_batch_index ) - datasets = Datasets(train=train_dataset, val=val_dataset, test=test_dataset) + datasets = DatasetMap(train=train_dataset, val=val_dataset, test=test_dataset) - else: - raise ValueError(f"Unsupported dataset: {dataset_name}") + if self.assert_single_seq: + # check there's only one sequence in all datasets + sequence_names = { + sequence_name + for dset in datasets.iter_datasets() + for sequence_name in dset.sequence_names() + } + if len(sequence_names) > 1: + raise ValueError("Multiple sequences loaded but expected one") - if assert_single_seq: - # check there's only one sequence in all datasets - sequence_names = { - sequence_name - for dset in datasets.iter_datasets() - for sequence_name in dset.sequence_names() - } - if len(sequence_names) > 1: - raise ValueError("Multiple sequences loaded but expected one") + return datasets - return datasets - - -enable_get_default_args(dataset_zoo) + def get_task(self) -> Task: + return Task(self.task_str) def _get_co3d_set_names_mapping( - dataset_name: str, + task: Task, test_on_train: bool, only_test: bool, ) -> Dict[str, List[str]]: @@ -273,7 +258,7 @@ def _get_co3d_set_names_mapping( - val (if not test_on_train) - test (if not test_on_train) """ - single_seq = dataset_name == "co3d_singlesequence" + single_seq = task == Task.SINGLE_SEQUENCE if only_test: set_names_mapping = {} diff --git a/pytorch3d/implicitron/eval_demo.py b/pytorch3d/implicitron/eval_demo.py index a460ea0f..f8de2560 100644 --- a/pytorch3d/implicitron/eval_demo.py +++ b/pytorch3d/implicitron/eval_demo.py @@ -12,11 +12,12 @@ from typing import Any, cast, Dict, List, Optional, Tuple import lpips import torch from iopath.common.file_io import PathManager -from pytorch3d.implicitron.dataset.data_source import Task -from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo +from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task from pytorch3d.implicitron.dataset.dataset_base import FrameData, ImplicitronDatasetBase -from pytorch3d.implicitron.dataset.dataset_zoo import CO3D_CATEGORIES, dataset_zoo from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset +from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import ( + CO3D_CATEGORIES, +) from pytorch3d.implicitron.dataset.utils import is_known_frame from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import ( aggregate_nvs_results, @@ -101,23 +102,21 @@ def evaluate_dbir_for_category( torch.manual_seed(42) - dataset_name = { - Task.SINGLE_SEQUENCE: "co3d_singlesequence", - Task.MULTI_SEQUENCE: "co3d_multisequence", - }[task] - - datasets = dataset_zoo( - category=category, - dataset_root=os.environ["CO3D_DATASET_ROOT"], - assert_single_seq=task == Task.SINGLE_SEQUENCE, - dataset_name=dataset_name, - test_on_train=False, - load_point_clouds=True, - test_restrict_sequence_id=single_sequence_id, - path_manager=path_manager, + dataset_map_provider_args = { + "category": category, + "dataset_root": os.environ["CO3D_DATASET_ROOT"], + "assert_single_seq": task == Task.SINGLE_SEQUENCE, + "task_str": task.value, + "test_on_train": False, + "load_point_clouds": True, + "test_restrict_sequence_id": single_sequence_id, + "path_manager": path_manager, + } + data_source = ImplicitronDataSource( + dataset_map_provider_JsonIndexDatasetMapProvider_args=dataset_map_provider_args ) - dataloaders = dataloader_zoo(datasets) + datasets, dataloaders = data_source.get_datasets_and_dataloaders() test_dataset = datasets.test test_dataloader = dataloaders.test diff --git a/tests/implicitron/data/data_source.yaml b/tests/implicitron/data/data_source.yaml new file mode 100644 index 00000000..bff7ae08 --- /dev/null +++ b/tests/implicitron/data/data_source.yaml @@ -0,0 +1,33 @@ +dataset_map_provider_class_type: ??? +dataloader_args: + batch_size: 1 + num_workers: 0 + dataset_len: 1000 + dataset_len_val: 1 + images_per_seq_options: + - 2 + sample_consecutive_frames: false + consecutive_frames_max_gap: 0 + consecutive_frames_max_gap_seconds: 0.1 +dataset_map_provider_JsonIndexDatasetMapProvider_args: + category: ??? + task_str: singlesequence + dataset_root: '' + limit_to: -1 + limit_sequences_to: -1 + n_frames_per_sequence: -1 + test_on_train: false + load_point_clouds: false + mask_images: false + mask_depths: false + restrict_sequence_name: [] + test_restrict_sequence_id: -1 + assert_single_seq: false + only_test_set: false + aux_dataset_kwargs: + box_crop: true + box_crop_context: 0.3 + image_width: 800 + image_height: 800 + remove_empty_masks: true + path_manager: null diff --git a/tests/implicitron/data/overrides.yaml b/tests/implicitron/data/overrides.yaml index eda265d0..ea6cbb2f 100644 --- a/tests/implicitron/data/overrides.yaml +++ b/tests/implicitron/data/overrides.yaml @@ -118,6 +118,6 @@ implicit_function_IdrFeatureField_args: bias: 1.0 skip_in: [] weight_norm: true - n_harmonic_functions_xyz: 0 + n_harmonic_functions_xyz: 1729 pooled_feature_dim: 0 encoding_dim: 0 diff --git a/tests/implicitron/test_config_use.py b/tests/implicitron/test_config_use.py index 00d5c6fd..3820b93a 100644 --- a/tests/implicitron/test_config_use.py +++ b/tests/implicitron/test_config_use.py @@ -70,6 +70,9 @@ class TestGenericModel(unittest.TestCase): "AngleWeightedIdentityFeatureAggregator" ) args.implicit_function_class_type = "IdrFeatureField" + idr_args = args.implicit_function_IdrFeatureField_args + idr_args.n_harmonic_functions_xyz = 1729 + args.renderer_class_type = "LSTMRenderer" gm = GenericModel(**args) self.assertIsInstance(gm.renderer, LSTMRenderer) @@ -78,6 +81,7 @@ class TestGenericModel(unittest.TestCase): AngleWeightedIdentityFeatureAggregator, ) self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField) + self.assertEqual(gm._implicit_functions[0]._fn.n_harmonic_functions_xyz, 1729) self.assertIsInstance(gm.sequence_autodecoder, Autodecoder) self.assertIsInstance(gm.image_feature_extractor, ResNetFeatureExtractor) self.assertFalse(hasattr(gm, "implicit_function")) diff --git a/tests/implicitron/test_data_source.py b/tests/implicitron/test_data_source.py new file mode 100644 index 00000000..845aa858 --- /dev/null +++ b/tests/implicitron/test_data_source.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import unittest + +from omegaconf import OmegaConf +from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource +from pytorch3d.implicitron.tools.config import get_default_args + +if os.environ.get("FB_TEST", False): + from common_testing import get_tests_dir +else: + from tests.common_testing import get_tests_dir + +DATA_DIR = get_tests_dir() / "implicitron/data" +DEBUG: bool = False + + +class TestDataSource(unittest.TestCase): + def setUp(self): + self.maxDiff = None + + def test_one(self): + cfg = get_default_args(ImplicitronDataSource) + yaml = OmegaConf.to_yaml(cfg, sort_keys=False) + if DEBUG: + (DATA_DIR / "data_source.yaml").write_text(yaml) + self.assertEqual(yaml, (DATA_DIR / "data_source.yaml").read_text())