data_source

Summary:
Move dataset_args and dataloader_args from ExperimentConfig into a new member called datasource so that it can contain replaceables.

Also add enum Task for task type.

Reviewed By: shapovalov

Differential Revision: D36201719

fbshipit-source-id: 47d6967bfea3b7b146b6bbd1572e0457c9365871
This commit is contained in:
Jeremy Reizenstein
2022-05-20 07:50:30 -07:00
committed by Facebook GitHub Bot
parent 9ec9d057cc
commit 73dc109dba
10 changed files with 194 additions and 124 deletions

View File

@@ -64,8 +64,9 @@ import tqdm
from omegaconf import DictConfig, OmegaConf
from packaging import version
from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo, Dataloaders
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo, Datasets
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
from pytorch3d.implicitron.dataset.dataloader_zoo import Dataloaders
from pytorch3d.implicitron.dataset.dataset_zoo import Datasets
from pytorch3d.implicitron.dataset.implicitron_dataset import (
FrameData,
ImplicitronDataset,
@@ -428,7 +429,7 @@ def trainvalidate(
optimizer.step()
def run_training(cfg: DictConfig, device: str = "cpu"):
def run_training(cfg: DictConfig, device: str = "cpu") -> None:
"""
Entry point to run the training and validation loops
based on the specified config file.
@@ -452,8 +453,9 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
warnings.warn("Cant dump config due to insufficient permissions!")
# setup datasets
datasets = dataset_zoo(**cfg.dataset_args)
dataloaders = dataloader_zoo(datasets, **cfg.dataloader_args)
datasource = ImplicitronDataSource(**cfg.data_source_args)
datasets, dataloaders = datasource.get_datasets_and_dataloaders()
task = datasource.get_task()
# init the model
model, stats, optimizer_state = init_model(cfg)
@@ -464,7 +466,7 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
# only run evaluation on the test dataloader
if cfg.eval_only:
_eval_and_dump(cfg, datasets, dataloaders, model, stats, device=device)
_eval_and_dump(cfg, task, datasets, dataloaders, model, stats, device=device)
return
# init the optimizer
@@ -526,7 +528,7 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
and cfg.test_interval > 0
and epoch % cfg.test_interval == 0
):
run_eval(cfg, model, stats, dataloaders.test, device=device)
_run_eval(model, stats, dataloaders.test, task, device=device)
assert stats.epoch == epoch, "inconsistent stats!"
@@ -546,11 +548,17 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
logger.info(f"LR change! {cur_lr} -> {new_lr}")
if cfg.test_when_finished:
_eval_and_dump(cfg, datasets, dataloaders, model, stats, device=device)
_eval_and_dump(cfg, task, datasets, dataloaders, model, stats, device=device)
def _eval_and_dump(
cfg, datasets: Datasets, dataloaders: Dataloaders, model, stats, device
cfg,
task: Task,
datasets: Datasets,
dataloaders: Dataloaders,
model,
stats,
device,
) -> None:
"""
Run the evaluation loop with the test data loader and
@@ -562,16 +570,13 @@ def _eval_and_dump(
if dataloader is None:
raise ValueError('Dataloaders have to contain the "test" entry for eval!')
eval_task = cfg.dataset_args["dataset_name"].split("_")[-1]
if eval_task == "singlesequence":
if task == Task.SINGLE_SEQUENCE:
if datasets.train is None:
raise ValueError("train dataset must be provided")
all_source_cameras = _get_all_source_cameras(datasets.train)
else:
all_source_cameras = None
results = run_eval(
cfg, model, all_source_cameras, dataloader, eval_task, device=device
)
results = _run_eval(model, all_source_cameras, dataloader, task, device=device)
# add the evaluation epoch to the results
for r in results:
@@ -598,7 +603,7 @@ def _get_eval_frame_data(frame_data):
return frame_data_for_eval
def run_eval(cfg, model, all_source_cameras, loader, task, device):
def _run_eval(model, all_source_cameras, loader, task: Task, device):
"""
Run the evaluation loop on the test dataloader
"""
@@ -672,8 +677,7 @@ def _seed_all_random_engines(seed: int):
class ExperimentConfig:
generic_model_args: DictConfig = get_default_args_field(GenericModel)
solver_args: DictConfig = get_default_args_field(init_optimizer)
dataset_args: DictConfig = get_default_args_field(dataset_zoo)
dataloader_args: DictConfig = get_default_args_field(dataloader_zoo)
data_source_args: DictConfig = get_default_args_field(ImplicitronDataSource)
architecture: str = "generic"
detect_anomaly: bool = False
eval_only: bool = False