data_source

Summary:
Move dataset_args and dataloader_args from ExperimentConfig into a new member called datasource so that it can contain replaceables.

Also add enum Task for task type.

Reviewed By: shapovalov

Differential Revision: D36201719

fbshipit-source-id: 47d6967bfea3b7b146b6bbd1572e0457c9365871
This commit is contained in:
Jeremy Reizenstein 2022-05-20 07:50:30 -07:00 committed by Facebook GitHub Bot
parent 9ec9d057cc
commit 73dc109dba
10 changed files with 194 additions and 124 deletions

View File

@ -66,7 +66,8 @@ If you have a custom `experiment.py` script (as in the Option 2 above), replace
To run training, pass a yaml config file, followed by a list of overridden arguments.
For example, to train NeRF on the first skateboard sequence from CO3D dataset, you can run:
```shell
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf dataset_args.dataset_root=<DATASET_ROOT> dataset_args.category='skateboard' dataset_args.test_restrict_sequence_id=0 test_when_finished=True exp_dir=<CHECKPOINT_DIR>
dataset_args=data_source_args.dataset_args
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf $dataset_args.dataset_root=<DATASET_ROOT> $dataset_args.category='skateboard' $dataset_args.test_restrict_sequence_id=0 test_when_finished=True exp_dir=<CHECKPOINT_DIR>
```
Here, `--config-path` points to the config path relative to `pytorch3d_implicitron_runner` location;
@ -84,7 +85,8 @@ To run evaluation on the latest checkpoint after (or during) training, simply ad
E.g. for executing the evaluation on the NeRF skateboard sequence, you can run:
```shell
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf dataset_args.dataset_root=<CO3D_DATASET_ROOT> dataset_args.category='skateboard' dataset_args.test_restrict_sequence_id=0 exp_dir=<CHECKPOINT_DIR> eval_only=True
dataset_args=data_source_args.dataset_args
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf $dataset_args.dataset_root=<CO3D_DATASET_ROOT> $dataset_args.category='skateboard' $dataset_args.test_restrict_sequence_id=0 exp_dir=<CHECKPOINT_DIR> eval_only=True
```
Evaluation prints the metrics to `stdout` and dumps them to a json file in `exp_dir`.
@ -202,7 +204,7 @@ to replace the implementation and potentially override the parameters.
# Code and config structure
As per above, the config structure is parsed automatically from the module hierarchy.
In particular, model parameters are contained in `generic_model_args` node, and dataset parameters in `dataset_args` node.
In particular, model parameters are contained in `generic_model_args` node, and dataset parameters in `data_source_args` node.
Here is the class structure (single-line edges show aggregation, while double lines show available implementations):
```
@ -233,8 +235,9 @@ generic_model_args: GenericModel
╘== AngleWeightedReductionFeatureAggregator
╘== ReductionFeatureAggregator
solver_args: init_optimizer
dataset_args: dataset_zoo
dataloader_args: dataloader_zoo
data_source_args: ImplicitronDataSource
└-- dataset_args
└-- dataloader_args
```
Please look at the annotations of the respective classes or functions for the lists of hyperparameters.

View File

@ -5,29 +5,30 @@ exp_dir: ./data/exps/base/
architecture: generic
visualize_interval: 0
visdom_port: 8097
dataloader_args:
batch_size: 10
dataset_len: 1000
dataset_len_val: 1
num_workers: 8
images_per_seq_options:
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
dataset_args:
dataset_root: ${oc.env:CO3D_DATASET_ROOT}
load_point_clouds: false
mask_depths: false
mask_images: false
n_frames_per_sequence: -1
test_on_train: true
test_restrict_sequence_id: 0
data_source_args:
dataloader_args:
batch_size: 10
dataset_len: 1000
dataset_len_val: 1
num_workers: 8
images_per_seq_options:
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
dataset_args:
dataset_root: ${oc.env:CO3D_DATASET_ROOT}
load_point_clouds: false
mask_depths: false
mask_images: false
n_frames_per_sequence: -1
test_on_train: true
test_restrict_sequence_id: 0
generic_model_args:
loss_weights:
loss_mask_bce: 1.0

