mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 06:10:34 +08:00
get_all_train_cameras
Summary: As part of removing Task, make the dataset code generate the source cameras for itself. There's a small optimization available here, in that the JsonIndexDataset could avoid loading images. Reviewed By: shapovalov Differential Revision: D37313423 fbshipit-source-id: 3e5e0b2aabbf9cc51f10547a3523e98c72ad8755
This commit is contained in:
committed by
Facebook GitHub Bot
parent
771cf8a328
commit
4e87c2b7f1
@@ -4,13 +4,14 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
registry,
|
||||
ReplaceableBase,
|
||||
run_auto_creation,
|
||||
)
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
|
||||
from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa
|
||||
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
|
||||
@@ -28,6 +29,15 @@ class DataSourceBase(ReplaceableBase):
|
||||
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||
"""
|
||||
If the data is all for a single scene, returns a list
|
||||
of the known training cameras for that scene, which is
|
||||
used for evaluating the viewpoint difficulty of the
|
||||
unseen cameras.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@registry.register
|
||||
class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
|
||||
@@ -56,3 +66,6 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
|
||||
|
||||
def get_task(self) -> Task:
|
||||
return self.dataset_map_provider.get_task()
|
||||
|
||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||
return self.dataset_map_provider.get_all_train_cameras()
|
||||
|
||||
@@ -224,7 +224,7 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
|
||||
# _seq_to_idx: Dict[str, List[int]] = field(init=False)
|
||||
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_frame_numbers_and_timestamps(
|
||||
self, idxs: Sequence[int]
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import Iterator, Optional
|
||||
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
|
||||
from .dataset_base import DatasetBase
|
||||
|
||||
@@ -73,6 +74,15 @@ class DatasetMapProviderBase(ReplaceableBase):
|
||||
def get_task(self) -> Task:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||
"""
|
||||
If the data is all for a single scene, returns a list
|
||||
of the known training cameras for that scene, which is
|
||||
used for evaluating the difficulty of the unknown
|
||||
cameras. Otherwise return None.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@registry.register
|
||||
class PathManagerFactory(ReplaceableBase):
|
||||
|
||||
@@ -32,11 +32,13 @@ import torch
|
||||
from PIL import Image
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
from pytorch3d.io import IO
|
||||
from pytorch3d.renderer.cameras import PerspectiveCameras
|
||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
|
||||
from . import types
|
||||
from .dataset_base import DatasetBase, FrameData
|
||||
from .utils import is_known_frame_scalar
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -206,6 +208,20 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]:
|
||||
return entry["subset"]
|
||||
|
||||
def get_all_train_cameras(self) -> CamerasBase:
|
||||
"""
|
||||
Returns the cameras corresponding to all the known frames.
|
||||
"""
|
||||
cameras = []
|
||||
# pyre-ignore[16]
|
||||
for frame_idx, frame_annot in enumerate(self.frame_annots):
|
||||
frame_type = self._get_frame_type(frame_annot)
|
||||
if frame_type is None:
|
||||
raise ValueError("subsets not loaded")
|
||||
if is_known_frame_scalar(frame_type):
|
||||
cameras.append(self[frame_idx].camera)
|
||||
return join_cameras_as_batch(cameras)
|
||||
|
||||
def __getitem__(self, index) -> FrameData:
|
||||
# pyre-ignore[16]
|
||||
if index >= len(self.frame_annots):
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Tuple, Type
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
from omegaconf import DictConfig, open_dict
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
@@ -15,6 +15,7 @@ from pytorch3d.implicitron.tools.config import (
|
||||
registry,
|
||||
run_auto_creation,
|
||||
)
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
|
||||
from .dataset_map_provider import (
|
||||
DatasetMap,
|
||||
@@ -267,6 +268,15 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
def get_task(self) -> Task:
|
||||
return Task(self.task_str)
|
||||
|
||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||
if Task(self.task_str) == Task.MULTI_SEQUENCE:
|
||||
return None
|
||||
|
||||
# pyre-ignore[16]
|
||||
train_dataset = self.dataset_map.train
|
||||
assert isinstance(train_dataset, JsonIndexDataset)
|
||||
return train_dataset.get_all_train_cameras()
|
||||
|
||||
|
||||
def _get_co3d_set_names_mapping(
|
||||
task: Task,
|
||||
|
||||
@@ -18,7 +18,7 @@ from pytorch3d.implicitron.tools.config import (
|
||||
expand_args_fields,
|
||||
run_auto_creation,
|
||||
)
|
||||
from pytorch3d.renderer import PerspectiveCameras
|
||||
from pytorch3d.renderer import CamerasBase, join_cameras_as_batch, PerspectiveCameras
|
||||
|
||||
from .dataset_base import DatasetBase, FrameData
|
||||
from .dataset_map_provider import (
|
||||
@@ -110,7 +110,7 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
||||
# - images: [N, 3, H, W] tensor of rgb images - floats in [0,1]
|
||||
# - fg_probabilities: None or [N, 1, H, W] of floats in [0,1]
|
||||
# - splits: List[List[int]] of indices for train/val/test subsets.
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_dataset(
|
||||
self, split_idx: int, frame_type: str, set_eval_batches: bool = False
|
||||
@@ -162,6 +162,11 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
||||
def get_task(self) -> Task:
|
||||
return Task.SINGLE_SEQUENCE
|
||||
|
||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||
# pyre-ignore[16]
|
||||
cameras = [self.poses[i] for i in self.i_split[0]]
|
||||
return join_cameras_as_batch(cameras)
|
||||
|
||||
|
||||
def _interpret_blender_cameras(
|
||||
poses: torch.Tensor, H: int, W: int, focal: float
|
||||
|
||||
@@ -12,12 +12,10 @@ from typing import Any, cast, Dict, List, Optional, Tuple
|
||||
import lpips
|
||||
import torch
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
|
||||
CO3D_CATEGORIES,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.utils import is_known_frame
|
||||
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
|
||||
aggregate_nvs_results,
|
||||
eval_batch,
|
||||
@@ -120,16 +118,7 @@ def evaluate_dbir_for_category(
|
||||
if test_dataset is None or test_dataloader is None:
|
||||
raise ValueError("must have a test dataset.")
|
||||
|
||||
if task == Task.SINGLE_SEQUENCE:
|
||||
# all_source_cameras are needed for evaluation of the
|
||||
# target camera difficulty
|
||||
# pyre-fixme[16]: `JsonIndexDataset` has no attribute `frame_annots`.
|
||||
sequence_name = test_dataset.frame_annots[0]["frame_annotation"].sequence_name
|
||||
all_source_cameras = _get_all_source_cameras(
|
||||
test_dataset, sequence_name, num_workers=num_workers
|
||||
)
|
||||
else:
|
||||
all_source_cameras = None
|
||||
all_train_cameras = data_source.get_all_train_cameras()
|
||||
|
||||
image_size = cast(JsonIndexDataset, test_dataset).image_width
|
||||
|
||||
@@ -160,7 +149,7 @@ def evaluate_dbir_for_category(
|
||||
preds["implicitron_render"],
|
||||
bg_color=bg_color,
|
||||
lpips_model=lpips_model,
|
||||
source_cameras=all_source_cameras,
|
||||
source_cameras=all_train_cameras,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -184,35 +173,5 @@ def _print_aggregate_results(
|
||||
print("")
|
||||
|
||||
|
||||
def _get_all_source_cameras(
|
||||
dataset: DatasetBase, sequence_name: str, num_workers: int = 8
|
||||
):
|
||||
"""
|
||||
Loads all training cameras of a given sequence.
|
||||
|
||||
The set of all seen cameras is needed for evaluating the viewpoint difficulty
|
||||
for the singlescene evaluation.
|
||||
|
||||
Args:
|
||||
dataset: Co3D dataset object.
|
||||
sequence_name: The name of the sequence.
|
||||
num_workers: The number of for the utilized dataloader.
|
||||
"""
|
||||
|
||||
# load all source cameras of the sequence
|
||||
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
|
||||
dataset_for_loader = torch.utils.data.Subset(dataset, seq_idx)
|
||||
(all_frame_data,) = torch.utils.data.DataLoader(
|
||||
dataset_for_loader,
|
||||
shuffle=False,
|
||||
batch_size=len(dataset_for_loader),
|
||||
num_workers=num_workers,
|
||||
collate_fn=dataset.frame_data_type.collate,
|
||||
)
|
||||
is_known = is_known_frame(all_frame_data.frame_type)
|
||||
source_cameras = all_frame_data.camera[torch.where(is_known)[0]]
|
||||
return source_cameras
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -143,7 +143,7 @@ def eval_batch(
|
||||
visualize: bool = False,
|
||||
visualize_visdom_env: str = "eval_debug",
|
||||
break_after_visualising: bool = True,
|
||||
source_cameras: Optional[List[CamerasBase]] = None,
|
||||
source_cameras: Optional[CamerasBase] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Produce performance metrics for a single batch of new-view synthesis
|
||||
|
||||
Reference in New Issue
Block a user