data_source

Summary:
Move dataset_args and dataloader_args from ExperimentConfig into a new member called datasource so that it can contain replaceables.

Also add enum Task for task type.

Reviewed By: shapovalov

Differential Revision: D36201719

fbshipit-source-id: 47d6967bfea3b7b146b6bbd1572e0457c9365871
This commit is contained in:
Jeremy Reizenstein 2022-05-20 07:50:30 -07:00 committed by Facebook GitHub Bot
parent 9ec9d057cc
commit 73dc109dba
10 changed files with 194 additions and 124 deletions

View File

@ -66,7 +66,8 @@ If you have a custom `experiment.py` script (as in the Option 2 above), replace
To run training, pass a yaml config file, followed by a list of overridden arguments. To run training, pass a yaml config file, followed by a list of overridden arguments.
For example, to train NeRF on the first skateboard sequence from CO3D dataset, you can run: For example, to train NeRF on the first skateboard sequence from CO3D dataset, you can run:
```shell ```shell
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf dataset_args.dataset_root=<DATASET_ROOT> dataset_args.category='skateboard' dataset_args.test_restrict_sequence_id=0 test_when_finished=True exp_dir=<CHECKPOINT_DIR> dataset_args=data_source_args.dataset_args
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf $dataset_args.dataset_root=<DATASET_ROOT> $dataset_args.category='skateboard' $dataset_args.test_restrict_sequence_id=0 test_when_finished=True exp_dir=<CHECKPOINT_DIR>
``` ```
Here, `--config-path` points to the config path relative to `pytorch3d_implicitron_runner` location; Here, `--config-path` points to the config path relative to `pytorch3d_implicitron_runner` location;
@ -84,7 +85,8 @@ To run evaluation on the latest checkpoint after (or during) training, simply ad
E.g. for executing the evaluation on the NeRF skateboard sequence, you can run: E.g. for executing the evaluation on the NeRF skateboard sequence, you can run:
```shell ```shell
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf dataset_args.dataset_root=<CO3D_DATASET_ROOT> dataset_args.category='skateboard' dataset_args.test_restrict_sequence_id=0 exp_dir=<CHECKPOINT_DIR> eval_only=True dataset_args=data_source_args.dataset_args
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf $dataset_args.dataset_root=<CO3D_DATASET_ROOT> $dataset_args.category='skateboard' $dataset_args.test_restrict_sequence_id=0 exp_dir=<CHECKPOINT_DIR> eval_only=True
``` ```
Evaluation prints the metrics to `stdout` and dumps them to a json file in `exp_dir`. Evaluation prints the metrics to `stdout` and dumps them to a json file in `exp_dir`.
@ -202,7 +204,7 @@ to replace the implementation and potentially override the parameters.
# Code and config structure # Code and config structure
As per above, the config structure is parsed automatically from the module hierarchy. As per above, the config structure is parsed automatically from the module hierarchy.
In particular, model parameters are contained in `generic_model_args` node, and dataset parameters in `dataset_args` node. In particular, model parameters are contained in `generic_model_args` node, and dataset parameters in `data_source_args` node.
Here is the class structure (single-line edges show aggregation, while double lines show available implementations): Here is the class structure (single-line edges show aggregation, while double lines show available implementations):
``` ```
@ -233,8 +235,9 @@ generic_model_args: GenericModel
╘== AngleWeightedReductionFeatureAggregator ╘== AngleWeightedReductionFeatureAggregator
╘== ReductionFeatureAggregator ╘== ReductionFeatureAggregator
solver_args: init_optimizer solver_args: init_optimizer
dataset_args: dataset_zoo data_source_args: ImplicitronDataSource
dataloader_args: dataloader_zoo └-- dataset_args
└-- dataloader_args
``` ```
Please look at the annotations of the respective classes or functions for the lists of hyperparameters. Please look at the annotations of the respective classes or functions for the lists of hyperparameters.

View File

