mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
lazy all_train_cameras
Summary: Avoid calculating all_train_cameras before it is needed, because it is slow in some datasets. Reviewed By: shapovalov Differential Revision: D38037157 fbshipit-source-id: 95461226655cde2626b680661951ab17ebb0ec75
This commit is contained in:
parent
b2dc520210
commit
3783437d2f
@ -391,7 +391,6 @@ def run_training(cfg: DictConfig) -> None:
|
|||||||
datasource = ImplicitronDataSource(**cfg.data_source_args)
|
datasource = ImplicitronDataSource(**cfg.data_source_args)
|
||||||
datasets, dataloaders = datasource.get_datasets_and_dataloaders()
|
datasets, dataloaders = datasource.get_datasets_and_dataloaders()
|
||||||
task = datasource.get_task()
|
task = datasource.get_task()
|
||||||
all_train_cameras = datasource.get_all_train_cameras()
|
|
||||||
|
|
||||||
# init the model
|
# init the model
|
||||||
model, stats, optimizer_state = init_model(cfg=cfg, accelerator=accelerator)
|
model, stats, optimizer_state = init_model(cfg=cfg, accelerator=accelerator)
|
||||||
@ -405,7 +404,7 @@ def run_training(cfg: DictConfig) -> None:
|
|||||||
_eval_and_dump(
|
_eval_and_dump(
|
||||||
cfg,
|
cfg,
|
||||||
task,
|
task,
|
||||||
all_train_cameras,
|
datasource.all_train_cameras,
|
||||||
datasets,
|
datasets,
|
||||||
dataloaders,
|
dataloaders,
|
||||||
model,
|
model,
|
||||||
@ -490,7 +489,7 @@ def run_training(cfg: DictConfig) -> None:
|
|||||||
):
|
):
|
||||||
_run_eval(
|
_run_eval(
|
||||||
model,
|
model,
|
||||||
all_train_cameras,
|
datasource.all_train_cameras,
|
||||||
dataloaders.test,
|
dataloaders.test,
|
||||||
task,
|
task,
|
||||||
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
|
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
|
||||||
@ -525,7 +524,7 @@ def run_training(cfg: DictConfig) -> None:
|
|||||||
_eval_and_dump(
|
_eval_and_dump(
|
||||||
cfg,
|
cfg,
|
||||||
task,
|
task,
|
||||||
all_train_cameras,
|
datasource.all_train_cameras,
|
||||||
datasets,
|
datasets,
|
||||||
dataloaders,
|
dataloaders,
|
||||||
model,
|
model,
|
||||||
|
@ -30,9 +30,10 @@ class DataSourceBase(ReplaceableBase):
|
|||||||
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
|
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
@property
|
||||||
|
def all_train_cameras(self) -> Optional[CamerasBase]:
|
||||||
"""
|
"""
|
||||||
If the data is all for a single scene, returns a list
|
If the data is all for a single scene, a list
|
||||||
of the known training cameras for that scene, which is
|
of the known training cameras for that scene, which is
|
||||||
used for evaluating the viewpoint difficulty of the
|
used for evaluating the viewpoint difficulty of the
|
||||||
unseen cameras.
|
unseen cameras.
|
||||||
@ -59,6 +60,7 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
|
self._all_train_cameras_cache: Optional[Tuple[Optional[CamerasBase]]] = None
|
||||||
|
|
||||||
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
|
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
|
||||||
datasets = self.dataset_map_provider.get_dataset_map()
|
datasets = self.dataset_map_provider.get_dataset_map()
|
||||||
@ -68,5 +70,10 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
|
|||||||
def get_task(self) -> Task:
|
def get_task(self) -> Task:
|
||||||
return self.dataset_map_provider.get_task()
|
return self.dataset_map_provider.get_task()
|
||||||
|
|
||||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
@property
|
||||||
return self.dataset_map_provider.get_all_train_cameras()
|
def all_train_cameras(self) -> Optional[CamerasBase]:
|
||||||
|
if self._all_train_cameras_cache is None: # pyre-ignore[16]
|
||||||
|
all_train_cameras = self.dataset_map_provider.get_all_train_cameras()
|
||||||
|
self._all_train_cameras_cache = (all_train_cameras,)
|
||||||
|
|
||||||
|
return self._all_train_cameras_cache[0]
|
||||||
|
@ -118,8 +118,6 @@ def evaluate_dbir_for_category(
|
|||||||
if test_dataset is None or test_dataloader is None:
|
if test_dataset is None or test_dataloader is None:
|
||||||
raise ValueError("must have a test dataset.")
|
raise ValueError("must have a test dataset.")
|
||||||
|
|
||||||
all_train_cameras = data_source.get_all_train_cameras()
|
|
||||||
|
|
||||||
image_size = cast(JsonIndexDataset, test_dataset).image_width
|
image_size = cast(JsonIndexDataset, test_dataset).image_width
|
||||||
|
|
||||||
if image_size is None:
|
if image_size is None:
|
||||||
@ -149,7 +147,7 @@ def evaluate_dbir_for_category(
|
|||||||
preds["implicitron_render"],
|
preds["implicitron_render"],
|
||||||
bg_color=bg_color,
|
bg_color=bg_color,
|
||||||
lpips_model=lpips_model,
|
lpips_model=lpips_model,
|
||||||
source_cameras=all_train_cameras,
|
source_cameras=data_source.all_train_cameras,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ from collections import defaultdict
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from itertools import product
|
from itertools import product
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
|
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
|
||||||
DoublePoolBatchSampler,
|
DoublePoolBatchSampler,
|
||||||
)
|
)
|
||||||
@ -53,6 +54,7 @@ class MockDataset(DatasetBase):
|
|||||||
|
|
||||||
class TestSceneBatchSampler(unittest.TestCase):
|
class TestSceneBatchSampler(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
np.random.seed(42)
|
||||||
self.dataset_overfit = MockDataset(1)
|
self.dataset_overfit = MockDataset(1)
|
||||||
|
|
||||||
def test_overfit(self):
|
def test_overfit(self):
|
||||||
|
@ -31,7 +31,7 @@ class TestDataJsonIndex(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
data_source = ImplicitronDataSource(**args)
|
data_source = ImplicitronDataSource(**args)
|
||||||
|
|
||||||
cameras = data_source.get_all_train_cameras()
|
cameras = data_source.all_train_cameras
|
||||||
self.assertIsInstance(cameras, PerspectiveCameras)
|
self.assertIsInstance(cameras, PerspectiveCameras)
|
||||||
self.assertEqual(len(cameras), 81)
|
self.assertEqual(len(cameras), 81)
|
||||||
|
|
||||||
|
@ -152,6 +152,6 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
|
|||||||
self.assertEqual(i.frame_type, ["unseen"])
|
self.assertEqual(i.frame_type, ["unseen"])
|
||||||
self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))
|
self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))
|
||||||
|
|
||||||
cameras = data_source.get_all_train_cameras()
|
cameras = data_source.all_train_cameras
|
||||||
self.assertIsInstance(cameras, PerspectiveCameras)
|
self.assertIsInstance(cameras, PerspectiveCameras)
|
||||||
self.assertEqual(len(cameras), 100)
|
self.assertEqual(len(cameras), 100)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user