get_all_train_cameras

Summary: As part of removing Task, make the dataset code generate the source cameras for itself. There's a small optimization available here, in that the JsonIndexDataset could avoid loading images.

Reviewed By: shapovalov

Differential Revision: D37313423

fbshipit-source-id: 3e5e0b2aabbf9cc51f10547a3523e98c72ad8755
This commit is contained in:
Jeremy Reizenstein 2022-07-06 07:13:41 -07:00 committed by Facebook GitHub Bot
parent 771cf8a328
commit 4e87c2b7f1
12 changed files with 139 additions and 94 deletions

View File

@ -66,9 +66,7 @@ from packaging import version
from pytorch3d.implicitron.dataset import utils as ds_utils from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.dataset.data_loader_map_provider import DataLoaderMap from pytorch3d.implicitron.dataset.data_loader_map_provider import DataLoaderMap
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
from pytorch3d.implicitron.tools import model_io, vis_utils from pytorch3d.implicitron.tools import model_io, vis_utils
@ -456,6 +454,7 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
datasource = ImplicitronDataSource(**cfg.data_source_args) datasource = ImplicitronDataSource(**cfg.data_source_args)
datasets, dataloaders = datasource.get_datasets_and_dataloaders() datasets, dataloaders = datasource.get_datasets_and_dataloaders()
task = datasource.get_task() task = datasource.get_task()
all_train_cameras = datasource.get_all_train_cameras()
# init the model # init the model
model, stats, optimizer_state = init_model(cfg) model, stats, optimizer_state = init_model(cfg)
@ -466,7 +465,16 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
# only run evaluation on the test dataloader # only run evaluation on the test dataloader
if cfg.eval_only: if cfg.eval_only:
_eval_and_dump(cfg, task, datasets, dataloaders, model, stats, device=device) _eval_and_dump(
cfg,
task,
all_train_cameras,
datasets,
dataloaders,
model,
stats,
device=device,
)
return return
# init the optimizer # init the optimizer
@ -528,7 +536,9 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
and cfg.test_interval > 0 and cfg.test_interval > 0
and epoch % cfg.test_interval == 0 and epoch % cfg.test_interval == 0
): ):
_run_eval(model, stats, dataloaders.test, task, device=device) _run_eval(
model, all_train_cameras, dataloaders.test, task, device=device
)
assert stats.epoch == epoch, "inconsistent stats!" assert stats.epoch == epoch, "inconsistent stats!"
@ -548,12 +558,22 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
logger.info(f"LR change! {cur_lr} -> {new_lr}") logger.info(f"LR change! {cur_lr} -> {new_lr}")
if cfg.test_when_finished: if cfg.test_when_finished:
_eval_and_dump(cfg, task, datasets, dataloaders, model, stats, device=device) _eval_and_dump(
cfg,
task,
all_train_cameras,
datasets,
dataloaders,
model,
stats,
device=device,
)
def _eval_and_dump( def _eval_and_dump(
cfg, cfg,
task: Task, task: Task,
all_train_cameras: Optional[CamerasBase],
datasets: DatasetMap, datasets: DatasetMap,
dataloaders: DataLoaderMap, dataloaders: DataLoaderMap,
model, model,
@ -570,13 +590,7 @@ def _eval_and_dump(
if dataloader is None: if dataloader is None:
raise ValueError('DataLoaderMap have to contain the "test" entry for eval!') raise ValueError('DataLoaderMap have to contain the "test" entry for eval!')
if task == Task.SINGLE_SEQUENCE: results = _run_eval(model, all_train_cameras, dataloader, task, device=device)
if datasets.train is None:
raise ValueError("train dataset must be provided")
all_source_cameras = _get_all_source_cameras(datasets.train)
else:
all_source_cameras = None
results = _run_eval(model, all_source_cameras, dataloader, task, device=device)
# add the evaluation epoch to the results # add the evaluation epoch to the results
for r in results: for r in results:
@ -603,7 +617,7 @@ def _get_eval_frame_data(frame_data):
return frame_data_for_eval return frame_data_for_eval
def _run_eval(model, all_source_cameras, loader, task: Task, device): def _run_eval(model, all_train_cameras, loader, task: Task, device):
""" """
Run the evaluation loop on the test dataloader Run the evaluation loop on the test dataloader
""" """
@ -631,7 +645,7 @@ def _run_eval(model, all_source_cameras, loader, task: Task, device):
implicitron_render, implicitron_render,
bg_color="black", bg_color="black",
lpips_model=lpips_model, lpips_model=lpips_model,
source_cameras=all_source_cameras, source_cameras=all_train_cameras,
) )
) )
@ -642,31 +656,6 @@ def _run_eval(model, all_source_cameras, loader, task: Task, device):
return category_result["results"] return category_result["results"]
def _get_all_source_cameras(
dataset: JsonIndexDataset,
num_workers: int = 8,
) -> CamerasBase:
"""
Load and return all the source cameras in the training dataset
"""
all_frame_data = next(
iter(
torch.utils.data.DataLoader(
dataset,
shuffle=False,
batch_size=len(dataset),
num_workers=num_workers,
collate_fn=FrameData.collate,
)
)
)
is_source = ds_utils.is_known_frame(all_frame_data.frame_type)
source_cameras = all_frame_data.camera[torch.where(is_source)[0]]
return source_cameras
def _seed_all_random_engines(seed: int): def _seed_all_random_engines(seed: int):
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)

