get_all_train_cameras

Summary: As part of removing Task, make the dataset code generate the source cameras for itself. There's a small optimization available here, in that the JsonIndexDataset could avoid loading images.

Reviewed By: shapovalov

Differential Revision: D37313423

fbshipit-source-id: 3e5e0b2aabbf9cc51f10547a3523e98c72ad8755
This commit is contained in:
Jeremy Reizenstein 2022-07-06 07:13:41 -07:00 committed by Facebook GitHub Bot
parent 771cf8a328
commit 4e87c2b7f1
12 changed files with 139 additions and 94 deletions

View File

@ -66,9 +66,7 @@ from packaging import version
from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.dataset.data_loader_map_provider import DataLoaderMap
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
from pytorch3d.implicitron.tools import model_io, vis_utils
@ -456,6 +454,7 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
datasource = ImplicitronDataSource(**cfg.data_source_args)
datasets, dataloaders = datasource.get_datasets_and_dataloaders()
task = datasource.get_task()
all_train_cameras = datasource.get_all_train_cameras()
# init the model
model, stats, optimizer_state = init_model(cfg)
@ -466,7 +465,16 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
# only run evaluation on the test dataloader
if cfg.eval_only:
_eval_and_dump(cfg, task, datasets, dataloaders, model, stats, device=device)
_eval_and_dump(
cfg,
task,
all_train_cameras,
datasets,
dataloaders,
model,
stats,
device=device,
)
return
# init the optimizer
@ -528,7 +536,9 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
and cfg.test_interval > 0
and epoch % cfg.test_interval == 0
):
_run_eval(model, stats, dataloaders.test, task, device=device)
_run_eval(
model, all_train_cameras, dataloaders.test, task, device=device
)
assert stats.epoch == epoch, "inconsistent stats!"
@ -548,12 +558,22 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
logger.info(f"LR change! {cur_lr} -> {new_lr}")
if cfg.test_when_finished:
_eval_and_dump(cfg, task, datasets, dataloaders, model, stats, device=device)
_eval_and_dump(
cfg,
task,
all_train_cameras,
datasets,
dataloaders,
model,
stats,
device=device,
)
def _eval_and_dump(
cfg,
task: Task,
all_train_cameras: Optional[CamerasBase],
datasets: DatasetMap,
dataloaders: DataLoaderMap,
model,
@ -570,13 +590,7 @@ def _eval_and_dump(
if dataloader is None:
raise ValueError('DataLoaderMap have to contain the "test" entry for eval!')
if task == Task.SINGLE_SEQUENCE:
if datasets.train is None:
raise ValueError("train dataset must be provided")
all_source_cameras = _get_all_source_cameras(datasets.train)
else:
all_source_cameras = None
results = _run_eval(model, all_source_cameras, dataloader, task, device=device)
results = _run_eval(model, all_train_cameras, dataloader, task, device=device)
# add the evaluation epoch to the results
for r in results:
@ -603,7 +617,7 @@ def _get_eval_frame_data(frame_data):
return frame_data_for_eval
def _run_eval(model, all_source_cameras, loader, task: Task, device):
def _run_eval(model, all_train_cameras, loader, task: Task, device):
"""
Run the evaluation loop on the test dataloader
"""
@ -631,7 +645,7 @@ def _run_eval(model, all_source_cameras, loader, task: Task, device):
implicitron_render,
bg_color="black",
lpips_model=lpips_model,
source_cameras=all_source_cameras,
source_cameras=all_train_cameras,
)
)
@ -642,31 +656,6 @@ def _run_eval(model, all_source_cameras, loader, task: Task, device):
return category_result["results"]
def _get_all_source_cameras(
dataset: JsonIndexDataset,
num_workers: int = 8,
) -> CamerasBase:
"""
Load and return all the source cameras in the training dataset
"""
all_frame_data = next(
iter(
torch.utils.data.DataLoader(
dataset,
shuffle=False,
batch_size=len(dataset),
num_workers=num_workers,
collate_fn=FrameData.collate,
)
)
)
is_source = ds_utils.is_known_frame(all_frame_data.frame_type)
source_cameras = all_frame_data.camera[torch.where(is_source)[0]]
return source_cameras
def _seed_all_random_engines(seed: int):
np.random.seed(seed)
torch.manual_seed(seed)

View File

@ -4,13 +4,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple
from typing import Optional, Tuple
from pytorch3d.implicitron.tools.config import (
registry,
ReplaceableBase,
run_auto_creation,
)
from pytorch3d.renderer.cameras import CamerasBase
from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
@ -28,6 +29,15 @@ class DataSourceBase(ReplaceableBase):
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
raise NotImplementedError()
def get_all_train_cameras(self) -> Optional[CamerasBase]:
"""
If the data is all for a single scene, returns a list
of the known training cameras for that scene, which is
used for evaluating the viewpoint difficulty of the
unseen cameras.
"""
raise NotImplementedError()
@registry.register
class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
@ -56,3 +66,6 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
def get_task(self) -> Task:
return self.dataset_map_provider.get_task()
def get_all_train_cameras(self) -> Optional[CamerasBase]:
return self.dataset_map_provider.get_all_train_cameras()

