return types for dataset_zoo, dataloader_zoo

Summary: Stronger typing for these functions

Reviewed By: shapovalov

Differential Revision: D36170489

fbshipit-source-id: a2104b29dbbbcfcf91ae1d076cd6b0e3d2030c0b
This commit is contained in:
Jeremy Reizenstein 2022-05-13 05:38:14 -07:00 committed by Facebook GitHub Bot
parent 90ab219d88
commit 2c1901522a
5 changed files with 176 additions and 117 deletions

View File

@ -64,8 +64,8 @@ import tqdm
from omegaconf import DictConfig, OmegaConf
from packaging import version
from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo, Dataloaders
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo, Datasets
from pytorch3d.implicitron.dataset.implicitron_dataset import (
FrameData,
ImplicitronDataset,
@ -453,7 +453,6 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
# setup datasets
datasets = dataset_zoo(**cfg.dataset_args)
cfg.dataloader_args["dataset_name"] = cfg.dataset_args["dataset_name"]
dataloaders = dataloader_zoo(datasets, **cfg.dataloader_args)
# init the model
@ -499,7 +498,7 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
model,
stats,
epoch,
dataloaders["train"],
dataloaders.train,
optimizer,
False,
visdom_env_root=vis_utils.get_visdom_env(cfg),
@ -508,12 +507,12 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
)
# val loop (optional)
if "val" in dataloaders and epoch % cfg.validation_interval == 0:
if dataloaders.val is not None and epoch % cfg.validation_interval == 0:
trainvalidate(
model,
stats,
epoch,
dataloaders["val"],
dataloaders.val,
optimizer,
True,
visdom_env_root=vis_utils.get_visdom_env(cfg),
@ -523,11 +522,11 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
# eval loop (optional)
if (
"test" in dataloaders
dataloaders.test is not None
and cfg.test_interval > 0
and epoch % cfg.test_interval == 0
):
run_eval(cfg, model, stats, dataloaders["test"], device=device)
run_eval(cfg, model, stats, dataloaders.test, device=device)
assert stats.epoch == epoch, "inconsistent stats!"
@ -550,23 +549,28 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
_eval_and_dump(cfg, datasets, dataloaders, model, stats, device=device)
def _eval_and_dump(cfg, datasets, dataloaders, model, stats, device):
def _eval_and_dump(
cfg, datasets: Datasets, dataloaders: Dataloaders, model, stats, device
) -> None:
"""
Run the evaluation loop with the test data loader and
save the predictions to the `exp_dir`.
"""
if "test" not in dataloaders:
dataloader = dataloaders.test
if dataloader is None:
raise ValueError('Dataloaders have to contain the "test" entry for eval!')
eval_task = cfg.dataset_args["dataset_name"].split("_")[-1]
all_source_cameras = (
_get_all_source_cameras(datasets["train"])
if eval_task == "singlesequence"
else None
)
if eval_task == "singlesequence":
if datasets.train is None:
raise ValueError("train dataset must be provided")
all_source_cameras = _get_all_source_cameras(datasets.train)
else:
all_source_cameras = None
results = run_eval(
cfg, model, all_source_cameras, dataloaders["test"], eval_task, device=device
cfg, model, all_source_cameras, dataloader, eval_task, device=device
)
# add the evaluation epoch to the results

View File

@ -27,6 +27,7 @@ from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
from pytorch3d.implicitron.dataset.implicitron_dataset import (
FrameData,
ImplicitronDataset,
ImplicitronDatasetBase,
)
from pytorch3d.implicitron.dataset.utils import is_train_frame
from pytorch3d.implicitron.models.base_model import EvaluationMode
@ -342,7 +343,10 @@ def export_scenes(
model.eval()
# Setup the dataset
dataset = dataset_zoo(**config.dataset_args)[split]
datasets = dataset_zoo(**config.dataset_args)
dataset: Optional[ImplicitronDatasetBase] = getattr(datasets, split, None)
if dataset is None:
raise ValueError(f"{split} dataset not provided")
# iterate over the sequences in the dataset
for sequence_name in dataset.sequence_names():

View File

@ -4,18 +4,36 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Sequence
from dataclasses import dataclass
from typing import Optional, Sequence
import torch
from pytorch3d.implicitron.tools.config import enable_get_default_args
from .dataset_zoo import Datasets
from .implicitron_dataset import FrameData, ImplicitronDatasetBase
from .scene_batch_sampler import SceneBatchSampler
@dataclass
class Dataloaders:
"""
A provider of dataloaders for implicitron.
Members:
train: a dataloader for training
val: a dataloader for validating during training
test: a dataloader for final evaluation
"""
train: Optional[torch.utils.data.DataLoader[FrameData]]
val: Optional[torch.utils.data.DataLoader[FrameData]]
test: Optional[torch.utils.data.DataLoader[FrameData]]
def dataloader_zoo(
datasets: Dict[str, ImplicitronDatasetBase],
dataset_name: str = "co3d_singlesequence",
datasets: Datasets,
batch_size: int = 1,
num_workers: int = 0,
dataset_len: int = 1000,
@ -24,7 +42,7 @@ def dataloader_zoo(
sample_consecutive_frames: bool = False,
consecutive_frames_max_gap: int = 0,
consecutive_frames_max_gap_seconds: float = 0.1,
) -> Dict[str, torch.utils.data.DataLoader]:
) -> Dataloaders:
"""
Returns a set of dataloaders for a given set of datasets.
@ -57,44 +75,43 @@ def dataloader_zoo(
dataloaders: A dictionary containing the
`"dataset_subset_name": torch_dataloader_object` key, value pairs.
"""
if dataset_name not in ["co3d_singlesequence", "co3d_multisequence"]:
raise ValueError(f"Unsupported dataset: {dataset_name}")
dataloaders = {}
dataloader_kwargs = {"num_workers": num_workers, "collate_fn": FrameData.collate}
if dataset_name in ["co3d_singlesequence", "co3d_multisequence"]:
for dataset_set, dataset in datasets.items():
num_samples = {
"train": dataset_len,
"val": dataset_len_val,
"test": None,
}[dataset_set]
def train_or_val_loader(
dataset: Optional[ImplicitronDatasetBase], num_batches: int
) -> Optional[torch.utils.data.DataLoader]:
if dataset is None:
return None
batch_sampler = SceneBatchSampler(
dataset,
batch_size,
num_batches=len(dataset) if num_batches <= 0 else num_batches,
images_per_seq_options=images_per_seq_options,
sample_consecutive_frames=sample_consecutive_frames,
consecutive_frames_max_gap=consecutive_frames_max_gap,
consecutive_frames_max_gap_seconds=consecutive_frames_max_gap_seconds,
)
return torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler,
**dataloader_kwargs,
)
if dataset_set == "test":
batch_sampler = dataset.get_eval_batches()
else:
assert num_samples is not None
num_samples = len(dataset) if num_samples <= 0 else num_samples
batch_sampler = SceneBatchSampler(
dataset,
batch_size,
num_batches=num_samples,
images_per_seq_options=images_per_seq_options,
sample_consecutive_frames=sample_consecutive_frames,
consecutive_frames_max_gap=consecutive_frames_max_gap,
)
dataloaders[dataset_set] = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_sampler=batch_sampler,
collate_fn=FrameData.collate,
)
train_dataloader = train_or_val_loader(datasets.train, dataset_len)
val_dataloader = train_or_val_loader(datasets.val, dataset_len_val)
test_dataset = datasets.test
if test_dataset is not None:
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
batch_sampler=test_dataset.get_eval_batches(),
**dataloader_kwargs,
)
else:
raise ValueError(f"Unsupported dataset: {dataset_name}")
test_dataloader = None
return dataloaders
return Dataloaders(train=train_dataloader, val=val_dataloader, test=test_dataloader)
enable_get_default_args(dataloader_zoo)

View File

@ -5,10 +5,10 @@
# LICENSE file in the root directory of this source tree.
import copy
import json
import os
from typing import Any, Dict, List, Optional, Sequence
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Optional, Sequence
from iopath.common.file_io import PathManager
from pytorch3d.implicitron.tools.config import enable_get_default_args
@ -52,6 +52,34 @@ CO3D_CATEGORIES: List[str] = list(reversed([
_CO3D_DATASET_ROOT: str = os.getenv("CO3D_DATASET_ROOT", "")
@dataclass
class Datasets:
"""
A provider of datasets for implicitron.
Members:
train: a dataset for training
val: a dataset for validating during training
test: a dataset for final evaluation
"""
train: Optional[ImplicitronDatasetBase]
val: Optional[ImplicitronDatasetBase]
test: Optional[ImplicitronDatasetBase]
def iter_datasets(self) -> Iterator[ImplicitronDatasetBase]:
"""
Iterator over all datasets.
"""
if self.train is not None:
yield self.train
if self.val is not None:
yield self.val
if self.test is not None:
yield self.test
def dataset_zoo(
dataset_name: str = "co3d_singlesequence",
dataset_root: str = _CO3D_DATASET_ROOT,
@ -69,7 +97,7 @@ def dataset_zoo(
only_test_set: bool = False,
aux_dataset_kwargs: dict = DATASET_CONFIGS["default"],
path_manager: Optional[PathManager] = None,
) -> Dict[str, ImplicitronDatasetBase]:
) -> Datasets:
"""
Generates the training / validation and testing dataset objects.
@ -101,12 +129,31 @@ def dataset_zoo(
datasets: A dictionary containing the
`"dataset_subset_name": torch_dataset_object` key, value pairs.
"""
datasets = {}
if only_test_set and test_on_train:
raise ValueError("Cannot have only_test_set and test_on_train")
# TODO:
# - implement loading multiple categories
if dataset_name in ["co3d_singlesequence", "co3d_multisequence"]:
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
subset_lists_file = os.path.join(dataset_root, category, "set_lists.json")
common_kwargs = {
"dataset_root": dataset_root,
"limit_to": limit_to,
"limit_sequences_to": limit_sequences_to,
"load_point_clouds": load_point_clouds,
"mask_images": mask_images,
"mask_depths": mask_depths,
"pick_sequence": restrict_sequence_name,
"path_manager": path_manager,
"frame_annotations_file": frame_file,
"sequence_annotations_file": sequence_file,
"subset_lists_file": subset_lists_file,
**aux_dataset_kwargs,
}
# This maps the common names of the dataset subsets ("train"/"val"/"test")
# to the names of the subsets in the CO3D dataset.
set_names_mapping = _get_co3d_set_names_mapping(
@ -156,65 +203,48 @@ def dataset_zoo(
# overwrite the restrict_sequence_name
restrict_sequence_name = [eval_sequence_name]
for dataset, subsets in set_names_mapping.items():
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
sequence_file = os.path.join(
dataset_root, category, "sequence_annotations.jgz"
train_dataset = None
if not only_test_set:
train_dataset = ImplicitronDataset(
n_frames_per_sequence=n_frames_per_sequence,
subsets=set_names_mapping["train"],
**common_kwargs,
)
subset_lists_file = os.path.join(dataset_root, category, "set_lists.json")
# TODO: maybe directly in param list
params = {
**copy.deepcopy(aux_dataset_kwargs),
"frame_annotations_file": frame_file,
"sequence_annotations_file": sequence_file,
"subset_lists_file": subset_lists_file,
"dataset_root": dataset_root,
"limit_to": limit_to,
"limit_sequences_to": limit_sequences_to,
"n_frames_per_sequence": n_frames_per_sequence
if dataset == "train"
else -1,
"subsets": subsets,
"load_point_clouds": load_point_clouds,
"mask_images": mask_images,
"mask_depths": mask_depths,
"pick_sequence": restrict_sequence_name,
"path_manager": path_manager,
}
datasets[dataset] = ImplicitronDataset(**params)
if dataset == "test":
if len(restrict_sequence_name) > 0:
eval_batch_index = [
b for b in eval_batch_index if b[0][0] in restrict_sequence_name
]
datasets[dataset].eval_batches = datasets[
dataset
].seq_frame_index_to_dataset_index(eval_batch_index)
if assert_single_seq:
# check theres only one sequence in all datasets
assert (
len(
{
e["frame_annotation"].sequence_name
for dset in datasets.values()
for e in dset.frame_annots
}
)
<= 1
), "Multiple sequences loaded but expected one"
if test_on_train:
assert train_dataset is not None
val_dataset = test_dataset = train_dataset
else:
val_dataset = ImplicitronDataset(
n_frames_per_sequence=1,
subsets=set_names_mapping["val"],
**common_kwargs,
)
test_dataset = ImplicitronDataset(
n_frames_per_sequence=1,
subsets=set_names_mapping["test"],
**common_kwargs,
)
if len(restrict_sequence_name) > 0:
eval_batch_index = [
b for b in eval_batch_index if b[0][0] in restrict_sequence_name
]
test_dataset.eval_batches = test_dataset.seq_frame_index_to_dataset_index(
eval_batch_index
)
datasets = Datasets(train=train_dataset, val=val_dataset, test=test_dataset)
else:
raise ValueError(f"Unsupported dataset: {dataset_name}")
if test_on_train:
datasets["val"] = datasets["train"]
datasets["test"] = datasets["train"]
if assert_single_seq:
# check there's only one sequence in all datasets
sequence_names = {
sequence_name
for dset in datasets.iter_datasets()
for sequence_name in dset.sequence_names()
}
if len(sequence_names) > 1:
raise ValueError("Multiple sequences loaded but expected one")
return datasets
@ -231,6 +261,11 @@ def _get_co3d_set_names_mapping(
Returns the mapping of the common dataset subset names ("train"/"val"/"test")
to the names of the corresponding subsets in the CO3D dataset
("test_known"/"test_unseen"/"train_known"/"train_unseen").
The keys returned will be
- train (if not only_test)
- val (if not test_on_train)
- test (if not test_on_train)
"""
single_seq = dataset_name == "co3d_singlesequence"

View File

@ -115,13 +115,12 @@ def evaluate_dbir_for_category(
path_manager=path_manager,
)
dataloaders = dataloader_zoo(
datasets,
dataset_name=f"co3d_{task}",
)
dataloaders = dataloader_zoo(datasets)
test_dataset = datasets["test"]
test_dataloader = dataloaders["test"]
test_dataset = datasets.test
test_dataloader = dataloaders.test
if test_dataset is None or test_dataloader is None:
raise ValueError("must have a test dataset.")
if task == "singlesequence":
# all_source_cameras are needed for evaluation of the