return types for dataset_zoo, dataloader_zoo

Summary: Stronger typing for these functions

Reviewed By: shapovalov

Differential Revision: D36170489

fbshipit-source-id: a2104b29dbbbcfcf91ae1d076cd6b0e3d2030c0b
This commit is contained in:
Jeremy Reizenstein
2022-05-13 05:38:14 -07:00
committed by Facebook GitHub Bot
parent 90ab219d88
commit 2c1901522a
5 changed files with 176 additions and 117 deletions

View File

@@ -64,8 +64,8 @@ import tqdm
from omegaconf import DictConfig, OmegaConf
from packaging import version
from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo, Dataloaders
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo, Datasets
from pytorch3d.implicitron.dataset.implicitron_dataset import (
FrameData,
ImplicitronDataset,
@@ -453,7 +453,6 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
# setup datasets
datasets = dataset_zoo(**cfg.dataset_args)
cfg.dataloader_args["dataset_name"] = cfg.dataset_args["dataset_name"]
dataloaders = dataloader_zoo(datasets, **cfg.dataloader_args)
# init the model
@@ -499,7 +498,7 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
model,
stats,
epoch,
dataloaders["train"],
dataloaders.train,
optimizer,
False,
visdom_env_root=vis_utils.get_visdom_env(cfg),
@@ -508,12 +507,12 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
)
# val loop (optional)
if "val" in dataloaders and epoch % cfg.validation_interval == 0:
if dataloaders.val is not None and epoch % cfg.validation_interval == 0:
trainvalidate(
model,
stats,
epoch,
dataloaders["val"],
dataloaders.val,
optimizer,
True,
visdom_env_root=vis_utils.get_visdom_env(cfg),
@@ -523,11 +522,11 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
# eval loop (optional)
if (
"test" in dataloaders
dataloaders.test is not None
and cfg.test_interval > 0
and epoch % cfg.test_interval == 0
):
run_eval(cfg, model, stats, dataloaders["test"], device=device)
run_eval(cfg, model, stats, dataloaders.test, device=device)
assert stats.epoch == epoch, "inconsistent stats!"
@@ -550,23 +549,28 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
_eval_and_dump(cfg, datasets, dataloaders, model, stats, device=device)
def _eval_and_dump(cfg, datasets, dataloaders, model, stats, device):
def _eval_and_dump(
cfg, datasets: Datasets, dataloaders: Dataloaders, model, stats, device
) -> None:
"""
Run the evaluation loop with the test data loader and
save the predictions to the `exp_dir`.
"""
if "test" not in dataloaders:
dataloader = dataloaders.test
if dataloader is None:
raise ValueError('Dataloaders have to contain the "test" entry for eval!')
eval_task = cfg.dataset_args["dataset_name"].split("_")[-1]
all_source_cameras = (
_get_all_source_cameras(datasets["train"])
if eval_task == "singlesequence"
else None
)
if eval_task == "singlesequence":
if datasets.train is None:
raise ValueError("train dataset must be provided")
all_source_cameras = _get_all_source_cameras(datasets.train)
else:
all_source_cameras = None
results = run_eval(
cfg, model, all_source_cameras, dataloaders["test"], eval_task, device=device
cfg, model, all_source_cameras, dataloader, eval_task, device=device
)
# add the evaluation epoch to the results

View File

@@ -27,6 +27,7 @@ from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
from pytorch3d.implicitron.dataset.implicitron_dataset import (
FrameData,
ImplicitronDataset,
ImplicitronDatasetBase,
)
from pytorch3d.implicitron.dataset.utils import is_train_frame
from pytorch3d.implicitron.models.base_model import EvaluationMode
@@ -342,7 +343,10 @@ def export_scenes(
model.eval()
# Setup the dataset
dataset = dataset_zoo(**config.dataset_args)[split]
datasets = dataset_zoo(**config.dataset_args)
dataset: Optional[ImplicitronDatasetBase] = getattr(datasets, split, None)
if dataset is None:
raise ValueError(f"{split} dataset not provided")
# iterate over the sequences in the dataset
for sequence_name in dataset.sequence_names():