View File

@ -1,30 +1,31 @@
defaults:
- repro_base.yaml
- _self_
dataloader_args:
batch_size: 10
dataset_len: 1000
dataset_len_val: 1
num_workers: 8
images_per_seq_options:
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
dataset_args:
assert_single_seq: false
dataset_name: co3d_multisequence
load_point_clouds: false
mask_depths: false
mask_images: false
n_frames_per_sequence: -1
test_on_train: true
test_restrict_sequence_id: 0
data_source_args:
dataloader_args:
batch_size: 10
dataset_len: 1000
dataset_len_val: 1
num_workers: 8
images_per_seq_options:
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
dataset_args:
assert_single_seq: false
dataset_name: co3d_multisequence
load_point_clouds: false
mask_depths: false
mask_images: false
n_frames_per_sequence: -1
test_on_train: true
test_restrict_sequence_id: 0
solver_args:
max_epochs: 3000
milestones:

View File

@ -1,19 +1,20 @@
defaults:
- repro_base
- _self_
dataloader_args:
batch_size: 1
dataset_len: 1000
dataset_len_val: 1
num_workers: 8
images_per_seq_options:
- 2
dataset_args:
dataset_name: co3d_singlesequence
assert_single_seq: true
n_frames_per_sequence: -1
test_restrict_sequence_id: 0
test_on_train: false
data_source_args:
dataloader_args:
batch_size: 1
dataset_len: 1000
dataset_len_val: 1
num_workers: 8
images_per_seq_options:
- 2
dataset_args:
dataset_name: co3d_singlesequence
assert_single_seq: true
n_frames_per_sequence: -1
test_restrict_sequence_id: 0
test_on_train: false
generic_model_args:
render_image_height: 800
render_image_width: 800

View File

@ -1,18 +1,19 @@
defaults:
- repro_singleseq_base
- _self_
dataloader_args:
batch_size: 10
dataset_len: 1000
dataset_len_val: 1
num_workers: 8
images_per_seq_options:
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
data_source_args:
dataloader_args:
batch_size: 10
dataset_len: 1000
dataset_len_val: 1
num_workers: 8
images_per_seq_options:
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10

View File

@ -64,8 +64,9 @@ import tqdm
from omegaconf import DictConfig, OmegaConf
from packaging import version
from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo, Dataloaders
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo, Datasets
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
from pytorch3d.implicitron.dataset.dataloader_zoo import Dataloaders
from pytorch3d.implicitron.dataset.dataset_zoo import Datasets
from pytorch3d.implicitron.dataset.implicitron_dataset import (
FrameData,
ImplicitronDataset,
@ -428,7 +429,7 @@ def trainvalidate(
optimizer.step()
def run_training(cfg: DictConfig, device: str = "cpu"):
def run_training(cfg: DictConfig, device: str = "cpu") -> None:
"""
Entry point to run the training and validation loops
based on the specified config file.
@ -452,8 +453,9 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
warnings.warn("Cant dump config due to insufficient permissions!")
# setup datasets
datasets = dataset_zoo(**cfg.dataset_args)
dataloaders = dataloader_zoo(datasets, **cfg.dataloader_args)
datasource = ImplicitronDataSource(**cfg.data_source_args)
datasets, dataloaders = datasource.get_datasets_and_dataloaders()
task = datasource.get_task()
# init the model
model, stats, optimizer_state = init_model(cfg)
@ -464,7 +466,7 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
# only run evaluation on the test dataloader
if cfg.eval_only:
_eval_and_dump(cfg, datasets, dataloaders, model, stats, device=device)
_eval_and_dump(cfg, task, datasets, dataloaders, model, stats, device=device)
return
# init the optimizer
@ -526,7 +528,7 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
and cfg.test_interval > 0
and epoch % cfg.test_interval == 0
):
run_eval(cfg, model, stats, dataloaders.test, device=device)
_run_eval(model, stats, dataloaders.test, task, device=device)
assert stats.epoch == epoch, "inconsistent stats!"
@ -546,11 +548,17 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
logger.info(f"LR change! {cur_lr} -> {new_lr}")
if cfg.test_when_finished:
_eval_and_dump(cfg, datasets, dataloaders, model, stats, device=device)
_eval_and_dump(cfg, task, datasets, dataloaders, model, stats, device=device)
def _eval_and_dump(
cfg, datasets: Datasets, dataloaders: Dataloaders, model, stats, device
cfg,
task: Task,
datasets: Datasets,
dataloaders: Dataloaders,
model,
stats,
device,
) -> None:
"""
Run the evaluation loop with the test data loader and
@ -562,16 +570,13 @@ def _eval_and_dump(
if dataloader is None:
raise ValueError('Dataloaders have to contain the "test" entry for eval!')
eval_task = cfg.dataset_args["dataset_name"].split("_")[-1]
if eval_task == "singlesequence":
if task == Task.SINGLE_SEQUENCE:
if datasets.train is None:
raise ValueError("train dataset must be provided")
all_source_cameras = _get_all_source_cameras(datasets.train)
else:
all_source_cameras = None
results = run_eval(
cfg, model, all_source_cameras, dataloader, eval_task, device=device
)
results = _run_eval(model, all_source_cameras, dataloader, task, device=device)
# add the evaluation epoch to the results
for r in results:
@ -598,7 +603,7 @@ def _get_eval_frame_data(frame_data):
return frame_data_for_eval
def run_eval(cfg, model, all_source_cameras, loader, task, device):
def _run_eval(model, all_source_cameras, loader, task: Task, device):
"""
Run the evaluation loop on the test dataloader
"""
@ -672,8 +677,7 @@ def _seed_all_random_engines(seed: int):
class ExperimentConfig:
generic_model_args: DictConfig = get_default_args_field(GenericModel)
solver_args: DictConfig = get_default_args_field(init_optimizer)
dataset_args: DictConfig = get_default_args_field(dataset_zoo)
dataloader_args: DictConfig = get_default_args_field(dataloader_zoo)
data_source_args: DictConfig = get_default_args_field(ImplicitronDataSource)
architecture: str = "generic"
detect_anomaly: bool = False
eval_only: bool = False

View File

@ -23,6 +23,7 @@ import torch
import torch.nn.functional as Fu
from experiment import init_model
from omegaconf import OmegaConf
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
from pytorch3d.implicitron.dataset.implicitron_dataset import (
FrameData,
@ -326,12 +327,14 @@ def export_scenes(
config.gpu_idx = gpu_idx
config.exp_dir = exp_dir
# important so that the CO3D dataset gets loaded in full
config.dataset_args.test_on_train = False
config.data_source_args.dataset_args.test_on_train = False
# Set the rendering image size
config.generic_model_args.render_image_width = render_size[0]
config.generic_model_args.render_image_height = render_size[1]
if restrict_sequence_name is not None:
config.dataset_args.restrict_sequence_name = restrict_sequence_name
config.data_source_args.dataset_args.restrict_sequence_name = (
restrict_sequence_name
)
# Set up the CUDA env for the visualization
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
@ -343,7 +346,8 @@ def export_scenes(
model.eval()
# Setup the dataset
datasets = dataset_zoo(**config.dataset_args)
datasource = ImplicitronDataSource(**config.data_source_args)
datasets = dataset_zoo(**datasource.dataset_args)
dataset: Optional[ImplicitronDatasetBase] = getattr(datasets, split, None)
if dataset is None:
raise ValueError(f"{split} dataset not provided")

View File

@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from enum import Enum
from typing import Tuple
from omegaconf import DictConfig
from pytorch3d.implicitron.tools.config import get_default_args_field, ReplaceableBase
from .dataloader_zoo import dataloader_zoo, Dataloaders
from .dataset_zoo import dataset_zoo, Datasets
class Task(Enum):
SINGLE_SEQUENCE = "singlesequence"
MULTI_SEQUENCE = "multisequence"
class DataSourceBase(ReplaceableBase):
"""
Base class for a data source in Implicitron. It encapsulates Dataset
and DataLoader configuration.
"""
def get_datasets_and_dataloaders(self) -> Tuple[Datasets, Dataloaders]:
raise NotImplementedError()
class ImplicitronDataSource(DataSourceBase):
"""
Represents the data used in Implicitron. This is the only implementation
of DataSourceBase provided.
"""
dataset_args: DictConfig = get_default_args_field(dataset_zoo)
dataloader_args: DictConfig = get_default_args_field(dataloader_zoo)
def get_datasets_and_dataloaders(self) -> Tuple[Datasets, Dataloaders]:
datasets = dataset_zoo(**self.dataset_args)
dataloaders = dataloader_zoo(datasets, **self.dataloader_args)
return datasets, dataloaders
def get_task(self) -> Task:
eval_task = self.dataset_args["dataset_name"].split("_")[-1]
return Task(eval_task)

View File

@ -7,11 +7,12 @@
import dataclasses
import os
from typing import cast, Optional, Tuple
from typing import Any, cast, Dict, List, Optional, Tuple
import lpips
import torch
from iopath.common.file_io import PathManager
from pytorch3d.implicitron.dataset.data_source import Task
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
from pytorch3d.implicitron.dataset.dataset_zoo import CO3D_CATEGORIES, dataset_zoo
from pytorch3d.implicitron.dataset.implicitron_dataset import (
@ -47,10 +48,12 @@ def main() -> None:
"""
task_results = {}
for task in ("singlesequence", "multisequence"):
for task in (Task.SINGLE_SEQUENCE, Task.MULTI_SEQUENCE):
task_results[task] = []
for category in CO3D_CATEGORIES[: (20 if task == "singlesequence" else 10)]:
for single_sequence_id in (0, 1) if task == "singlesequence" else (None,):
for category in CO3D_CATEGORIES[: (20 if task == Task.SINGLE_SEQUENCE else 10)]:
for single_sequence_id in (
(0, 1) if task == Task.SINGLE_SEQUENCE else (None,)
):
category_result = evaluate_dbir_for_category(
category, task=task, single_sequence_id=single_sequence_id
)
@ -74,9 +77,9 @@ def main() -> None:
def evaluate_dbir_for_category(
category: str = "apple",
category: str,
task: Task,
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0),
task: str = "singlesequence",
single_sequence_id: Optional[int] = None,
num_workers: int = 16,
path_manager: Optional[PathManager] = None,
@ -101,14 +104,16 @@ def evaluate_dbir_for_category(
torch.manual_seed(42)
if task not in ["multisequence", "singlesequence"]:
raise ValueError("'task' has to be either 'multisequence' or 'singlesequence'")
dataset_name = {
Task.SINGLE_SEQUENCE: "co3d_singlesequence",
Task.MULTI_SEQUENCE: "co3d_multisequence",
}[task]
datasets = dataset_zoo(
category=category,
dataset_root=os.environ["CO3D_DATASET_ROOT"],
assert_single_seq=task == "singlesequence",
dataset_name=f"co3d_{task}",
assert_single_seq=task == Task.SINGLE_SEQUENCE,
dataset_name=dataset_name,
test_on_train=False,
load_point_clouds=True,
test_restrict_sequence_id=single_sequence_id,
@ -122,7 +127,7 @@ def evaluate_dbir_for_category(
if test_dataset is None or test_dataloader is None:
raise ValueError("must have a test dataset.")
if task == "singlesequence":
if task == Task.SINGLE_SEQUENCE:
# all_source_cameras are needed for evaluation of the
# target camera difficulty
# pyre-fixme[16]: `ImplicitronDataset` has no attribute `frame_annots`.
@ -173,7 +178,9 @@ def evaluate_dbir_for_category(
return category_result["results"]
def _print_aggregate_results(task, task_results) -> None:
def _print_aggregate_results(
task: Task, task_results: Dict[Task, List[List[Dict[str, Any]]]]
) -> None:
"""
Prints the aggregate metrics for a given task.
"""

View File

@ -14,6 +14,7 @@ from typing import Any, Dict, List, Optional, Sequence, Union
import numpy as np
import torch
import torch.nn.functional as F
from pytorch3d.implicitron.dataset.data_source import Task
from pytorch3d.implicitron.dataset.implicitron_dataset import FrameData
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
from pytorch3d.implicitron.models.base_model import ImplicitronRender
@ -317,7 +318,7 @@ def eval_batch(
if visualize:
visualizer.show_depth(abs_.mean().item(), name_postfix, loss_mask_now)
if break_after_visualising:
import pdb
import pdb # noqa: B602
pdb.set_trace()
@ -411,16 +412,16 @@ def _reduce_camera_iou_overlap(ious: torch.Tensor, topk: int = 2) -> torch.Tenso
return ious.topk(k=min(topk, len(ious) - 1)).values.mean()
def get_camera_difficulty_bin_edges(task: str):
def _get_camera_difficulty_bin_edges(task: Task):
"""
Get the edges of camera difficulty bins.
"""
_eps = 1e-5
if task == "multisequence":
if task == Task.MULTI_SEQUENCE:
# TODO: extract those to constants
diff_bin_edges = torch.linspace(0.5, 1.0 + _eps, 4)
diff_bin_edges[0] = 0.0 - _eps
elif task == "singlesequence":
elif task == Task.SINGLE_SEQUENCE:
diff_bin_edges = torch.tensor([0.0 - _eps, 0.97, 0.98, 1.0 + _eps]).float()
else:
raise ValueError(f"No such eval task {task}.")
@ -430,7 +431,7 @@ def get_camera_difficulty_bin_edges(task: str):
def summarize_nvs_eval_results(
per_batch_eval_results: List[Dict[str, Any]],
task: str = "singlesequence",
task: Task,
):
"""
Compile the per-batch evaluation results `per_batch_eval_results` into
@ -439,7 +440,6 @@ def summarize_nvs_eval_results(
Args:
per_batch_eval_results: Metrics of each per-batch evaluation.
task: The type of the new-view synthesis task.
Either 'singlesequence' or 'multisequence'.
Returns:
nvs_results_flat: A flattened dict of all aggregate metrics.
@ -447,10 +447,10 @@ def summarize_nvs_eval_results(
"""
n_batches = len(per_batch_eval_results)
eval_sets: List[Optional[str]] = []
if task == "singlesequence":
if task == Task.SINGLE_SEQUENCE:
eval_sets = [None]
# assert n_batches==100
elif task == "multisequence":
elif task == Task.MULTI_SEQUENCE:
eval_sets = ["train", "test"]
# assert n_batches==1000
else:
@ -466,17 +466,17 @@ def summarize_nvs_eval_results(
# init the result database dict
results = []
diff_bin_edges, diff_bin_names = get_camera_difficulty_bin_edges(task)
diff_bin_edges, diff_bin_names = _get_camera_difficulty_bin_edges(task)
n_diff_edges = diff_bin_edges.numel()
# add per set averages
for SET in eval_sets:
if SET is None:
# task=='singlesequence'
assert task == Task.SINGLE_SEQUENCE
ok_set = torch.ones(n_batches, dtype=torch.bool)
set_name = "test"
else:
# task=='multisequence'
assert task == Task.MULTI_SEQUENCE
ok_set = is_train == int(SET == "train")
set_name = SET
@ -501,7 +501,7 @@ def summarize_nvs_eval_results(
}
)
if task == "multisequence":
if task == Task.MULTI_SEQUENCE:
# split based on n_src_views
n_src_views = batch_sizes - 1
for n_src in EVAL_N_SRC_VIEWS: