get_all_train_cameras

Summary: As part of removing Task, make the dataset code generate the source cameras for itself. There's a small optimization available here, in that the JsonIndexDataset could avoid loading images.

Reviewed By: shapovalov

Differential Revision: D37313423

fbshipit-source-id: 3e5e0b2aabbf9cc51f10547a3523e98c72ad8755
This commit is contained in:
Jeremy Reizenstein
2022-07-06 07:13:41 -07:00
committed by Facebook GitHub Bot
parent 771cf8a328
commit 4e87c2b7f1
12 changed files with 139 additions and 94 deletions

View File

@@ -66,9 +66,7 @@ from packaging import version
from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.dataset.data_loader_map_provider import DataLoaderMap
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
from pytorch3d.implicitron.tools import model_io, vis_utils
@@ -456,6 +454,7 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
datasource = ImplicitronDataSource(**cfg.data_source_args)
datasets, dataloaders = datasource.get_datasets_and_dataloaders()
task = datasource.get_task()
all_train_cameras = datasource.get_all_train_cameras()
# init the model
model, stats, optimizer_state = init_model(cfg)
@@ -466,7 +465,16 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
# only run evaluation on the test dataloader
if cfg.eval_only:
_eval_and_dump(cfg, task, datasets, dataloaders, model, stats, device=device)
_eval_and_dump(
cfg,
task,
all_train_cameras,
datasets,
dataloaders,
model,
stats,
device=device,
)
return
# init the optimizer
@@ -528,7 +536,9 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
and cfg.test_interval > 0
and epoch % cfg.test_interval == 0
):
_run_eval(model, stats, dataloaders.test, task, device=device)
_run_eval(
model, all_train_cameras, dataloaders.test, task, device=device
)
assert stats.epoch == epoch, "inconsistent stats!"
@@ -548,12 +558,22 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
logger.info(f"LR change! {cur_lr} -> {new_lr}")
if cfg.test_when_finished:
_eval_and_dump(cfg, task, datasets, dataloaders, model, stats, device=device)
_eval_and_dump(
cfg,
task,
all_train_cameras,
datasets,
dataloaders,
model,
stats,
device=device,
)
def _eval_and_dump(
cfg,
task: Task,
all_train_cameras: Optional[CamerasBase],
datasets: DatasetMap,
dataloaders: DataLoaderMap,
model,
@@ -570,13 +590,7 @@ def _eval_and_dump(
if dataloader is None:
raise ValueError('DataLoaderMap have to contain the "test" entry for eval!')
if task == Task.SINGLE_SEQUENCE:
if datasets.train is None:
raise ValueError("train dataset must be provided")
all_source_cameras = _get_all_source_cameras(datasets.train)
else:
all_source_cameras = None
results = _run_eval(model, all_source_cameras, dataloader, task, device=device)
results = _run_eval(model, all_train_cameras, dataloader, task, device=device)
# add the evaluation epoch to the results
for r in results:
@@ -603,7 +617,7 @@ def _get_eval_frame_data(frame_data):
return frame_data_for_eval
def _run_eval(model, all_source_cameras, loader, task: Task, device):
def _run_eval(model, all_train_cameras, loader, task: Task, device):
"""
Run the evaluation loop on the test dataloader
"""
@@ -631,7 +645,7 @@ def _run_eval(model, all_source_cameras, loader, task: Task, device):
implicitron_render,
bg_color="black",
lpips_model=lpips_model,
source_cameras=all_source_cameras,
source_cameras=all_train_cameras,
)
)
@@ -642,31 +656,6 @@ def _run_eval(model, all_source_cameras, loader, task: Task, device):
return category_result["results"]
def _get_all_source_cameras(
dataset: JsonIndexDataset,
num_workers: int = 8,
) -> CamerasBase:
"""
Load and return all the source cameras in the training dataset
"""
all_frame_data = next(
iter(
torch.utils.data.DataLoader(
dataset,
shuffle=False,
batch_size=len(dataset),
num_workers=num_workers,
collate_fn=FrameData.collate,
)
)
)
is_source = ds_utils.is_known_frame(all_frame_data.frame_type)
source_cameras = all_frame_data.camera[torch.where(is_source)[0]]
return source_cameras
def _seed_all_random_engines(seed: int):
np.random.seed(seed)
torch.manual_seed(seed)