View File

@ -224,7 +224,7 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
# _seq_to_idx: Dict[str, List[int]] = field(init=False)
def __len__(self) -> int:
raise NotImplementedError
raise NotImplementedError()
def get_frame_numbers_and_timestamps(
self, idxs: Sequence[int]

View File

@ -12,6 +12,7 @@ from typing import Iterator, Optional
from iopath.common.file_io import PathManager
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from pytorch3d.renderer.cameras import CamerasBase
from .dataset_base import DatasetBase
@ -73,6 +74,15 @@ class DatasetMapProviderBase(ReplaceableBase):
def get_task(self) -> Task:
raise NotImplementedError()
def get_all_train_cameras(self) -> Optional[CamerasBase]:
"""
If the data is all for a single scene, returns a list
of the known training cameras for that scene, which is
used for evaluating the difficulty of the unknown
cameras. Otherwise return None.
"""
raise NotImplementedError()
@registry.register
class PathManagerFactory(ReplaceableBase):

View File

@ -32,11 +32,13 @@ import torch
from PIL import Image
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from pytorch3d.io import IO
from pytorch3d.renderer.cameras import PerspectiveCameras
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
from pytorch3d.structures.pointclouds import Pointclouds
from . import types
from .dataset_base import DatasetBase, FrameData
from .utils import is_known_frame_scalar
logger = logging.getLogger(__name__)
@ -206,6 +208,20 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]:
return entry["subset"]
def get_all_train_cameras(self) -> CamerasBase:
"""
Returns the cameras corresponding to all the known frames.
"""
cameras = []
# pyre-ignore[16]
for frame_idx, frame_annot in enumerate(self.frame_annots):
frame_type = self._get_frame_type(frame_annot)
if frame_type is None:
raise ValueError("subsets not loaded")
if is_known_frame_scalar(frame_type):
cameras.append(self[frame_idx].camera)
return join_cameras_as_batch(cameras)
def __getitem__(self, index) -> FrameData:
# pyre-ignore[16]
if index >= len(self.frame_annots):

View File