View File

@ -4,13 +4,14 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Tuple from typing import Optional, Tuple
from pytorch3d.implicitron.tools.config import ( from pytorch3d.implicitron.tools.config import (
registry, registry,
ReplaceableBase, ReplaceableBase,
run_auto_creation, run_auto_creation,
) )
from pytorch3d.renderer.cameras import CamerasBase
from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
@ -28,6 +29,15 @@ class DataSourceBase(ReplaceableBase):
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]: def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
raise NotImplementedError() raise NotImplementedError()
def get_all_train_cameras(self) -> Optional[CamerasBase]:
"""
If the data is all for a single scene, returns a list
of the known training cameras for that scene, which is
used for evaluating the viewpoint difficulty of the
unseen cameras.
"""
raise NotImplementedError()
@registry.register @registry.register
class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13] class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
@ -56,3 +66,6 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
def get_task(self) -> Task: def get_task(self) -> Task:
return self.dataset_map_provider.get_task() return self.dataset_map_provider.get_task()
def get_all_train_cameras(self) -> Optional[CamerasBase]:
return self.dataset_map_provider.get_all_train_cameras()

View File

@ -224,7 +224,7 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
# _seq_to_idx: Dict[str, List[int]] = field(init=False) # _seq_to_idx: Dict[str, List[int]] = field(init=False)
def __len__(self) -> int: def __len__(self) -> int:
raise NotImplementedError raise NotImplementedError()
def get_frame_numbers_and_timestamps( def get_frame_numbers_and_timestamps(
self, idxs: Sequence[int] self, idxs: Sequence[int]

View File

@ -12,6 +12,7 @@ from typing import Iterator, Optional
from iopath.common.file_io import PathManager from iopath.common.file_io import PathManager
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from pytorch3d.renderer.cameras import CamerasBase
from .dataset_base import DatasetBase from .dataset_base import DatasetBase
@ -73,6 +74,15 @@ class DatasetMapProviderBase(ReplaceableBase):
def get_task(self) -> Task: def get_task(self) -> Task:
raise NotImplementedError() raise NotImplementedError()
def get_all_train_cameras(self) -> Optional[CamerasBase]:
"""
If the data is all for a single scene, returns a list
of the known training cameras for that scene, which is
used for evaluating the difficulty of the unknown
cameras. Otherwise return None.
"""
raise NotImplementedError()
@registry.register @registry.register
class PathManagerFactory(ReplaceableBase): class PathManagerFactory(ReplaceableBase):

View File

@ -32,11 +32,13 @@ import torch
from PIL import Image from PIL import Image
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from pytorch3d.io import IO from pytorch3d.io import IO
from pytorch3d.renderer.cameras import PerspectiveCameras from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
from pytorch3d.structures.pointclouds import Pointclouds from pytorch3d.structures.pointclouds import Pointclouds
from . import types from . import types
from .dataset_base import DatasetBase, FrameData from .dataset_base import DatasetBase, FrameData
from .utils import is_known_frame_scalar
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -206,6 +208,20 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]: def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]:
return entry["subset"] return entry["subset"]
def get_all_train_cameras(self) -> CamerasBase:
"""
Returns the cameras corresponding to all the known frames.
"""
cameras = []
# pyre-ignore[16]
for frame_idx, frame_annot in enumerate(self.frame_annots):
frame_type = self._get_frame_type(frame_annot)
if frame_type is None:
raise ValueError("subsets not loaded")
if is_known_frame_scalar(frame_type):
cameras.append(self[frame_idx].camera)
return join_cameras_as_batch(cameras)
def __getitem__(self, index) -> FrameData: def __getitem__(self, index) -> FrameData:
# pyre-ignore[16] # pyre-ignore[16]
if index >= len(self.frame_annots): if index >= len(self.frame_annots):

