diff --git a/projects/implicitron_trainer/README.md b/projects/implicitron_trainer/README.md index 7b4eb72b..d53aa3b9 100644 --- a/projects/implicitron_trainer/README.md +++ b/projects/implicitron_trainer/README.md @@ -66,7 +66,8 @@ If you have a custom `experiment.py` script (as in the Option 2 above), replace To run training, pass a yaml config file, followed by a list of overridden arguments. For example, to train NeRF on the first skateboard sequence from CO3D dataset, you can run: ```shell -pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf dataset_args.dataset_root= dataset_args.category='skateboard' dataset_args.test_restrict_sequence_id=0 test_when_finished=True exp_dir= +dataset_args=data_source_args.dataset_args +pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf $dataset_args.dataset_root= $dataset_args.category='skateboard' $dataset_args.test_restrict_sequence_id=0 test_when_finished=True exp_dir= ``` Here, `--config-path` points to the config path relative to `pytorch3d_implicitron_runner` location; @@ -84,7 +85,8 @@ To run evaluation on the latest checkpoint after (or during) training, simply ad E.g. for executing the evaluation on the NeRF skateboard sequence, you can run: ```shell -pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf dataset_args.dataset_root= dataset_args.category='skateboard' dataset_args.test_restrict_sequence_id=0 exp_dir= eval_only=True +dataset_args=data_source_args.dataset_args +pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf $dataset_args.dataset_root= $dataset_args.category='skateboard' $dataset_args.test_restrict_sequence_id=0 exp_dir= eval_only=True ``` Evaluation prints the metrics to `stdout` and dumps them to a json file in `exp_dir`. @@ -202,7 +204,7 @@ to replace the implementation and potentially override the parameters. # Code and config structure As per above, the config structure is parsed automatically from the module hierarchy. -In particular, model parameters are contained in `generic_model_args` node, and dataset parameters in `dataset_args` node. +In particular, model parameters are contained in `generic_model_args` node, and dataset parameters in `data_source_args` node. Here is the class structure (single-line edges show aggregation, while double lines show available implementations): ``` @@ -233,8 +235,9 @@ generic_model_args: GenericModel ╘== AngleWeightedReductionFeatureAggregator ╘== ReductionFeatureAggregator solver_args: init_optimizer -dataset_args: dataset_zoo -dataloader_args: dataloader_zoo +data_source_args: ImplicitronDataSource +└-- dataset_args +└-- dataloader_args ``` Please look at the annotations of the respective classes or functions for the lists of hyperparameters. diff --git a/projects/implicitron_trainer/configs/repro_base.yaml b/projects/implicitron_trainer/configs/repro_base.yaml index 595e69e5..be4ab289 100644 --- a/projects/implicitron_trainer/configs/repro_base.yaml +++ b/projects/implicitron_trainer/configs/repro_base.yaml @@ -5,29 +5,30 @@ exp_dir: ./data/exps/base/ architecture: generic visualize_interval: 0 visdom_port: 8097 -dataloader_args: - batch_size: 10 - dataset_len: 1000 - dataset_len_val: 1 - num_workers: 8 - images_per_seq_options: - - 2 - - 3 - - 4 - - 5 - - 6 - - 7 - - 8 - - 9 - - 10 -dataset_args: - dataset_root: ${oc.env:CO3D_DATASET_ROOT} - load_point_clouds: false - mask_depths: false - mask_images: false - n_frames_per_sequence: -1 - test_on_train: true - test_restrict_sequence_id: 0 +data_source_args: + dataloader_args: + batch_size: 10 + dataset_len: 1000 + dataset_len_val: 1 + num_workers: 8 + images_per_seq_options: + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 10 + dataset_args: + dataset_root: ${oc.env:CO3D_DATASET_ROOT} + load_point_clouds: false + mask_depths: false + mask_images: false + n_frames_per_sequence: -1 + test_on_train: true + test_restrict_sequence_id: 0 generic_model_args: loss_weights: loss_mask_bce: 1.0 diff --git a/projects/implicitron_trainer/configs/repro_multiseq_base.yaml b/projects/implicitron_trainer/configs/repro_multiseq_base.yaml index 12abe1ae..128f29be 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_base.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_base.yaml @@ -1,30 +1,31 @@ defaults: - repro_base.yaml - _self_ -dataloader_args: - batch_size: 10 - dataset_len: 1000 - dataset_len_val: 1 - num_workers: 8 - images_per_seq_options: - - 2 - - 3 - - 4 - - 5 - - 6 - - 7 - - 8 - - 9 - - 10 -dataset_args: - assert_single_seq: false - dataset_name: co3d_multisequence - load_point_clouds: false - mask_depths: false - mask_images: false - n_frames_per_sequence: -1 - test_on_train: true - test_restrict_sequence_id: 0 +data_source_args: + dataloader_args: + batch_size: 10 + dataset_len: 1000 + dataset_len_val: 1 + num_workers: 8 + images_per_seq_options: + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 10 + dataset_args: + assert_single_seq: false + dataset_name: co3d_multisequence + load_point_clouds: false + mask_depths: false + mask_images: false + n_frames_per_sequence: -1 + test_on_train: true + test_restrict_sequence_id: 0 solver_args: max_epochs: 3000 milestones: diff --git a/projects/implicitron_trainer/configs/repro_singleseq_base.yaml b/projects/implicitron_trainer/configs/repro_singleseq_base.yaml index bbec0f4c..4e082395 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_base.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_base.yaml @@ -1,19 +1,20 @@ defaults: - repro_base - _self_ -dataloader_args: - batch_size: 1 - dataset_len: 1000 - dataset_len_val: 1 - num_workers: 8 - images_per_seq_options: - - 2 -dataset_args: - dataset_name: co3d_singlesequence - assert_single_seq: true - n_frames_per_sequence: -1 - test_restrict_sequence_id: 0 - test_on_train: false +data_source_args: + dataloader_args: + batch_size: 1 + dataset_len: 1000 + dataset_len_val: 1 + num_workers: 8 + images_per_seq_options: + - 2 + dataset_args: + dataset_name: co3d_singlesequence + assert_single_seq: true + n_frames_per_sequence: -1 + test_restrict_sequence_id: 0 + test_on_train: false generic_model_args: render_image_height: 800 render_image_width: 800 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_wce_base.yaml b/projects/implicitron_trainer/configs/repro_singleseq_wce_base.yaml index f8ae682a..57de6cf4 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_wce_base.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_wce_base.yaml @@ -1,18 +1,19 @@ defaults: - repro_singleseq_base - _self_ -dataloader_args: - batch_size: 10 - dataset_len: 1000 - dataset_len_val: 1 - num_workers: 8 - images_per_seq_options: - - 2 - - 3 - - 4 - - 5 - - 6 - - 7 - - 8 - - 9 - - 10 +data_source_args: + dataloader_args: + batch_size: 10 + dataset_len: 1000 + dataset_len_val: 1 + num_workers: 8 + images_per_seq_options: + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 10 diff --git a/projects/implicitron_trainer/experiment.py b/projects/implicitron_trainer/experiment.py index 482f02aa..d042f97e 100755 --- a/projects/implicitron_trainer/experiment.py +++ b/projects/implicitron_trainer/experiment.py @@ -64,8 +64,9 @@ import tqdm from omegaconf import DictConfig, OmegaConf from packaging import version from pytorch3d.implicitron.dataset import utils as ds_utils -from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo, Dataloaders -from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo, Datasets +from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task +from pytorch3d.implicitron.dataset.dataloader_zoo import Dataloaders +from pytorch3d.implicitron.dataset.dataset_zoo import Datasets from pytorch3d.implicitron.dataset.implicitron_dataset import ( FrameData, ImplicitronDataset, @@ -428,7 +429,7 @@ def trainvalidate( optimizer.step() -def run_training(cfg: DictConfig, device: str = "cpu"): +def run_training(cfg: DictConfig, device: str = "cpu") -> None: """ Entry point to run the training and validation loops based on the specified config file. @@ -452,8 +453,9 @@ def run_training(cfg: DictConfig, device: str = "cpu"): warnings.warn("Cant dump config due to insufficient permissions!") # setup datasets - datasets = dataset_zoo(**cfg.dataset_args) - dataloaders = dataloader_zoo(datasets, **cfg.dataloader_args) + datasource = ImplicitronDataSource(**cfg.data_source_args) + datasets, dataloaders = datasource.get_datasets_and_dataloaders() + task = datasource.get_task() # init the model model, stats, optimizer_state = init_model(cfg) @@ -464,7 +466,7 @@ def run_training(cfg: DictConfig, device: str = "cpu"): # only run evaluation on the test dataloader if cfg.eval_only: - _eval_and_dump(cfg, datasets, dataloaders, model, stats, device=device) + _eval_and_dump(cfg, task, datasets, dataloaders, model, stats, device=device) return # init the optimizer @@ -526,7 +528,7 @@ def run_training(cfg: DictConfig, device: str = "cpu"): and cfg.test_interval > 0 and epoch % cfg.test_interval == 0 ): - run_eval(cfg, model, stats, dataloaders.test, device=device) + _run_eval(model, stats, dataloaders.test, task, device=device) assert stats.epoch == epoch, "inconsistent stats!" @@ -546,11 +548,17 @@ def run_training(cfg: DictConfig, device: str = "cpu"): logger.info(f"LR change! {cur_lr} -> {new_lr}") if cfg.test_when_finished: - _eval_and_dump(cfg, datasets, dataloaders, model, stats, device=device) + _eval_and_dump(cfg, task, datasets, dataloaders, model, stats, device=device) def _eval_and_dump( - cfg, datasets: Datasets, dataloaders: Dataloaders, model, stats, device + cfg, + task: Task, + datasets: Datasets, + dataloaders: Dataloaders, + model, + stats, + device, ) -> None: """ Run the evaluation loop with the test data loader and @@ -562,16 +570,13 @@ def _eval_and_dump( if dataloader is None: raise ValueError('Dataloaders have to contain the "test" entry for eval!') - eval_task = cfg.dataset_args["dataset_name"].split("_")[-1] - if eval_task == "singlesequence": + if task == Task.SINGLE_SEQUENCE: if datasets.train is None: raise ValueError("train dataset must be provided") all_source_cameras = _get_all_source_cameras(datasets.train) else: all_source_cameras = None - results = run_eval( - cfg, model, all_source_cameras, dataloader, eval_task, device=device - ) + results = _run_eval(model, all_source_cameras, dataloader, task, device=device) # add the evaluation epoch to the results for r in results: @@ -598,7 +603,7 @@ def _get_eval_frame_data(frame_data): return frame_data_for_eval -def run_eval(cfg, model, all_source_cameras, loader, task, device): +def _run_eval(model, all_source_cameras, loader, task: Task, device): """ Run the evaluation loop on the test dataloader """ @@ -672,8 +677,7 @@ def _seed_all_random_engines(seed: int): class ExperimentConfig: generic_model_args: DictConfig = get_default_args_field(GenericModel) solver_args: DictConfig = get_default_args_field(init_optimizer) - dataset_args: DictConfig = get_default_args_field(dataset_zoo) - dataloader_args: DictConfig = get_default_args_field(dataloader_zoo) + data_source_args: DictConfig = get_default_args_field(ImplicitronDataSource) architecture: str = "generic" detect_anomaly: bool = False eval_only: bool = False diff --git a/projects/implicitron_trainer/visualize_reconstruction.py b/projects/implicitron_trainer/visualize_reconstruction.py index 506a8054..64f95224 100644 --- a/projects/implicitron_trainer/visualize_reconstruction.py +++ b/projects/implicitron_trainer/visualize_reconstruction.py @@ -23,6 +23,7 @@ import torch import torch.nn.functional as Fu from experiment import init_model from omegaconf import OmegaConf +from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo from pytorch3d.implicitron.dataset.implicitron_dataset import ( FrameData, @@ -326,12 +327,14 @@ def export_scenes( config.gpu_idx = gpu_idx config.exp_dir = exp_dir # important so that the CO3D dataset gets loaded in full - config.dataset_args.test_on_train = False + config.data_source_args.dataset_args.test_on_train = False # Set the rendering image size config.generic_model_args.render_image_width = render_size[0] config.generic_model_args.render_image_height = render_size[1] if restrict_sequence_name is not None: - config.dataset_args.restrict_sequence_name = restrict_sequence_name + config.data_source_args.dataset_args.restrict_sequence_name = ( + restrict_sequence_name + ) # Set up the CUDA env for the visualization os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" @@ -343,7 +346,8 @@ def export_scenes( model.eval() # Setup the dataset - datasets = dataset_zoo(**config.dataset_args) + datasource = ImplicitronDataSource(**config.data_source_args) + datasets = dataset_zoo(**datasource.dataset_args) dataset: Optional[ImplicitronDatasetBase] = getattr(datasets, split, None) if dataset is None: raise ValueError(f"{split} dataset not provided") diff --git a/pytorch3d/implicitron/dataset/data_source.py b/pytorch3d/implicitron/dataset/data_source.py new file mode 100644 index 00000000..89acb282 --- /dev/null +++ b/pytorch3d/implicitron/dataset/data_source.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Tuple + +from omegaconf import DictConfig +from pytorch3d.implicitron.tools.config import get_default_args_field, ReplaceableBase + +from .dataloader_zoo import dataloader_zoo, Dataloaders +from .dataset_zoo import dataset_zoo, Datasets + + +class Task(Enum): + SINGLE_SEQUENCE = "singlesequence" + MULTI_SEQUENCE = "multisequence" + + +class DataSourceBase(ReplaceableBase): + """ + Base class for a data source in Implicitron. It encapsulates Dataset + and DataLoader configuration. + """ + + def get_datasets_and_dataloaders(self) -> Tuple[Datasets, Dataloaders]: + raise NotImplementedError() + + +class ImplicitronDataSource(DataSourceBase): + """ + Represents the data used in Implicitron. This is the only implementation + of DataSourceBase provided. + """ + + dataset_args: DictConfig = get_default_args_field(dataset_zoo) + dataloader_args: DictConfig = get_default_args_field(dataloader_zoo) + + def get_datasets_and_dataloaders(self) -> Tuple[Datasets, Dataloaders]: + datasets = dataset_zoo(**self.dataset_args) + dataloaders = dataloader_zoo(datasets, **self.dataloader_args) + return datasets, dataloaders + + def get_task(self) -> Task: + eval_task = self.dataset_args["dataset_name"].split("_")[-1] + return Task(eval_task) diff --git a/pytorch3d/implicitron/eval_demo.py b/pytorch3d/implicitron/eval_demo.py index 60bffb70..9b2e223c 100644 --- a/pytorch3d/implicitron/eval_demo.py +++ b/pytorch3d/implicitron/eval_demo.py @@ -7,11 +7,12 @@ import dataclasses import os -from typing import cast, Optional, Tuple +from typing import Any, cast, Dict, List, Optional, Tuple import lpips import torch from iopath.common.file_io import PathManager +from pytorch3d.implicitron.dataset.data_source import Task from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo from pytorch3d.implicitron.dataset.dataset_zoo import CO3D_CATEGORIES, dataset_zoo from pytorch3d.implicitron.dataset.implicitron_dataset import ( @@ -47,10 +48,12 @@ def main() -> None: """ task_results = {} - for task in ("singlesequence", "multisequence"): + for task in (Task.SINGLE_SEQUENCE, Task.MULTI_SEQUENCE): task_results[task] = [] - for category in CO3D_CATEGORIES[: (20 if task == "singlesequence" else 10)]: - for single_sequence_id in (0, 1) if task == "singlesequence" else (None,): + for category in CO3D_CATEGORIES[: (20 if task == Task.SINGLE_SEQUENCE else 10)]: + for single_sequence_id in ( + (0, 1) if task == Task.SINGLE_SEQUENCE else (None,) + ): category_result = evaluate_dbir_for_category( category, task=task, single_sequence_id=single_sequence_id ) @@ -74,9 +77,9 @@ def main() -> None: def evaluate_dbir_for_category( - category: str = "apple", + category: str, + task: Task, bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0), - task: str = "singlesequence", single_sequence_id: Optional[int] = None, num_workers: int = 16, path_manager: Optional[PathManager] = None, @@ -101,14 +104,16 @@ def evaluate_dbir_for_category( torch.manual_seed(42) - if task not in ["multisequence", "singlesequence"]: - raise ValueError("'task' has to be either 'multisequence' or 'singlesequence'") + dataset_name = { + Task.SINGLE_SEQUENCE: "co3d_singlesequence", + Task.MULTI_SEQUENCE: "co3d_multisequence", + }[task] datasets = dataset_zoo( category=category, dataset_root=os.environ["CO3D_DATASET_ROOT"], - assert_single_seq=task == "singlesequence", - dataset_name=f"co3d_{task}", + assert_single_seq=task == Task.SINGLE_SEQUENCE, + dataset_name=dataset_name, test_on_train=False, load_point_clouds=True, test_restrict_sequence_id=single_sequence_id, @@ -122,7 +127,7 @@ def evaluate_dbir_for_category( if test_dataset is None or test_dataloader is None: raise ValueError("must have a test dataset.") - if task == "singlesequence": + if task == Task.SINGLE_SEQUENCE: # all_source_cameras are needed for evaluation of the # target camera difficulty # pyre-fixme[16]: `ImplicitronDataset` has no attribute `frame_annots`. @@ -173,7 +178,9 @@ def evaluate_dbir_for_category( return category_result["results"] -def _print_aggregate_results(task, task_results) -> None: +def _print_aggregate_results( + task: Task, task_results: Dict[Task, List[List[Dict[str, Any]]]] +) -> None: """ Prints the aggregate metrics for a given task. """ diff --git a/pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py b/pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py index 37467acf..c1a16068 100644 --- a/pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py +++ b/pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py @@ -14,6 +14,7 @@ from typing import Any, Dict, List, Optional, Sequence, Union import numpy as np import torch import torch.nn.functional as F +from pytorch3d.implicitron.dataset.data_source import Task from pytorch3d.implicitron.dataset.implicitron_dataset import FrameData from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame from pytorch3d.implicitron.models.base_model import ImplicitronRender @@ -317,7 +318,7 @@ def eval_batch( if visualize: visualizer.show_depth(abs_.mean().item(), name_postfix, loss_mask_now) if break_after_visualising: - import pdb + import pdb # noqa: B602 pdb.set_trace() @@ -411,16 +412,16 @@ def _reduce_camera_iou_overlap(ious: torch.Tensor, topk: int = 2) -> torch.Tenso return ious.topk(k=min(topk, len(ious) - 1)).values.mean() -def get_camera_difficulty_bin_edges(task: str): +def _get_camera_difficulty_bin_edges(task: Task): """ Get the edges of camera difficulty bins. """ _eps = 1e-5 - if task == "multisequence": + if task == Task.MULTI_SEQUENCE: # TODO: extract those to constants diff_bin_edges = torch.linspace(0.5, 1.0 + _eps, 4) diff_bin_edges[0] = 0.0 - _eps - elif task == "singlesequence": + elif task == Task.SINGLE_SEQUENCE: diff_bin_edges = torch.tensor([0.0 - _eps, 0.97, 0.98, 1.0 + _eps]).float() else: raise ValueError(f"No such eval task {task}.") @@ -430,7 +431,7 @@ def get_camera_difficulty_bin_edges(task: str): def summarize_nvs_eval_results( per_batch_eval_results: List[Dict[str, Any]], - task: str = "singlesequence", + task: Task, ): """ Compile the per-batch evaluation results `per_batch_eval_results` into @@ -439,7 +440,6 @@ def summarize_nvs_eval_results( Args: per_batch_eval_results: Metrics of each per-batch evaluation. task: The type of the new-view synthesis task. - Either 'singlesequence' or 'multisequence'. Returns: nvs_results_flat: A flattened dict of all aggregate metrics. @@ -447,10 +447,10 @@ def summarize_nvs_eval_results( """ n_batches = len(per_batch_eval_results) eval_sets: List[Optional[str]] = [] - if task == "singlesequence": + if task == Task.SINGLE_SEQUENCE: eval_sets = [None] # assert n_batches==100 - elif task == "multisequence": + elif task == Task.MULTI_SEQUENCE: eval_sets = ["train", "test"] # assert n_batches==1000 else: @@ -466,17 +466,17 @@ def summarize_nvs_eval_results( # init the result database dict results = [] - diff_bin_edges, diff_bin_names = get_camera_difficulty_bin_edges(task) + diff_bin_edges, diff_bin_names = _get_camera_difficulty_bin_edges(task) n_diff_edges = diff_bin_edges.numel() # add per set averages for SET in eval_sets: if SET is None: - # task=='singlesequence' + assert task == Task.SINGLE_SEQUENCE ok_set = torch.ones(n_batches, dtype=torch.bool) set_name = "test" else: - # task=='multisequence' + assert task == Task.MULTI_SEQUENCE ok_set = is_train == int(SET == "train") set_name = SET @@ -501,7 +501,7 @@ def summarize_nvs_eval_results( } ) - if task == "multisequence": + if task == Task.MULTI_SEQUENCE: # split based on n_src_views n_src_views = batch_sizes - 1 for n_src in EVAL_N_SRC_VIEWS: