mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
data_source
Summary: Move dataset_args and dataloader_args from ExperimentConfig into a new member called datasource so that it can contain replaceables. Also add enum Task for task type. Reviewed By: shapovalov Differential Revision: D36201719 fbshipit-source-id: 47d6967bfea3b7b146b6bbd1572e0457c9365871
This commit is contained in:
parent
9ec9d057cc
commit
73dc109dba
@ -66,7 +66,8 @@ If you have a custom `experiment.py` script (as in the Option 2 above), replace
|
||||
To run training, pass a yaml config file, followed by a list of overridden arguments.
|
||||
For example, to train NeRF on the first skateboard sequence from CO3D dataset, you can run:
|
||||
```shell
|
||||
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf dataset_args.dataset_root=<DATASET_ROOT> dataset_args.category='skateboard' dataset_args.test_restrict_sequence_id=0 test_when_finished=True exp_dir=<CHECKPOINT_DIR>
|
||||
dataset_args=data_source_args.dataset_args
|
||||
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf $dataset_args.dataset_root=<DATASET_ROOT> $dataset_args.category='skateboard' $dataset_args.test_restrict_sequence_id=0 test_when_finished=True exp_dir=<CHECKPOINT_DIR>
|
||||
```
|
||||
|
||||
Here, `--config-path` points to the config path relative to `pytorch3d_implicitron_runner` location;
|
||||
@ -84,7 +85,8 @@ To run evaluation on the latest checkpoint after (or during) training, simply ad
|
||||
|
||||
E.g. for executing the evaluation on the NeRF skateboard sequence, you can run:
|
||||
```shell
|
||||
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf dataset_args.dataset_root=<CO3D_DATASET_ROOT> dataset_args.category='skateboard' dataset_args.test_restrict_sequence_id=0 exp_dir=<CHECKPOINT_DIR> eval_only=True
|
||||
dataset_args=data_source_args.dataset_args
|
||||
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf $dataset_args.dataset_root=<CO3D_DATASET_ROOT> $dataset_args.category='skateboard' $dataset_args.test_restrict_sequence_id=0 exp_dir=<CHECKPOINT_DIR> eval_only=True
|
||||
```
|
||||
Evaluation prints the metrics to `stdout` and dumps them to a json file in `exp_dir`.
|
||||
|
||||
@ -202,7 +204,7 @@ to replace the implementation and potentially override the parameters.
|
||||
# Code and config structure
|
||||
|
||||
As per above, the config structure is parsed automatically from the module hierarchy.
|
||||
In particular, model parameters are contained in `generic_model_args` node, and dataset parameters in `dataset_args` node.
|
||||
In particular, model parameters are contained in `generic_model_args` node, and dataset parameters in `data_source_args` node.
|
||||
|
||||
Here is the class structure (single-line edges show aggregation, while double lines show available implementations):
|
||||
```
|
||||
@ -233,8 +235,9 @@ generic_model_args: GenericModel
|
||||
╘== AngleWeightedReductionFeatureAggregator
|
||||
╘== ReductionFeatureAggregator
|
||||
solver_args: init_optimizer
|
||||
dataset_args: dataset_zoo
|
||||
dataloader_args: dataloader_zoo
|
||||
data_source_args: ImplicitronDataSource
|
||||
└-- dataset_args
|
||||
└-- dataloader_args
|
||||
```
|
||||
|
||||
Please look at the annotations of the respective classes or functions for the lists of hyperparameters.
|
||||
|
@ -5,29 +5,30 @@ exp_dir: ./data/exps/base/
|
||||
architecture: generic
|
||||
visualize_interval: 0
|
||||
visdom_port: 8097
|
||||
dataloader_args:
|
||||
batch_size: 10
|
||||
dataset_len: 1000
|
||||
dataset_len_val: 1
|
||||
num_workers: 8
|
||||
images_per_seq_options:
|
||||
- 2
|
||||
- 3
|
||||
- 4
|
||||
- 5
|
||||
- 6
|
||||
- 7
|
||||
- 8
|
||||
- 9
|
||||
- 10
|
||||
dataset_args:
|
||||
dataset_root: ${oc.env:CO3D_DATASET_ROOT}
|
||||
load_point_clouds: false
|
||||
mask_depths: false
|
||||
mask_images: false
|
||||
n_frames_per_sequence: -1
|
||||
test_on_train: true
|
||||
test_restrict_sequence_id: 0
|
||||
data_source_args:
|
||||
dataloader_args:
|
||||
batch_size: 10
|
||||
dataset_len: 1000
|
||||
dataset_len_val: 1
|
||||
num_workers: 8
|
||||
images_per_seq_options:
|
||||
- 2
|
||||
- 3
|
||||
- 4
|
||||
- 5
|
||||
- 6
|
||||
- 7
|
||||
- 8
|
||||
- 9
|
||||
- 10
|
||||
dataset_args:
|
||||
dataset_root: ${oc.env:CO3D_DATASET_ROOT}
|
||||
load_point_clouds: false
|
||||
mask_depths: false
|
||||
mask_images: false
|
||||
n_frames_per_sequence: -1
|
||||
test_on_train: true
|
||||
test_restrict_sequence_id: 0
|
||||
generic_model_args:
|
||||
loss_weights:
|
||||
loss_mask_bce: 1.0
|
||||
|
@ -1,30 +1,31 @@
|
||||
defaults:
|
||||
- repro_base.yaml
|
||||
- _self_
|
||||
dataloader_args:
|
||||
batch_size: 10
|
||||
dataset_len: 1000
|
||||
dataset_len_val: 1
|
||||
num_workers: 8
|
||||
images_per_seq_options:
|
||||
- 2
|
||||
- 3
|
||||
- 4
|
||||
- 5
|
||||
- 6
|
||||
- 7
|
||||
- 8
|
||||
- 9
|
||||
- 10
|
||||
dataset_args:
|
||||
assert_single_seq: false
|
||||
dataset_name: co3d_multisequence
|
||||
load_point_clouds: false
|
||||
mask_depths: false
|
||||
mask_images: false
|
||||
n_frames_per_sequence: -1
|
||||
test_on_train: true
|
||||
test_restrict_sequence_id: 0
|
||||
data_source_args:
|
||||
dataloader_args:
|
||||
batch_size: 10
|
||||
dataset_len: 1000
|
||||
dataset_len_val: 1
|
||||
num_workers: 8
|
||||
images_per_seq_options:
|
||||
- 2
|
||||
- 3
|
||||
- 4
|
||||
- 5
|
||||
- 6
|
||||
- 7
|
||||
- 8
|
||||
- 9
|
||||
- 10
|
||||
dataset_args:
|
||||
assert_single_seq: false
|
||||
dataset_name: co3d_multisequence
|
||||
load_point_clouds: false
|
||||
mask_depths: false
|
||||
mask_images: false
|
||||
n_frames_per_sequence: -1
|
||||
test_on_train: true
|
||||
test_restrict_sequence_id: 0
|
||||
solver_args:
|
||||
max_epochs: 3000
|
||||
milestones:
|
||||
|
@ -1,19 +1,20 @@
|
||||
defaults:
|
||||
- repro_base
|
||||
- _self_
|
||||
dataloader_args:
|
||||
batch_size: 1
|
||||
dataset_len: 1000
|
||||
dataset_len_val: 1
|
||||
num_workers: 8
|
||||
images_per_seq_options:
|
||||
- 2
|
||||
dataset_args:
|
||||
dataset_name: co3d_singlesequence
|
||||
assert_single_seq: true
|
||||
n_frames_per_sequence: -1
|
||||
test_restrict_sequence_id: 0
|
||||
test_on_train: false
|
||||
data_source_args:
|
||||
dataloader_args:
|
||||
batch_size: 1
|
||||
dataset_len: 1000
|
||||
dataset_len_val: 1
|
||||
num_workers: 8
|
||||
images_per_seq_options:
|
||||
- 2
|
||||
dataset_args:
|
||||
dataset_name: co3d_singlesequence
|
||||
assert_single_seq: true
|
||||
n_frames_per_sequence: -1
|
||||
test_restrict_sequence_id: 0
|
||||
test_on_train: false
|
||||
generic_model_args:
|
||||
render_image_height: 800
|
||||
render_image_width: 800
|
||||
|
@ -1,18 +1,19 @@
|
||||
defaults:
|
||||
- repro_singleseq_base
|
||||
- _self_
|
||||
dataloader_args:
|
||||
batch_size: 10
|
||||
dataset_len: 1000
|
||||
dataset_len_val: 1
|
||||
num_workers: 8
|
||||
images_per_seq_options:
|
||||
- 2
|
||||
- 3
|
||||
- 4
|
||||
- 5
|
||||
- 6
|
||||
- 7
|
||||
- 8
|
||||
- 9
|
||||
- 10
|
||||
data_source_args:
|
||||
dataloader_args:
|
||||
batch_size: 10
|
||||
dataset_len: 1000
|
||||
dataset_len_val: 1
|
||||
num_workers: 8
|
||||
images_per_seq_options:
|
||||
- 2
|
||||
- 3
|
||||
- 4
|
||||
- 5
|
||||
- 6
|
||||
- 7
|
||||
- 8
|
||||
- 9
|
||||
- 10
|
||||
|
@ -64,8 +64,9 @@ import tqdm
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from packaging import version
|
||||
from pytorch3d.implicitron.dataset import utils as ds_utils
|
||||
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo, Dataloaders
|
||||
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo, Datasets
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
|
||||
from pytorch3d.implicitron.dataset.dataloader_zoo import Dataloaders
|
||||
from pytorch3d.implicitron.dataset.dataset_zoo import Datasets
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
||||
FrameData,
|
||||
ImplicitronDataset,
|
||||
@ -428,7 +429,7 @@ def trainvalidate(
|
||||
optimizer.step()
|
||||
|
||||
|
||||
def run_training(cfg: DictConfig, device: str = "cpu"):
|
||||
def run_training(cfg: DictConfig, device: str = "cpu") -> None:
|
||||
"""
|
||||
Entry point to run the training and validation loops
|
||||
based on the specified config file.
|
||||
@ -452,8 +453,9 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
|
||||
warnings.warn("Cant dump config due to insufficient permissions!")
|
||||
|
||||
# setup datasets
|
||||
datasets = dataset_zoo(**cfg.dataset_args)
|
||||
dataloaders = dataloader_zoo(datasets, **cfg.dataloader_args)
|
||||
datasource = ImplicitronDataSource(**cfg.data_source_args)
|
||||
datasets, dataloaders = datasource.get_datasets_and_dataloaders()
|
||||
task = datasource.get_task()
|
||||
|
||||
# init the model
|
||||
model, stats, optimizer_state = init_model(cfg)
|
||||
@ -464,7 +466,7 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
|
||||
|
||||
# only run evaluation on the test dataloader
|
||||
if cfg.eval_only:
|
||||
_eval_and_dump(cfg, datasets, dataloaders, model, stats, device=device)
|
||||
_eval_and_dump(cfg, task, datasets, dataloaders, model, stats, device=device)
|
||||
return
|
||||
|
||||
# init the optimizer
|
||||
@ -526,7 +528,7 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
|
||||
and cfg.test_interval > 0
|
||||
and epoch % cfg.test_interval == 0
|
||||
):
|
||||
run_eval(cfg, model, stats, dataloaders.test, device=device)
|
||||
_run_eval(model, stats, dataloaders.test, task, device=device)
|
||||
|
||||
assert stats.epoch == epoch, "inconsistent stats!"
|
||||
|
||||
@ -546,11 +548,17 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
|
||||
logger.info(f"LR change! {cur_lr} -> {new_lr}")
|
||||
|
||||
if cfg.test_when_finished:
|
||||
_eval_and_dump(cfg, datasets, dataloaders, model, stats, device=device)
|
||||
_eval_and_dump(cfg, task, datasets, dataloaders, model, stats, device=device)
|
||||
|
||||
|
||||
def _eval_and_dump(
|
||||
cfg, datasets: Datasets, dataloaders: Dataloaders, model, stats, device
|
||||
cfg,
|
||||
task: Task,
|
||||
datasets: Datasets,
|
||||
dataloaders: Dataloaders,
|
||||
model,
|
||||
stats,
|
||||
device,
|
||||
) -> None:
|
||||
"""
|
||||
Run the evaluation loop with the test data loader and
|
||||
@ -562,16 +570,13 @@ def _eval_and_dump(
|
||||
if dataloader is None:
|
||||
raise ValueError('Dataloaders have to contain the "test" entry for eval!')
|
||||
|
||||
eval_task = cfg.dataset_args["dataset_name"].split("_")[-1]
|
||||
if eval_task == "singlesequence":
|
||||
if task == Task.SINGLE_SEQUENCE:
|
||||
if datasets.train is None:
|
||||
raise ValueError("train dataset must be provided")
|
||||
all_source_cameras = _get_all_source_cameras(datasets.train)
|
||||
else:
|
||||
all_source_cameras = None
|
||||
results = run_eval(
|
||||
cfg, model, all_source_cameras, dataloader, eval_task, device=device
|
||||
)
|
||||
results = _run_eval(model, all_source_cameras, dataloader, task, device=device)
|
||||
|
||||
# add the evaluation epoch to the results
|
||||
for r in results:
|
||||
@ -598,7 +603,7 @@ def _get_eval_frame_data(frame_data):
|
||||
return frame_data_for_eval
|
||||
|
||||
|
||||
def run_eval(cfg, model, all_source_cameras, loader, task, device):
|
||||
def _run_eval(model, all_source_cameras, loader, task: Task, device):
|
||||
"""
|
||||
Run the evaluation loop on the test dataloader
|
||||
"""
|
||||
@ -672,8 +677,7 @@ def _seed_all_random_engines(seed: int):
|
||||
class ExperimentConfig:
|
||||
generic_model_args: DictConfig = get_default_args_field(GenericModel)
|
||||
solver_args: DictConfig = get_default_args_field(init_optimizer)
|
||||
dataset_args: DictConfig = get_default_args_field(dataset_zoo)
|
||||
dataloader_args: DictConfig = get_default_args_field(dataloader_zoo)
|
||||
data_source_args: DictConfig = get_default_args_field(ImplicitronDataSource)
|
||||
architecture: str = "generic"
|
||||
detect_anomaly: bool = False
|
||||
eval_only: bool = False
|
||||
|
@ -23,6 +23,7 @@ import torch
|
||||
import torch.nn.functional as Fu
|
||||
from experiment import init_model
|
||||
from omegaconf import OmegaConf
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
||||
FrameData,
|
||||
@ -326,12 +327,14 @@ def export_scenes(
|
||||
config.gpu_idx = gpu_idx
|
||||
config.exp_dir = exp_dir
|
||||
# important so that the CO3D dataset gets loaded in full
|
||||
config.dataset_args.test_on_train = False
|
||||
config.data_source_args.dataset_args.test_on_train = False
|
||||
# Set the rendering image size
|
||||
config.generic_model_args.render_image_width = render_size[0]
|
||||
config.generic_model_args.render_image_height = render_size[1]
|
||||
if restrict_sequence_name is not None:
|
||||
config.dataset_args.restrict_sequence_name = restrict_sequence_name
|
||||
config.data_source_args.dataset_args.restrict_sequence_name = (
|
||||
restrict_sequence_name
|
||||
)
|
||||
|
||||
# Set up the CUDA env for the visualization
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
@ -343,7 +346,8 @@ def export_scenes(
|
||||
model.eval()
|
||||
|
||||
# Setup the dataset
|
||||
datasets = dataset_zoo(**config.dataset_args)
|
||||
datasource = ImplicitronDataSource(**config.data_source_args)
|
||||
datasets = dataset_zoo(**datasource.dataset_args)
|
||||
dataset: Optional[ImplicitronDatasetBase] = getattr(datasets, split, None)
|
||||
if dataset is None:
|
||||
raise ValueError(f"{split} dataset not provided")
|
||||
|
48
pytorch3d/implicitron/dataset/data_source.py
Normal file
48
pytorch3d/implicitron/dataset/data_source.py
Normal file
@ -0,0 +1,48 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Tuple
|
||||
|
||||
from omegaconf import DictConfig
|
||||
from pytorch3d.implicitron.tools.config import get_default_args_field, ReplaceableBase
|
||||
|
||||
from .dataloader_zoo import dataloader_zoo, Dataloaders
|
||||
from .dataset_zoo import dataset_zoo, Datasets
|
||||
|
||||
|
||||
class Task(Enum):
|
||||
SINGLE_SEQUENCE = "singlesequence"
|
||||
MULTI_SEQUENCE = "multisequence"
|
||||
|
||||
|
||||
class DataSourceBase(ReplaceableBase):
|
||||
"""
|
||||
Base class for a data source in Implicitron. It encapsulates Dataset
|
||||
and DataLoader configuration.
|
||||
"""
|
||||
|
||||
def get_datasets_and_dataloaders(self) -> Tuple[Datasets, Dataloaders]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ImplicitronDataSource(DataSourceBase):
|
||||
"""
|
||||
Represents the data used in Implicitron. This is the only implementation
|
||||
of DataSourceBase provided.
|
||||
"""
|
||||
|
||||
dataset_args: DictConfig = get_default_args_field(dataset_zoo)
|
||||
dataloader_args: DictConfig = get_default_args_field(dataloader_zoo)
|
||||
|
||||
def get_datasets_and_dataloaders(self) -> Tuple[Datasets, Dataloaders]:
|
||||
datasets = dataset_zoo(**self.dataset_args)
|
||||
dataloaders = dataloader_zoo(datasets, **self.dataloader_args)
|
||||
return datasets, dataloaders
|
||||
|
||||
def get_task(self) -> Task:
|
||||
eval_task = self.dataset_args["dataset_name"].split("_")[-1]
|
||||
return Task(eval_task)
|
@ -7,11 +7,12 @@
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
from typing import cast, Optional, Tuple
|
||||
from typing import Any, cast, Dict, List, Optional, Tuple
|
||||
|
||||
import lpips
|
||||
import torch
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.implicitron.dataset.data_source import Task
|
||||
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
|
||||
from pytorch3d.implicitron.dataset.dataset_zoo import CO3D_CATEGORIES, dataset_zoo
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
||||
@ -47,10 +48,12 @@ def main() -> None:
|
||||
"""
|
||||
|
||||
task_results = {}
|
||||
for task in ("singlesequence", "multisequence"):
|
||||
for task in (Task.SINGLE_SEQUENCE, Task.MULTI_SEQUENCE):
|
||||
task_results[task] = []
|
||||
for category in CO3D_CATEGORIES[: (20 if task == "singlesequence" else 10)]:
|
||||
for single_sequence_id in (0, 1) if task == "singlesequence" else (None,):
|
||||
for category in CO3D_CATEGORIES[: (20 if task == Task.SINGLE_SEQUENCE else 10)]:
|
||||
for single_sequence_id in (
|
||||
(0, 1) if task == Task.SINGLE_SEQUENCE else (None,)
|
||||
):
|
||||
category_result = evaluate_dbir_for_category(
|
||||
category, task=task, single_sequence_id=single_sequence_id
|
||||
)
|
||||
@ -74,9 +77,9 @@ def main() -> None:
|
||||
|
||||
|
||||
def evaluate_dbir_for_category(
|
||||
category: str = "apple",
|
||||
category: str,
|
||||
task: Task,
|
||||
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0),
|
||||
task: str = "singlesequence",
|
||||
single_sequence_id: Optional[int] = None,
|
||||
num_workers: int = 16,
|
||||
path_manager: Optional[PathManager] = None,
|
||||
@ -101,14 +104,16 @@ def evaluate_dbir_for_category(
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
if task not in ["multisequence", "singlesequence"]:
|
||||
raise ValueError("'task' has to be either 'multisequence' or 'singlesequence'")
|
||||
dataset_name = {
|
||||
Task.SINGLE_SEQUENCE: "co3d_singlesequence",
|
||||
Task.MULTI_SEQUENCE: "co3d_multisequence",
|
||||
}[task]
|
||||
|
||||
datasets = dataset_zoo(
|
||||
category=category,
|
||||
dataset_root=os.environ["CO3D_DATASET_ROOT"],
|
||||
assert_single_seq=task == "singlesequence",
|
||||
dataset_name=f"co3d_{task}",
|
||||
assert_single_seq=task == Task.SINGLE_SEQUENCE,
|
||||
dataset_name=dataset_name,
|
||||
test_on_train=False,
|
||||
load_point_clouds=True,
|
||||
test_restrict_sequence_id=single_sequence_id,
|
||||
@ -122,7 +127,7 @@ def evaluate_dbir_for_category(
|
||||
if test_dataset is None or test_dataloader is None:
|
||||
raise ValueError("must have a test dataset.")
|
||||
|
||||
if task == "singlesequence":
|
||||
if task == Task.SINGLE_SEQUENCE:
|
||||
# all_source_cameras are needed for evaluation of the
|
||||
# target camera difficulty
|
||||
# pyre-fixme[16]: `ImplicitronDataset` has no attribute `frame_annots`.
|
||||
@ -173,7 +178,9 @@ def evaluate_dbir_for_category(
|
||||
return category_result["results"]
|
||||
|
||||
|
||||
def _print_aggregate_results(task, task_results) -> None:
|
||||
def _print_aggregate_results(
|
||||
task: Task, task_results: Dict[Task, List[List[Dict[str, Any]]]]
|
||||
) -> None:
|
||||
"""
|
||||
Prints the aggregate metrics for a given task.
|
||||
"""
|
||||
|
@ -14,6 +14,7 @@ from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch3d.implicitron.dataset.data_source import Task
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import FrameData
|
||||
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
|
||||
from pytorch3d.implicitron.models.base_model import ImplicitronRender
|
||||
@ -317,7 +318,7 @@ def eval_batch(
|
||||
if visualize:
|
||||
visualizer.show_depth(abs_.mean().item(), name_postfix, loss_mask_now)
|
||||
if break_after_visualising:
|
||||
import pdb
|
||||
import pdb # noqa: B602
|
||||
|
||||
pdb.set_trace()
|
||||
|
||||
@ -411,16 +412,16 @@ def _reduce_camera_iou_overlap(ious: torch.Tensor, topk: int = 2) -> torch.Tenso
|
||||
return ious.topk(k=min(topk, len(ious) - 1)).values.mean()
|
||||
|
||||
|
||||
def get_camera_difficulty_bin_edges(task: str):
|
||||
def _get_camera_difficulty_bin_edges(task: Task):
|
||||
"""
|
||||
Get the edges of camera difficulty bins.
|
||||
"""
|
||||
_eps = 1e-5
|
||||
if task == "multisequence":
|
||||
if task == Task.MULTI_SEQUENCE:
|
||||
# TODO: extract those to constants
|
||||
diff_bin_edges = torch.linspace(0.5, 1.0 + _eps, 4)
|
||||
diff_bin_edges[0] = 0.0 - _eps
|
||||
elif task == "singlesequence":
|
||||
elif task == Task.SINGLE_SEQUENCE:
|
||||
diff_bin_edges = torch.tensor([0.0 - _eps, 0.97, 0.98, 1.0 + _eps]).float()
|
||||
else:
|
||||
raise ValueError(f"No such eval task {task}.")
|
||||
@ -430,7 +431,7 @@ def get_camera_difficulty_bin_edges(task: str):
|
||||
|
||||
def summarize_nvs_eval_results(
|
||||
per_batch_eval_results: List[Dict[str, Any]],
|
||||
task: str = "singlesequence",
|
||||
task: Task,
|
||||
):
|
||||
"""
|
||||
Compile the per-batch evaluation results `per_batch_eval_results` into
|
||||
@ -439,7 +440,6 @@ def summarize_nvs_eval_results(
|
||||
Args:
|
||||
per_batch_eval_results: Metrics of each per-batch evaluation.
|
||||
task: The type of the new-view synthesis task.
|
||||
Either 'singlesequence' or 'multisequence'.
|
||||
|
||||
Returns:
|
||||
nvs_results_flat: A flattened dict of all aggregate metrics.
|
||||
@ -447,10 +447,10 @@ def summarize_nvs_eval_results(
|
||||
"""
|
||||
n_batches = len(per_batch_eval_results)
|
||||
eval_sets: List[Optional[str]] = []
|
||||
if task == "singlesequence":
|
||||
if task == Task.SINGLE_SEQUENCE:
|
||||
eval_sets = [None]
|
||||
# assert n_batches==100
|
||||
elif task == "multisequence":
|
||||
elif task == Task.MULTI_SEQUENCE:
|
||||
eval_sets = ["train", "test"]
|
||||
# assert n_batches==1000
|
||||
else:
|
||||
@ -466,17 +466,17 @@ def summarize_nvs_eval_results(
|
||||
# init the result database dict
|
||||
results = []
|
||||
|
||||
diff_bin_edges, diff_bin_names = get_camera_difficulty_bin_edges(task)
|
||||
diff_bin_edges, diff_bin_names = _get_camera_difficulty_bin_edges(task)
|
||||
n_diff_edges = diff_bin_edges.numel()
|
||||
|
||||
# add per set averages
|
||||
for SET in eval_sets:
|
||||
if SET is None:
|
||||
# task=='singlesequence'
|
||||
assert task == Task.SINGLE_SEQUENCE
|
||||
ok_set = torch.ones(n_batches, dtype=torch.bool)
|
||||
set_name = "test"
|
||||
else:
|
||||
# task=='multisequence'
|
||||
assert task == Task.MULTI_SEQUENCE
|
||||
ok_set = is_train == int(SET == "train")
|
||||
set_name = SET
|
||||
|
||||
@ -501,7 +501,7 @@ def summarize_nvs_eval_results(
|
||||
}
|
||||
)
|
||||
|
||||
if task == "multisequence":
|
||||
if task == Task.MULTI_SEQUENCE:
|
||||
# split based on n_src_views
|
||||
n_src_views = batch_sizes - 1
|
||||
for n_src in EVAL_N_SRC_VIEWS:
|
||||
|
Loading…
x
Reference in New Issue
Block a user