View File

@ -7,7 +7,7 @@
import json import json
import os import os
from typing import Dict, List, Tuple, Type from typing import Dict, List, Optional, Tuple, Type
from omegaconf import DictConfig, open_dict from omegaconf import DictConfig, open_dict
from pytorch3d.implicitron.tools.config import ( from pytorch3d.implicitron.tools.config import (
@ -15,6 +15,7 @@ from pytorch3d.implicitron.tools.config import (
registry, registry,
run_auto_creation, run_auto_creation,
) )
from pytorch3d.renderer.cameras import CamerasBase
from .dataset_map_provider import ( from .dataset_map_provider import (
DatasetMap, DatasetMap,
@ -267,6 +268,15 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
def get_task(self) -> Task: def get_task(self) -> Task:
return Task(self.task_str) return Task(self.task_str)
def get_all_train_cameras(self) -> Optional[CamerasBase]:
if Task(self.task_str) == Task.MULTI_SEQUENCE:
return None
# pyre-ignore[16]
train_dataset = self.dataset_map.train
assert isinstance(train_dataset, JsonIndexDataset)
return train_dataset.get_all_train_cameras()
def _get_co3d_set_names_mapping( def _get_co3d_set_names_mapping(
task: Task, task: Task,

View File

@ -18,7 +18,7 @@ from pytorch3d.implicitron.tools.config import (
expand_args_fields, expand_args_fields,
run_auto_creation, run_auto_creation,
) )
from pytorch3d.renderer import PerspectiveCameras from pytorch3d.renderer import CamerasBase, join_cameras_as_batch, PerspectiveCameras
from .dataset_base import DatasetBase, FrameData from .dataset_base import DatasetBase, FrameData
from .dataset_map_provider import ( from .dataset_map_provider import (
@ -110,7 +110,7 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
# - images: [N, 3, H, W] tensor of rgb images - floats in [0,1] # - images: [N, 3, H, W] tensor of rgb images - floats in [0,1]
# - fg_probabilities: None or [N, 1, H, W] of floats in [0,1] # - fg_probabilities: None or [N, 1, H, W] of floats in [0,1]
# - splits: List[List[int]] of indices for train/val/test subsets. # - splits: List[List[int]] of indices for train/val/test subsets.
raise NotImplementedError raise NotImplementedError()
def _get_dataset( def _get_dataset(
self, split_idx: int, frame_type: str, set_eval_batches: bool = False self, split_idx: int, frame_type: str, set_eval_batches: bool = False
@ -162,6 +162,11 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
def get_task(self) -> Task: def get_task(self) -> Task:
return Task.SINGLE_SEQUENCE return Task.SINGLE_SEQUENCE
def get_all_train_cameras(self) -> Optional[CamerasBase]:
# pyre-ignore[16]
cameras = [self.poses[i] for i in self.i_split[0]]
return join_cameras_as_batch(cameras)
def _interpret_blender_cameras( def _interpret_blender_cameras(
poses: torch.Tensor, H: int, W: int, focal: float poses: torch.Tensor, H: int, W: int, focal: float

View File

@ -12,12 +12,10 @@ from typing import Any, cast, Dict, List, Optional, Tuple
import lpips import lpips
import torch import torch
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import ( from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
CO3D_CATEGORIES, CO3D_CATEGORIES,
) )
from pytorch3d.implicitron.dataset.utils import is_known_frame
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import ( from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
aggregate_nvs_results, aggregate_nvs_results,
eval_batch, eval_batch,
@ -120,16 +118,7 @@ def evaluate_dbir_for_category(
if test_dataset is None or test_dataloader is None: if test_dataset is None or test_dataloader is None:
raise ValueError("must have a test dataset.") raise ValueError("must have a test dataset.")
if task == Task.SINGLE_SEQUENCE: all_train_cameras = data_source.get_all_train_cameras()
# all_source_cameras are needed for evaluation of the
# target camera difficulty
# pyre-fixme[16]: `JsonIndexDataset` has no attribute `frame_annots`.
sequence_name = test_dataset.frame_annots[0]["frame_annotation"].sequence_name
all_source_cameras = _get_all_source_cameras(
test_dataset, sequence_name, num_workers=num_workers
)
else:
all_source_cameras = None
image_size = cast(JsonIndexDataset, test_dataset).image_width image_size = cast(JsonIndexDataset, test_dataset).image_width
@ -160,7 +149,7 @@ def evaluate_dbir_for_category(
preds["implicitron_render"], preds["implicitron_render"],
bg_color=bg_color, bg_color=bg_color,
lpips_model=lpips_model, lpips_model=lpips_model,
source_cameras=all_source_cameras, source_cameras=all_train_cameras,
) )
) )
@ -184,35 +173,5 @@ def _print_aggregate_results(
print("") print("")
def _get_all_source_cameras(
dataset: DatasetBase, sequence_name: str, num_workers: int = 8
):
"""
Loads all training cameras of a given sequence.
The set of all seen cameras is needed for evaluating the viewpoint difficulty
for the singlescene evaluation.
Args:
dataset: Co3D dataset object.
sequence_name: The name of the sequence.
num_workers: The number of for the utilized dataloader.
"""
# load all source cameras of the sequence
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
dataset_for_loader = torch.utils.data.Subset(dataset, seq_idx)
(all_frame_data,) = torch.utils.data.DataLoader(
dataset_for_loader,
shuffle=False,
batch_size=len(dataset_for_loader),
num_workers=num_workers,
collate_fn=dataset.frame_data_type.collate,
)
is_known = is_known_frame(all_frame_data.frame_type)
source_cameras = all_frame_data.camera[torch.where(is_known)[0]]
return source_cameras
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -143,7 +143,7 @@ def eval_batch(
visualize: bool = False, visualize: bool = False,
visualize_visdom_env: str = "eval_debug", visualize_visdom_env: str = "eval_debug",
break_after_visualising: bool = True, break_after_visualising: bool = True,
source_cameras: Optional[List[CamerasBase]] = None, source_cameras: Optional[CamerasBase] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Produce performance metrics for a single batch of new-view synthesis Produce performance metrics for a single batch of new-view synthesis

View File

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.tools.config import get_default_args
from pytorch3d.renderer import PerspectiveCameras
from tests.common_testing import TestCaseMixin
# These tests are only run internally, where the data is available.
internal = os.environ.get("FB_TEST", False)
inside_re_worker = os.environ.get("INSIDE_RE_WORKER", False)
skip_tests = not internal or inside_re_worker
@unittest.skipIf(skip_tests, "no data")
class TestDataJsonIndex(TestCaseMixin, unittest.TestCase):
def test_loaders(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider"
dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args
dataset_args.category = "skateboard"
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
dataset_args.test_restrict_sequence_id = 0
dataset_args.dataset_JsonIndexDataset_args.limit_sequences_to = 1
data_source = ImplicitronDataSource(**args)
cameras = data_source.get_all_train_cameras()
self.assertIsInstance(cameras, PerspectiveCameras)
self.assertEqual(len(cameras), 81)
data_sets, data_loaders = data_source.get_datasets_and_dataloaders()
self.assertEqual(len(data_sets.train), 81)
self.assertEqual(len(data_sets.val), 102)
self.assertEqual(len(data_sets.test), 102)

View File

@ -16,6 +16,7 @@ from pytorch3d.implicitron.dataset.llff_dataset_map_provider import (
LlffDatasetMapProvider, LlffDatasetMapProvider,
) )
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
from pytorch3d.renderer import PerspectiveCameras
from tests.common_testing import TestCaseMixin from tests.common_testing import TestCaseMixin
@ -123,3 +124,7 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
for i in data_loaders.test: for i in data_loaders.test:
self.assertEqual(i.frame_type, ["unseen"]) self.assertEqual(i.frame_type, ["unseen"])
self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800)) self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))
cameras = data_source.get_all_train_cameras()
self.assertIsInstance(cameras, PerspectiveCameras)
self.assertEqual(len(cameras), 100)

View File

@ -24,11 +24,7 @@ from pytorch3d.implicitron.tools.config import expand_args_fields, registry
from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth
from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_ from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
from .common_resources import get_skateboard_data, provide_lpips_vgg
if os.environ.get("FB_TEST", False):
from .common_resources import get_skateboard_data, provide_lpips_vgg
else:
from common_resources import get_skateboard_data, provide_lpips_vgg
class TestEvaluation(unittest.TestCase): class TestEvaluation(unittest.TestCase):