@ -7,7 +7,7 @@
import json
import os
from typing import Dict, List, Tuple, Type
from typing import Dict, List, Optional, Tuple, Type
from omegaconf import DictConfig, open_dict
from pytorch3d.implicitron.tools.config import (
@ -15,6 +15,7 @@ from pytorch3d.implicitron.tools.config import (
registry,
run_auto_creation,
)
from pytorch3d.renderer.cameras import CamerasBase
from .dataset_map_provider import (
DatasetMap,
@ -267,6 +268,15 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
def get_task(self) -> Task:
return Task(self.task_str)
def get_all_train_cameras(self) -> Optional[CamerasBase]:
if Task(self.task_str) == Task.MULTI_SEQUENCE:
return None
# pyre-ignore[16]
train_dataset = self.dataset_map.train
assert isinstance(train_dataset, JsonIndexDataset)
return train_dataset.get_all_train_cameras()
def _get_co3d_set_names_mapping(
task: Task,

View File

@ -18,7 +18,7 @@ from pytorch3d.implicitron.tools.config import (
expand_args_fields,
run_auto_creation,
)
from pytorch3d.renderer import PerspectiveCameras
from pytorch3d.renderer import CamerasBase, join_cameras_as_batch, PerspectiveCameras
from .dataset_base import DatasetBase, FrameData
from .dataset_map_provider import (
@ -110,7 +110,7 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
# - images: [N, 3, H, W] tensor of rgb images - floats in [0,1]
# - fg_probabilities: None or [N, 1, H, W] of floats in [0,1]
# - splits: List[List[int]] of indices for train/val/test subsets.
raise NotImplementedError
raise NotImplementedError()
def _get_dataset(
self, split_idx: int, frame_type: str, set_eval_batches: bool = False
@ -162,6 +162,11 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
def get_task(self) -> Task:
return Task.SINGLE_SEQUENCE
def get_all_train_cameras(self) -> Optional[CamerasBase]:
# pyre-ignore[16]
cameras = [self.poses[i] for i in self.i_split[0]]
return join_cameras_as_batch(cameras)
def _interpret_blender_cameras(
poses: torch.Tensor, H: int, W: int, focal: float

View File

@ -12,12 +12,10 @@ from typing import Any, cast, Dict, List, Optional, Tuple
import lpips
import torch
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
CO3D_CATEGORIES,
)
from pytorch3d.implicitron.dataset.utils import is_known_frame
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
aggregate_nvs_results,
eval_batch,
@ -120,16 +118,7 @@ def evaluate_dbir_for_category(
if test_dataset is None or test_dataloader is None:
raise ValueError("must have a test dataset.")
if task == Task.SINGLE_SEQUENCE:
# all_source_cameras are needed for evaluation of the
# target camera difficulty
# pyre-fixme[16]: `JsonIndexDataset` has no attribute `frame_annots`.
sequence_name = test_dataset.frame_annots[0]["frame_annotation"].sequence_name
all_source_cameras = _get_all_source_cameras(
test_dataset, sequence_name, num_workers=num_workers
)
else:
all_source_cameras = None
all_train_cameras = data_source.get_all_train_cameras()
image_size = cast(JsonIndexDataset, test_dataset).image_width
@ -160,7 +149,7 @@ def evaluate_dbir_for_category(
preds["implicitron_render"],
bg_color=bg_color,
lpips_model=lpips_model,
source_cameras=all_source_cameras,
source_cameras=all_train_cameras,
)
)
@ -184,35 +173,5 @@ def _print_aggregate_results(
print("")
def _get_all_source_cameras(
dataset: DatasetBase, sequence_name: str, num_workers: int = 8
):
"""
Loads all training cameras of a given sequence.
The set of all seen cameras is needed for evaluating the viewpoint difficulty
for the singlescene evaluation.
Args:
dataset: Co3D dataset object.
sequence_name: The name of the sequence.
num_workers: The number of for the utilized dataloader.
"""
# load all source cameras of the sequence
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
dataset_for_loader = torch.utils.data.Subset(dataset, seq_idx)
(all_frame_data,) = torch.utils.data.DataLoader(
dataset_for_loader,
shuffle=False,
batch_size=len(dataset_for_loader),
num_workers=num_workers,
collate_fn=dataset.frame_data_type.collate,
)
is_known = is_known_frame(all_frame_data.frame_type)
source_cameras = all_frame_data.camera[torch.where(is_known)[0]]
return source_cameras
if __name__ == "__main__":
main()

View File

@ -143,7 +143,7 @@ def eval_batch(
visualize: bool = False,
visualize_visdom_env: str = "eval_debug",
break_after_visualising: bool = True,
source_cameras: Optional[List[CamerasBase]] = None,
source_cameras: Optional[CamerasBase] = None,
) -> Dict[str, Any]:
"""
Produce performance metrics for a single batch of new-view synthesis

View File

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.tools.config import get_default_args
from pytorch3d.renderer import PerspectiveCameras
from tests.common_testing import TestCaseMixin
# These tests are only run internally, where the data is available.
internal = os.environ.get("FB_TEST", False)
inside_re_worker = os.environ.get("INSIDE_RE_WORKER", False)
skip_tests = not internal or inside_re_worker
@unittest.skipIf(skip_tests, "no data")
class TestDataJsonIndex(TestCaseMixin, unittest.TestCase):
def test_loaders(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider"
dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args
dataset_args.category = "skateboard"
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
dataset_args.test_restrict_sequence_id = 0
dataset_args.dataset_JsonIndexDataset_args.limit_sequences_to = 1
data_source = ImplicitronDataSource(**args)
cameras = data_source.get_all_train_cameras()
self.assertIsInstance(cameras, PerspectiveCameras)
self.assertEqual(len(cameras), 81)
data_sets, data_loaders = data_source.get_datasets_and_dataloaders()
self.assertEqual(len(data_sets.train), 81)
self.assertEqual(len(data_sets.val), 102)
self.assertEqual(len(data_sets.test), 102)

View File

@ -16,6 +16,7 @@ from pytorch3d.implicitron.dataset.llff_dataset_map_provider import (
LlffDatasetMapProvider,
)
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
from pytorch3d.renderer import PerspectiveCameras
from tests.common_testing import TestCaseMixin
@ -123,3 +124,7 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
for i in data_loaders.test:
self.assertEqual(i.frame_type, ["unseen"])
self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))
cameras = data_source.get_all_train_cameras()
self.assertIsInstance(cameras, PerspectiveCameras)
self.assertEqual(len(cameras), 100)

View File

@ -24,11 +24,7 @@ from pytorch3d.implicitron.tools.config import expand_args_fields, registry
from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth
from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
if os.environ.get("FB_TEST", False):
from .common_resources import get_skateboard_data, provide_lpips_vgg
else:
from common_resources import get_skateboard_data, provide_lpips_vgg
from .common_resources import get_skateboard_data, provide_lpips_vgg
class TestEvaluation(unittest.TestCase):