lazy all_train_cameras

Summary: Avoid calculating all_train_cameras before it is needed, because it is slow in some datasets.

Reviewed By: shapovalov

Differential Revision: D38037157

fbshipit-source-id: 95461226655cde2626b680661951ab17ebb0ec75
This commit is contained in:
Jeremy Reizenstein 2022-07-21 15:04:00 -07:00 committed by Facebook GitHub Bot
parent b2dc520210
commit 3783437d2f
6 changed files with 19 additions and 13 deletions

View File

@ -391,7 +391,6 @@ def run_training(cfg: DictConfig) -> None:
datasource = ImplicitronDataSource(**cfg.data_source_args) datasource = ImplicitronDataSource(**cfg.data_source_args)
datasets, dataloaders = datasource.get_datasets_and_dataloaders() datasets, dataloaders = datasource.get_datasets_and_dataloaders()
task = datasource.get_task() task = datasource.get_task()
all_train_cameras = datasource.get_all_train_cameras()
# init the model # init the model
model, stats, optimizer_state = init_model(cfg=cfg, accelerator=accelerator) model, stats, optimizer_state = init_model(cfg=cfg, accelerator=accelerator)
@ -405,7 +404,7 @@ def run_training(cfg: DictConfig) -> None:
_eval_and_dump( _eval_and_dump(
cfg, cfg,
task, task,
all_train_cameras, datasource.all_train_cameras,
datasets, datasets,
dataloaders, dataloaders,
model, model,
@ -490,7 +489,7 @@ def run_training(cfg: DictConfig) -> None:
): ):
_run_eval( _run_eval(
model, model,
all_train_cameras, datasource.all_train_cameras,
dataloaders.test, dataloaders.test,
task, task,
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks, camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
@ -525,7 +524,7 @@ def run_training(cfg: DictConfig) -> None:
_eval_and_dump( _eval_and_dump(
cfg, cfg,
task, task,
all_train_cameras, datasource.all_train_cameras,
datasets, datasets,
dataloaders, dataloaders,
model, model,

View File

@ -30,9 +30,10 @@ class DataSourceBase(ReplaceableBase):
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]: def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
raise NotImplementedError() raise NotImplementedError()
def get_all_train_cameras(self) -> Optional[CamerasBase]: @property
def all_train_cameras(self) -> Optional[CamerasBase]:
""" """
If the data is all for a single scene, returns a list If the data is all for a single scene, a list
of the known training cameras for that scene, which is of the known training cameras for that scene, which is
used for evaluating the viewpoint difficulty of the used for evaluating the viewpoint difficulty of the
unseen cameras. unseen cameras.
@ -59,6 +60,7 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
def __post_init__(self): def __post_init__(self):
run_auto_creation(self) run_auto_creation(self)
self._all_train_cameras_cache: Optional[Tuple[Optional[CamerasBase]]] = None
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]: def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
datasets = self.dataset_map_provider.get_dataset_map() datasets = self.dataset_map_provider.get_dataset_map()
@ -68,5 +70,10 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
def get_task(self) -> Task: def get_task(self) -> Task:
return self.dataset_map_provider.get_task() return self.dataset_map_provider.get_task()
def get_all_train_cameras(self) -> Optional[CamerasBase]: @property
return self.dataset_map_provider.get_all_train_cameras() def all_train_cameras(self) -> Optional[CamerasBase]:
if self._all_train_cameras_cache is None: # pyre-ignore[16]
all_train_cameras = self.dataset_map_provider.get_all_train_cameras()
self._all_train_cameras_cache = (all_train_cameras,)
return self._all_train_cameras_cache[0]

View File

@ -118,8 +118,6 @@ def evaluate_dbir_for_category(
if test_dataset is None or test_dataloader is None: if test_dataset is None or test_dataloader is None:
raise ValueError("must have a test dataset.") raise ValueError("must have a test dataset.")
all_train_cameras = data_source.get_all_train_cameras()
image_size = cast(JsonIndexDataset, test_dataset).image_width image_size = cast(JsonIndexDataset, test_dataset).image_width
if image_size is None: if image_size is None:
@ -149,7 +147,7 @@ def evaluate_dbir_for_category(
preds["implicitron_render"], preds["implicitron_render"],
bg_color=bg_color, bg_color=bg_color,
lpips_model=lpips_model, lpips_model=lpips_model,
source_cameras=all_train_cameras, source_cameras=data_source.all_train_cameras,
) )
) )

View File

@ -10,6 +10,7 @@ from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from itertools import product from itertools import product
import numpy as np
from pytorch3d.implicitron.dataset.data_loader_map_provider import ( from pytorch3d.implicitron.dataset.data_loader_map_provider import (
DoublePoolBatchSampler, DoublePoolBatchSampler,
) )
@ -53,6 +54,7 @@ class MockDataset(DatasetBase):
class TestSceneBatchSampler(unittest.TestCase): class TestSceneBatchSampler(unittest.TestCase):
def setUp(self): def setUp(self):
np.random.seed(42)
self.dataset_overfit = MockDataset(1) self.dataset_overfit = MockDataset(1)
def test_overfit(self): def test_overfit(self):

View File

@ -31,7 +31,7 @@ class TestDataJsonIndex(TestCaseMixin, unittest.TestCase):
data_source = ImplicitronDataSource(**args) data_source = ImplicitronDataSource(**args)
cameras = data_source.get_all_train_cameras() cameras = data_source.all_train_cameras
self.assertIsInstance(cameras, PerspectiveCameras) self.assertIsInstance(cameras, PerspectiveCameras)
self.assertEqual(len(cameras), 81) self.assertEqual(len(cameras), 81)

View File

@ -152,6 +152,6 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
self.assertEqual(i.frame_type, ["unseen"]) self.assertEqual(i.frame_type, ["unseen"])
self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800)) self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))
cameras = data_source.get_all_train_cameras() cameras = data_source.all_train_cameras
self.assertIsInstance(cameras, PerspectiveCameras) self.assertIsInstance(cameras, PerspectiveCameras)
self.assertEqual(len(cameras), 100) self.assertEqual(len(cameras), 100)