mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-14 19:36:23 +08:00
return types for dataset_zoo, dataloader_zoo
Summary: Stronger typing for these functions Reviewed By: shapovalov Differential Revision: D36170489 fbshipit-source-id: a2104b29dbbbcfcf91ae1d076cd6b0e3d2030c0b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
90ab219d88
commit
2c1901522a
@@ -64,8 +64,8 @@ import tqdm
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from packaging import version
|
||||
from pytorch3d.implicitron.dataset import utils as ds_utils
|
||||
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
|
||||
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
|
||||
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo, Dataloaders
|
||||
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo, Datasets
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
||||
FrameData,
|
||||
ImplicitronDataset,
|
||||
@@ -453,7 +453,6 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
|
||||
|
||||
# setup datasets
|
||||
datasets = dataset_zoo(**cfg.dataset_args)
|
||||
cfg.dataloader_args["dataset_name"] = cfg.dataset_args["dataset_name"]
|
||||
dataloaders = dataloader_zoo(datasets, **cfg.dataloader_args)
|
||||
|
||||
# init the model
|
||||
@@ -499,7 +498,7 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
|
||||
model,
|
||||
stats,
|
||||
epoch,
|
||||
dataloaders["train"],
|
||||
dataloaders.train,
|
||||
optimizer,
|
||||
False,
|
||||
visdom_env_root=vis_utils.get_visdom_env(cfg),
|
||||
@@ -508,12 +507,12 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
|
||||
)
|
||||
|
||||
# val loop (optional)
|
||||
if "val" in dataloaders and epoch % cfg.validation_interval == 0:
|
||||
if dataloaders.val is not None and epoch % cfg.validation_interval == 0:
|
||||
trainvalidate(
|
||||
model,
|
||||
stats,
|
||||
epoch,
|
||||
dataloaders["val"],
|
||||
dataloaders.val,
|
||||
optimizer,
|
||||
True,
|
||||
visdom_env_root=vis_utils.get_visdom_env(cfg),
|
||||
@@ -523,11 +522,11 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
|
||||
|
||||
# eval loop (optional)
|
||||
if (
|
||||
"test" in dataloaders
|
||||
dataloaders.test is not None
|
||||
and cfg.test_interval > 0
|
||||
and epoch % cfg.test_interval == 0
|
||||
):
|
||||
run_eval(cfg, model, stats, dataloaders["test"], device=device)
|
||||
run_eval(cfg, model, stats, dataloaders.test, device=device)
|
||||
|
||||
assert stats.epoch == epoch, "inconsistent stats!"
|
||||
|
||||
@@ -550,23 +549,28 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
|
||||
_eval_and_dump(cfg, datasets, dataloaders, model, stats, device=device)
|
||||
|
||||
|
||||
def _eval_and_dump(cfg, datasets, dataloaders, model, stats, device):
|
||||
def _eval_and_dump(
|
||||
cfg, datasets: Datasets, dataloaders: Dataloaders, model, stats, device
|
||||
) -> None:
|
||||
"""
|
||||
Run the evaluation loop with the test data loader and
|
||||
save the predictions to the `exp_dir`.
|
||||
"""
|
||||
|
||||
if "test" not in dataloaders:
|
||||
dataloader = dataloaders.test
|
||||
|
||||
if dataloader is None:
|
||||
raise ValueError('Dataloaders have to contain the "test" entry for eval!')
|
||||
|
||||
eval_task = cfg.dataset_args["dataset_name"].split("_")[-1]
|
||||
all_source_cameras = (
|
||||
_get_all_source_cameras(datasets["train"])
|
||||
if eval_task == "singlesequence"
|
||||
else None
|
||||
)
|
||||
if eval_task == "singlesequence":
|
||||
if datasets.train is None:
|
||||
raise ValueError("train dataset must be provided")
|
||||
all_source_cameras = _get_all_source_cameras(datasets.train)
|
||||
else:
|
||||
all_source_cameras = None
|
||||
results = run_eval(
|
||||
cfg, model, all_source_cameras, dataloaders["test"], eval_task, device=device
|
||||
cfg, model, all_source_cameras, dataloader, eval_task, device=device
|
||||
)
|
||||
|
||||
# add the evaluation epoch to the results
|
||||
|
||||
@@ -27,6 +27,7 @@ from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
||||
FrameData,
|
||||
ImplicitronDataset,
|
||||
ImplicitronDatasetBase,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.utils import is_train_frame
|
||||
from pytorch3d.implicitron.models.base_model import EvaluationMode
|
||||
@@ -342,7 +343,10 @@ def export_scenes(
|
||||
model.eval()
|
||||
|
||||
# Setup the dataset
|
||||
dataset = dataset_zoo(**config.dataset_args)[split]
|
||||
datasets = dataset_zoo(**config.dataset_args)
|
||||
dataset: Optional[ImplicitronDatasetBase] = getattr(datasets, split, None)
|
||||
if dataset is None:
|
||||
raise ValueError(f"{split} dataset not provided")
|
||||
|
||||
# iterate over the sequences in the dataset
|
||||
for sequence_name in dataset.sequence_names():
|
||||
|
||||
Reference in New Issue
Block a user