@ -5,7 +5,8 @@ exp_dir: ./data/exps/base/
architecture: generic architecture: generic
visualize_interval: 0 visualize_interval: 0
visdom_port: 8097 visdom_port: 8097
dataloader_args: data_source_args:
dataloader_args:
batch_size: 10 batch_size: 10
dataset_len: 1000 dataset_len: 1000
dataset_len_val: 1 dataset_len_val: 1
@ -20,7 +21,7 @@ dataloader_args:
- 8 - 8
- 9 - 9
- 10 - 10
dataset_args: dataset_args:
dataset_root: ${oc.env:CO3D_DATASET_ROOT} dataset_root: ${oc.env:CO3D_DATASET_ROOT}
load_point_clouds: false load_point_clouds: false
mask_depths: false mask_depths: false

View File

@ -1,7 +1,8 @@
defaults: defaults:
- repro_base.yaml - repro_base.yaml
- _self_ - _self_
dataloader_args: data_source_args:
dataloader_args:
batch_size: 10 batch_size: 10
dataset_len: 1000 dataset_len: 1000
dataset_len_val: 1 dataset_len_val: 1
@ -16,7 +17,7 @@ dataloader_args:
- 8 - 8
- 9 - 9
- 10 - 10
dataset_args: dataset_args:
assert_single_seq: false assert_single_seq: false
dataset_name: co3d_multisequence dataset_name: co3d_multisequence
load_point_clouds: false load_point_clouds: false

View File

@ -1,14 +1,15 @@
defaults: defaults:
- repro_base - repro_base
- _self_ - _self_
dataloader_args: data_source_args:
dataloader_args:
batch_size: 1 batch_size: 1
dataset_len: 1000 dataset_len: 1000
dataset_len_val: 1 dataset_len_val: 1
num_workers: 8 num_workers: 8
images_per_seq_options: images_per_seq_options:
- 2 - 2
dataset_args: dataset_args:
dataset_name: co3d_singlesequence dataset_name: co3d_singlesequence
assert_single_seq: true assert_single_seq: true
n_frames_per_sequence: -1 n_frames_per_sequence: -1

View File

@ -1,7 +1,8 @@
defaults: defaults:
- repro_singleseq_base - repro_singleseq_base
- _self_ - _self_
dataloader_args: data_source_args:
dataloader_args:
batch_size: 10 batch_size: 10
dataset_len: 1000 dataset_len: 1000
dataset_len_val: 1 dataset_len_val: 1

View File

@ -64,8 +64,9 @@ import tqdm
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from packaging import version from packaging import version
from pytorch3d.implicitron.dataset import utils as ds_utils from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo, Dataloaders from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo, Datasets from pytorch3d.implicitron.dataset.dataloader_zoo import Dataloaders
from pytorch3d.implicitron.dataset.dataset_zoo import Datasets
from pytorch3d.implicitron.dataset.implicitron_dataset import ( from pytorch3d.implicitron.dataset.implicitron_dataset import (
FrameData, FrameData,
ImplicitronDataset, ImplicitronDataset,
@ -428,7 +429,7 @@ def trainvalidate(
optimizer.step() optimizer.step()
def run_training(cfg: DictConfig, device: str = "cpu"): def run_training(cfg: DictConfig, device: str = "cpu") -> None:
""" """
Entry point to run the training and validation loops Entry point to run the training and validation loops
based on the specified config file. based on the specified config file.
@ -452,8 +453,9 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
warnings.warn("Cant dump config due to insufficient permissions!") warnings.warn("Cant dump config due to insufficient permissions!")
# setup datasets # setup datasets
datasets = dataset_zoo(**cfg.dataset_args) datasource = ImplicitronDataSource(**cfg.data_source_args)
dataloaders = dataloader_zoo(datasets, **cfg.dataloader_args) datasets, dataloaders = datasource.get_datasets_and_dataloaders()
task = datasource.get_task()
# init the model # init the model
model, stats, optimizer_state = init_model(cfg) model, stats, optimizer_state = init_model(cfg)
@ -464,7 +466,7 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
# only run evaluation on the test dataloader # only run evaluation on the test dataloader
if cfg.eval_only: if cfg.eval_only:
_eval_and_dump(cfg, datasets, dataloaders, model, stats, device=device) _eval_and_dump(cfg, task, datasets, dataloaders, model, stats, device=device)
return return
# init the optimizer # init the optimizer
@ -526,7 +528,7 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
and cfg.test_interval > 0 and cfg.test_interval > 0
and epoch % cfg.test_interval == 0 and epoch % cfg.test_interval == 0
): ):
run_eval(cfg, model, stats, dataloaders.test, device=device) _run_eval(model, stats, dataloaders.test, task, device=device)
assert stats.epoch == epoch, "inconsistent stats!" assert stats.epoch == epoch, "inconsistent stats!"
@ -546,11 +548,17 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
logger.info(f"LR change! {cur_lr} -> {new_lr}") logger.info(f"LR change! {cur_lr} -> {new_lr}")
if cfg.test_when_finished: if cfg.test_when_finished:
_eval_and_dump(cfg, datasets, dataloaders, model, stats, device=device) _eval_and_dump(cfg, task, datasets, dataloaders, model, stats, device=device)
def _eval_and_dump( def _eval_and_dump(
cfg, datasets: Datasets, dataloaders: Dataloaders, model, stats, device cfg,
task: Task,
datasets: Datasets,
dataloaders: Dataloaders,
model,
stats,
device,
) -> None: ) -> None:
""" """
Run the evaluation loop with the test data loader and Run the evaluation loop with the test data loader and
@ -562,16 +570,13 @@ def _eval_and_dump(
if dataloader is None: if dataloader is None:
raise ValueError('Dataloaders have to contain the "test" entry for eval!') raise ValueError('Dataloaders have to contain the "test" entry for eval!')
eval_task = cfg.dataset_args["dataset_name"].split("_")[-1] if task == Task.SINGLE_SEQUENCE:
if eval_task == "singlesequence":
if datasets.train is None: if datasets.train is None:
raise ValueError("train dataset must be provided") raise ValueError("train dataset must be provided")
all_source_cameras = _get_all_source_cameras(datasets.train) all_source_cameras = _get_all_source_cameras(datasets.train)
else: else:
all_source_cameras = None all_source_cameras = None
results = run_eval( results = _run_eval(model, all_source_cameras, dataloader, task, device=device)
cfg, model, all_source_cameras, dataloader, eval_task, device=device
)
# add the evaluation epoch to the results # add the evaluation epoch to the results
for r in results: for r in results:
@ -598,7 +603,7 @@ def _get_eval_frame_data(frame_data):
return frame_data_for_eval return frame_data_for_eval
def run_eval(cfg, model, all_source_cameras, loader, task, device): def _run_eval(model, all_source_cameras, loader, task: Task, device):
""" """
Run the evaluation loop on the test dataloader Run the evaluation loop on the test dataloader
""" """
@ -672,8 +677,7 @@ def _seed_all_random_engines(seed: int):
class ExperimentConfig: class ExperimentConfig:
generic_model_args: DictConfig = get_default_args_field(GenericModel) generic_model_args: DictConfig = get_default_args_field(GenericModel)
solver_args: DictConfig = get_default_args_field(init_optimizer) solver_args: DictConfig = get_default_args_field(init_optimizer)
dataset_args: DictConfig = get_default_args_field(dataset_zoo) data_source_args: DictConfig = get_default_args_field(ImplicitronDataSource)
dataloader_args: DictConfig = get_default_args_field(dataloader_zoo)
architecture: str = "generic" architecture: str = "generic"
detect_anomaly: bool = False detect_anomaly: bool = False
eval_only: bool = False eval_only: bool = False

View File

@ -23,6 +23,7 @@ import torch
import torch.nn.functional as Fu import torch.nn.functional as Fu
from experiment import init_model from experiment import init_model
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
from pytorch3d.implicitron.dataset.implicitron_dataset import ( from pytorch3d.implicitron.dataset.implicitron_dataset import (
FrameData, FrameData,
@ -326,12 +327,14 @@ def export_scenes(
config.gpu_idx = gpu_idx config.gpu_idx = gpu_idx
config.exp_dir = exp_dir config.exp_dir = exp_dir
# important so that the CO3D dataset gets loaded in full # important so that the CO3D dataset gets loaded in full
config.dataset_args.test_on_train = False config.data_source_args.dataset_args.test_on_train = False
# Set the rendering image size # Set the rendering image size
config.generic_model_args.render_image_width = render_size[0] config.generic_model_args.render_image_width = render_size[0]
config.generic_model_args.render_image_height = render_size[1] config.generic_model_args.render_image_height = render_size[1]
if restrict_sequence_name is not None: if restrict_sequence_name is not None:
config.dataset_args.restrict_sequence_name = restrict_sequence_name config.data_source_args.dataset_args.restrict_sequence_name = (
restrict_sequence_name
)
# Set up the CUDA env for the visualization # Set up the CUDA env for the visualization
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
@ -343,7 +346,8 @@ def export_scenes(
model.eval() model.eval()
# Setup the dataset # Setup the dataset
datasets = dataset_zoo(**config.dataset_args) datasource = ImplicitronDataSource(**config.data_source_args)
datasets = dataset_zoo(**datasource.dataset_args)
dataset: Optional[ImplicitronDatasetBase] = getattr(datasets, split, None) dataset: Optional[ImplicitronDatasetBase] = getattr(datasets, split, None)
if dataset is None: if dataset is None:
raise ValueError(f"{split} dataset not provided") raise ValueError(f"{split} dataset not provided")

View File

@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from enum import Enum
from typing import Tuple
from omegaconf import DictConfig
from pytorch3d.implicitron.tools.config import get_default_args_field, ReplaceableBase
from .dataloader_zoo import dataloader_zoo, Dataloaders
from .dataset_zoo import dataset_zoo, Datasets
class Task(Enum):
SINGLE_SEQUENCE = "singlesequence"
MULTI_SEQUENCE = "multisequence"
class DataSourceBase(ReplaceableBase):
"""
Base class for a data source in Implicitron. It encapsulates Dataset
and DataLoader configuration.
"""
def get_datasets_and_dataloaders(self) -> Tuple[Datasets, Dataloaders]:
raise NotImplementedError()
class ImplicitronDataSource(DataSourceBase):
"""
Represents the data used in Implicitron. This is the only implementation
of DataSourceBase provided.
"""
dataset_args: DictConfig = get_default_args_field(dataset_zoo)
dataloader_args: DictConfig = get_default_args_field(dataloader_zoo)
def get_datasets_and_dataloaders(self) -> Tuple[Datasets, Dataloaders]:
datasets = dataset_zoo(**self.dataset_args)
dataloaders = dataloader_zoo(datasets, **self.dataloader_args)
return datasets, dataloaders
def get_task(self) -> Task:
eval_task = self.dataset_args["dataset_name"].split("_")[-1]
return Task(eval_task)

View File

@ -7,11 +7,12 @@
import dataclasses import dataclasses
import os import os
from typing import cast, Optional, Tuple from typing import Any, cast, Dict, List, Optional, Tuple
import lpips import lpips
import torch import torch
from iopath.common.file_io import PathManager from iopath.common.file_io import PathManager
from pytorch3d.implicitron.dataset.data_source import Task
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
from pytorch3d.implicitron.dataset.dataset_zoo import CO3D_CATEGORIES, dataset_zoo from pytorch3d.implicitron.dataset.dataset_zoo import CO3D_CATEGORIES, dataset_zoo
from pytorch3d.implicitron.dataset.implicitron_dataset import ( from pytorch3d.implicitron.dataset.implicitron_dataset import (
@ -47,10 +48,12 @@ def main() -> None:
""" """
task_results = {} task_results = {}
for task in ("singlesequence", "multisequence"): for task in (Task.SINGLE_SEQUENCE, Task.MULTI_SEQUENCE):
task_results[task] = [] task_results[task] = []
for category in CO3D_CATEGORIES[: (20 if task == "singlesequence" else 10)]: for category in CO3D_CATEGORIES[: (20 if task == Task.SINGLE_SEQUENCE else 10)]:
for single_sequence_id in (0, 1) if task == "singlesequence" else (None,): for single_sequence_id in (
(0, 1) if task == Task.SINGLE_SEQUENCE else (None,)
):
category_result = evaluate_dbir_for_category( category_result = evaluate_dbir_for_category(
category, task=task, single_sequence_id=single_sequence_id category, task=task, single_sequence_id=single_sequence_id
) )
@ -74,9 +77,9 @@ def main() -> None:
def evaluate_dbir_for_category( def evaluate_dbir_for_category(
category: str = "apple", category: str,
task: Task,
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0), bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0),
task: str = "singlesequence",
single_sequence_id: Optional[int] = None, single_sequence_id: Optional[int] = None,
num_workers: int = 16, num_workers: int = 16,
path_manager: Optional[PathManager] = None, path_manager: Optional[PathManager] = None,
@ -101,14 +104,16 @@ def evaluate_dbir_for_category(
torch.manual_seed(42) torch.manual_seed(42)
if task not in ["multisequence", "singlesequence"]: dataset_name = {
raise ValueError("'task' has to be either 'multisequence' or 'singlesequence'") Task.SINGLE_SEQUENCE: "co3d_singlesequence",
Task.MULTI_SEQUENCE: "co3d_multisequence",
}[task]
datasets = dataset_zoo( datasets = dataset_zoo(
category=category, category=category,
dataset_root=os.environ["CO3D_DATASET_ROOT"], dataset_root=os.environ["CO3D_DATASET_ROOT"],
assert_single_seq=task == "singlesequence", assert_single_seq=task == Task.SINGLE_SEQUENCE,
dataset_name=f"co3d_{task}", dataset_name=dataset_name,
test_on_train=False, test_on_train=False,
load_point_clouds=True, load_point_clouds=True,
test_restrict_sequence_id=single_sequence_id, test_restrict_sequence_id=single_sequence_id,
@ -122,7 +127,7 @@ def evaluate_dbir_for_category(
if test_dataset is None or test_dataloader is None: if test_dataset is None or test_dataloader is None:
raise ValueError("must have a test dataset.") raise ValueError("must have a test dataset.")
if task == "singlesequence": if task == Task.SINGLE_SEQUENCE:
# all_source_cameras are needed for evaluation of the # all_source_cameras are needed for evaluation of the
# target camera difficulty # target camera difficulty
# pyre-fixme[16]: `ImplicitronDataset` has no attribute `frame_annots`. # pyre-fixme[16]: `ImplicitronDataset` has no attribute `frame_annots`.
@ -173,7 +178,9 @@ def evaluate_dbir_for_category(
return category_result["results"] return category_result["results"]
def _print_aggregate_results(task, task_results) -> None: def _print_aggregate_results(
task: Task, task_results: Dict[Task, List[List[Dict[str, Any]]]]
) -> None:
""" """
Prints the aggregate metrics for a given task. Prints the aggregate metrics for a given task.
""" """

View File

@ -14,6 +14,7 @@ from typing import Any, Dict, List, Optional, Sequence, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from pytorch3d.implicitron.dataset.data_source import Task
from pytorch3d.implicitron.dataset.implicitron_dataset import FrameData from pytorch3d.implicitron.dataset.implicitron_dataset import FrameData
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
from pytorch3d.implicitron.models.base_model import ImplicitronRender from pytorch3d.implicitron.models.base_model import ImplicitronRender
@ -317,7 +318,7 @@ def eval_batch(
if visualize: if visualize:
visualizer.show_depth(abs_.mean().item(), name_postfix, loss_mask_now) visualizer.show_depth(abs_.mean().item(), name_postfix, loss_mask_now)
if break_after_visualising: if break_after_visualising:
import pdb import pdb # noqa: B602
pdb.set_trace() pdb.set_trace()
@ -411,16 +412,16 @@ def _reduce_camera_iou_overlap(ious: torch.Tensor, topk: int = 2) -> torch.Tenso
return ious.topk(k=min(topk, len(ious) - 1)).values.mean() return ious.topk(k=min(topk, len(ious) - 1)).values.mean()
def get_camera_difficulty_bin_edges(task: str): def _get_camera_difficulty_bin_edges(task: Task):
""" """
Get the edges of camera difficulty bins. Get the edges of camera difficulty bins.
""" """
_eps = 1e-5 _eps = 1e-5
if task == "multisequence": if task == Task.MULTI_SEQUENCE:
# TODO: extract those to constants # TODO: extract those to constants
diff_bin_edges = torch.linspace(0.5, 1.0 + _eps, 4) diff_bin_edges = torch.linspace(0.5, 1.0 + _eps, 4)
diff_bin_edges[0] = 0.0 - _eps diff_bin_edges[0] = 0.0 - _eps
elif task == "singlesequence": elif task == Task.SINGLE_SEQUENCE:
diff_bin_edges = torch.tensor([0.0 - _eps, 0.97, 0.98, 1.0 + _eps]).float() diff_bin_edges = torch.tensor([0.0 - _eps, 0.97, 0.98, 1.0 + _eps]).float()
else: else:
raise ValueError(f"No such eval task {task}.") raise ValueError(f"No such eval task {task}.")
@ -430,7 +431,7 @@ def get_camera_difficulty_bin_edges(task: str):
def summarize_nvs_eval_results( def summarize_nvs_eval_results(
per_batch_eval_results: List[Dict[str, Any]], per_batch_eval_results: List[Dict[str, Any]],
task: str = "singlesequence", task: Task,
): ):
""" """
Compile the per-batch evaluation results `per_batch_eval_results` into Compile the per-batch evaluation results `per_batch_eval_results` into
@ -439,7 +440,6 @@ def summarize_nvs_eval_results(
Args: Args:
per_batch_eval_results: Metrics of each per-batch evaluation. per_batch_eval_results: Metrics of each per-batch evaluation.
task: The type of the new-view synthesis task. task: The type of the new-view synthesis task.
Either 'singlesequence' or 'multisequence'.
Returns: Returns:
nvs_results_flat: A flattened dict of all aggregate metrics. nvs_results_flat: A flattened dict of all aggregate metrics.
@ -447,10 +447,10 @@ def summarize_nvs_eval_results(
""" """
n_batches = len(per_batch_eval_results) n_batches = len(per_batch_eval_results)
eval_sets: List[Optional[str]] = [] eval_sets: List[Optional[str]] = []
if task == "singlesequence": if task == Task.SINGLE_SEQUENCE:
eval_sets = [None] eval_sets = [None]
# assert n_batches==100 # assert n_batches==100
elif task == "multisequence": elif task == Task.MULTI_SEQUENCE:
eval_sets = ["train", "test"] eval_sets = ["train", "test"]
# assert n_batches==1000 # assert n_batches==1000
else: else:
@ -466,17 +466,17 @@ def summarize_nvs_eval_results(
# init the result database dict # init the result database dict
results = [] results = []
diff_bin_edges, diff_bin_names = get_camera_difficulty_bin_edges(task) diff_bin_edges, diff_bin_names = _get_camera_difficulty_bin_edges(task)
n_diff_edges = diff_bin_edges.numel() n_diff_edges = diff_bin_edges.numel()
# add per set averages # add per set averages
for SET in eval_sets: for SET in eval_sets:
if SET is None: if SET is None:
# task=='singlesequence' assert task == Task.SINGLE_SEQUENCE
ok_set = torch.ones(n_batches, dtype=torch.bool) ok_set = torch.ones(n_batches, dtype=torch.bool)
set_name = "test" set_name = "test"
else: else:
# task=='multisequence' assert task == Task.MULTI_SEQUENCE
ok_set = is_train == int(SET == "train") ok_set = is_train == int(SET == "train")
set_name = SET set_name = SET
@ -501,7 +501,7 @@ def summarize_nvs_eval_results(
} }
) )
if task == "multisequence": if task == Task.MULTI_SEQUENCE:
# split based on n_src_views # split based on n_src_views
n_src_views = batch_sizes - 1 n_src_views = batch_sizes - 1
for n_src in EVAL_N_SRC_VIEWS: for n_src in EVAL_N_SRC_VIEWS: