mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
get_all_train_cameras
Summary: As part of removing Task, make the dataset code generate the source cameras for itself. There's a small optimization available here, in that the JsonIndexDataset could avoid loading images. Reviewed By: shapovalov Differential Revision: D37313423 fbshipit-source-id: 3e5e0b2aabbf9cc51f10547a3523e98c72ad8755
This commit is contained in:
parent
771cf8a328
commit
4e87c2b7f1
@ -66,9 +66,7 @@ from packaging import version
|
||||
from pytorch3d.implicitron.dataset import utils as ds_utils
|
||||
from pytorch3d.implicitron.dataset.data_loader_map_provider import DataLoaderMap
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
|
||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
|
||||
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
|
||||
from pytorch3d.implicitron.tools import model_io, vis_utils
|
||||
@ -456,6 +454,7 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
|
||||
datasource = ImplicitronDataSource(**cfg.data_source_args)
|
||||
datasets, dataloaders = datasource.get_datasets_and_dataloaders()
|
||||
task = datasource.get_task()
|
||||
all_train_cameras = datasource.get_all_train_cameras()
|
||||
|
||||
# init the model
|
||||
model, stats, optimizer_state = init_model(cfg)
|
||||
@ -466,7 +465,16 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
|
||||
|
||||
# only run evaluation on the test dataloader
|
||||
if cfg.eval_only:
|
||||
_eval_and_dump(cfg, task, datasets, dataloaders, model, stats, device=device)
|
||||
_eval_and_dump(
|
||||
cfg,
|
||||
task,
|
||||
all_train_cameras,
|
||||
datasets,
|
||||
dataloaders,
|
||||
model,
|
||||
stats,
|
||||
device=device,
|
||||
)
|
||||
return
|
||||
|
||||
# init the optimizer
|
||||
@ -528,7 +536,9 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
|
||||
and cfg.test_interval > 0
|
||||
and epoch % cfg.test_interval == 0
|
||||
):
|
||||
_run_eval(model, stats, dataloaders.test, task, device=device)
|
||||
_run_eval(
|
||||
model, all_train_cameras, dataloaders.test, task, device=device
|
||||
)
|
||||
|
||||
assert stats.epoch == epoch, "inconsistent stats!"
|
||||
|
||||
@ -548,12 +558,22 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
|
||||
logger.info(f"LR change! {cur_lr} -> {new_lr}")
|
||||
|
||||
if cfg.test_when_finished:
|
||||
_eval_and_dump(cfg, task, datasets, dataloaders, model, stats, device=device)
|
||||
_eval_and_dump(
|
||||
cfg,
|
||||
task,
|
||||
all_train_cameras,
|
||||
datasets,
|
||||
dataloaders,
|
||||
model,
|
||||
stats,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
def _eval_and_dump(
|
||||
cfg,
|
||||
task: Task,
|
||||
all_train_cameras: Optional[CamerasBase],
|
||||
datasets: DatasetMap,
|
||||
dataloaders: DataLoaderMap,
|
||||
model,
|
||||
@ -570,13 +590,7 @@ def _eval_and_dump(
|
||||
if dataloader is None:
|
||||
raise ValueError('DataLoaderMap have to contain the "test" entry for eval!')
|
||||
|
||||
if task == Task.SINGLE_SEQUENCE:
|
||||
if datasets.train is None:
|
||||
raise ValueError("train dataset must be provided")
|
||||
all_source_cameras = _get_all_source_cameras(datasets.train)
|
||||
else:
|
||||
all_source_cameras = None
|
||||
results = _run_eval(model, all_source_cameras, dataloader, task, device=device)
|
||||
results = _run_eval(model, all_train_cameras, dataloader, task, device=device)
|
||||
|
||||
# add the evaluation epoch to the results
|
||||
for r in results:
|
||||
@ -603,7 +617,7 @@ def _get_eval_frame_data(frame_data):
|
||||
return frame_data_for_eval
|
||||
|
||||
|
||||
def _run_eval(model, all_source_cameras, loader, task: Task, device):
|
||||
def _run_eval(model, all_train_cameras, loader, task: Task, device):
|
||||
"""
|
||||
Run the evaluation loop on the test dataloader
|
||||
"""
|
||||
@ -631,7 +645,7 @@ def _run_eval(model, all_source_cameras, loader, task: Task, device):
|
||||
implicitron_render,
|
||||
bg_color="black",
|
||||
lpips_model=lpips_model,
|
||||
source_cameras=all_source_cameras,
|
||||
source_cameras=all_train_cameras,
|
||||
)
|
||||
)
|
||||
|
||||
@ -642,31 +656,6 @@ def _run_eval(model, all_source_cameras, loader, task: Task, device):
|
||||
return category_result["results"]
|
||||
|
||||
|
||||
def _get_all_source_cameras(
|
||||
dataset: JsonIndexDataset,
|
||||
num_workers: int = 8,
|
||||
) -> CamerasBase:
|
||||
"""
|
||||
Load and return all the source cameras in the training dataset
|
||||
"""
|
||||
|
||||
all_frame_data = next(
|
||||
iter(
|
||||
torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
shuffle=False,
|
||||
batch_size=len(dataset),
|
||||
num_workers=num_workers,
|
||||
collate_fn=FrameData.collate,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
is_source = ds_utils.is_known_frame(all_frame_data.frame_type)
|
||||
source_cameras = all_frame_data.camera[torch.where(is_source)[0]]
|
||||
return source_cameras
|
||||
|
||||
|
||||
def _seed_all_random_engines(seed: int):
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
@ -4,13 +4,14 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
registry,
|
||||
ReplaceableBase,
|
||||
run_auto_creation,
|
||||
)
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
|
||||
from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa
|
||||
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
|
||||
@ -28,6 +29,15 @@ class DataSourceBase(ReplaceableBase):
|
||||
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||
"""
|
||||
If the data is all for a single scene, returns a list
|
||||
of the known training cameras for that scene, which is
|
||||
used for evaluating the viewpoint difficulty of the
|
||||
unseen cameras.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@registry.register
|
||||
class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
|
||||
@ -56,3 +66,6 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
|
||||
|
||||
def get_task(self) -> Task:
|
||||
return self.dataset_map_provider.get_task()
|
||||
|
||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||
return self.dataset_map_provider.get_all_train_cameras()
|
||||
|
@ -224,7 +224,7 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
|
||||
# _seq_to_idx: Dict[str, List[int]] = field(init=False)
|
||||
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_frame_numbers_and_timestamps(
|
||||
self, idxs: Sequence[int]
|
||||
|
@ -12,6 +12,7 @@ from typing import Iterator, Optional
|
||||
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
|
||||
from .dataset_base import DatasetBase
|
||||
|
||||
@ -73,6 +74,15 @@ class DatasetMapProviderBase(ReplaceableBase):
|
||||
def get_task(self) -> Task:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||
"""
|
||||
If the data is all for a single scene, returns a list
|
||||
of the known training cameras for that scene, which is
|
||||
used for evaluating the difficulty of the unknown
|
||||
cameras. Otherwise return None.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@registry.register
|
||||
class PathManagerFactory(ReplaceableBase):
|
||||
|
@ -32,11 +32,13 @@ import torch
|
||||
from PIL import Image
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
from pytorch3d.io import IO
|
||||
from pytorch3d.renderer.cameras import PerspectiveCameras
|
||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
|
||||
from . import types
|
||||
from .dataset_base import DatasetBase, FrameData
|
||||
from .utils import is_known_frame_scalar
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -206,6 +208,20 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]:
|
||||
return entry["subset"]
|
||||
|
||||
def get_all_train_cameras(self) -> CamerasBase:
|
||||
"""
|
||||
Returns the cameras corresponding to all the known frames.
|
||||
"""
|
||||
cameras = []
|
||||
# pyre-ignore[16]
|
||||
for frame_idx, frame_annot in enumerate(self.frame_annots):
|
||||
frame_type = self._get_frame_type(frame_annot)
|
||||
if frame_type is None:
|
||||
raise ValueError("subsets not loaded")
|
||||
if is_known_frame_scalar(frame_type):
|
||||
cameras.append(self[frame_idx].camera)
|
||||
return join_cameras_as_batch(cameras)
|
||||
|
||||
def __getitem__(self, index) -> FrameData:
|
||||
# pyre-ignore[16]
|
||||
if index >= len(self.frame_annots):
|
||||
|
@ -7,7 +7,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Tuple, Type
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
from omegaconf import DictConfig, open_dict
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
@ -15,6 +15,7 @@ from pytorch3d.implicitron.tools.config import (
|
||||
registry,
|
||||
run_auto_creation,
|
||||
)
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
|
||||
from .dataset_map_provider import (
|
||||
DatasetMap,
|
||||
@ -267,6 +268,15 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
def get_task(self) -> Task:
|
||||
return Task(self.task_str)
|
||||
|
||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||
if Task(self.task_str) == Task.MULTI_SEQUENCE:
|
||||
return None
|
||||
|
||||
# pyre-ignore[16]
|
||||
train_dataset = self.dataset_map.train
|
||||
assert isinstance(train_dataset, JsonIndexDataset)
|
||||
return train_dataset.get_all_train_cameras()
|
||||
|
||||
|
||||
def _get_co3d_set_names_mapping(
|
||||
task: Task,
|
||||
|
@ -18,7 +18,7 @@ from pytorch3d.implicitron.tools.config import (
|
||||
expand_args_fields,
|
||||
run_auto_creation,
|
||||
)
|
||||
from pytorch3d.renderer import PerspectiveCameras
|
||||
from pytorch3d.renderer import CamerasBase, join_cameras_as_batch, PerspectiveCameras
|
||||
|
||||
from .dataset_base import DatasetBase, FrameData
|
||||
from .dataset_map_provider import (
|
||||
@ -110,7 +110,7 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
||||
# - images: [N, 3, H, W] tensor of rgb images - floats in [0,1]
|
||||
# - fg_probabilities: None or [N, 1, H, W] of floats in [0,1]
|
||||
# - splits: List[List[int]] of indices for train/val/test subsets.
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_dataset(
|
||||
self, split_idx: int, frame_type: str, set_eval_batches: bool = False
|
||||
@ -162,6 +162,11 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
||||
def get_task(self) -> Task:
|
||||
return Task.SINGLE_SEQUENCE
|
||||
|
||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||
# pyre-ignore[16]
|
||||
cameras = [self.poses[i] for i in self.i_split[0]]
|
||||
return join_cameras_as_batch(cameras)
|
||||
|
||||
|
||||
def _interpret_blender_cameras(
|
||||
poses: torch.Tensor, H: int, W: int, focal: float
|
||||
|
@ -12,12 +12,10 @@ from typing import Any, cast, Dict, List, Optional, Tuple
|
||||
import lpips
|
||||
import torch
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
|
||||
CO3D_CATEGORIES,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.utils import is_known_frame
|
||||
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
|
||||
aggregate_nvs_results,
|
||||
eval_batch,
|
||||
@ -120,16 +118,7 @@ def evaluate_dbir_for_category(
|
||||
if test_dataset is None or test_dataloader is None:
|
||||
raise ValueError("must have a test dataset.")
|
||||
|
||||
if task == Task.SINGLE_SEQUENCE:
|
||||
# all_source_cameras are needed for evaluation of the
|
||||
# target camera difficulty
|
||||
# pyre-fixme[16]: `JsonIndexDataset` has no attribute `frame_annots`.
|
||||
sequence_name = test_dataset.frame_annots[0]["frame_annotation"].sequence_name
|
||||
all_source_cameras = _get_all_source_cameras(
|
||||
test_dataset, sequence_name, num_workers=num_workers
|
||||
)
|
||||
else:
|
||||
all_source_cameras = None
|
||||
all_train_cameras = data_source.get_all_train_cameras()
|
||||
|
||||
image_size = cast(JsonIndexDataset, test_dataset).image_width
|
||||
|
||||
@ -160,7 +149,7 @@ def evaluate_dbir_for_category(
|
||||
preds["implicitron_render"],
|
||||
bg_color=bg_color,
|
||||
lpips_model=lpips_model,
|
||||
source_cameras=all_source_cameras,
|
||||
source_cameras=all_train_cameras,
|
||||
)
|
||||
)
|
||||
|
||||
@ -184,35 +173,5 @@ def _print_aggregate_results(
|
||||
print("")
|
||||
|
||||
|
||||
def _get_all_source_cameras(
|
||||
dataset: DatasetBase, sequence_name: str, num_workers: int = 8
|
||||
):
|
||||
"""
|
||||
Loads all training cameras of a given sequence.
|
||||
|
||||
The set of all seen cameras is needed for evaluating the viewpoint difficulty
|
||||
for the singlescene evaluation.
|
||||
|
||||
Args:
|
||||
dataset: Co3D dataset object.
|
||||
sequence_name: The name of the sequence.
|
||||
num_workers: The number of for the utilized dataloader.
|
||||
"""
|
||||
|
||||
# load all source cameras of the sequence
|
||||
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
|
||||
dataset_for_loader = torch.utils.data.Subset(dataset, seq_idx)
|
||||
(all_frame_data,) = torch.utils.data.DataLoader(
|
||||
dataset_for_loader,
|
||||
shuffle=False,
|
||||
batch_size=len(dataset_for_loader),
|
||||
num_workers=num_workers,
|
||||
collate_fn=dataset.frame_data_type.collate,
|
||||
)
|
||||
is_known = is_known_frame(all_frame_data.frame_type)
|
||||
source_cameras = all_frame_data.camera[torch.where(is_known)[0]]
|
||||
return source_cameras
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -143,7 +143,7 @@ def eval_batch(
|
||||
visualize: bool = False,
|
||||
visualize_visdom_env: str = "eval_debug",
|
||||
break_after_visualising: bool = True,
|
||||
source_cameras: Optional[List[CamerasBase]] = None,
|
||||
source_cameras: Optional[CamerasBase] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Produce performance metrics for a single batch of new-view synthesis
|
||||
|
42
tests/implicitron/test_data_json_index.py
Normal file
42
tests/implicitron/test_data_json_index.py
Normal file
@ -0,0 +1,42 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||
from pytorch3d.implicitron.tools.config import get_default_args
|
||||
from pytorch3d.renderer import PerspectiveCameras
|
||||
from tests.common_testing import TestCaseMixin
|
||||
|
||||
# These tests are only run internally, where the data is available.
|
||||
internal = os.environ.get("FB_TEST", False)
|
||||
inside_re_worker = os.environ.get("INSIDE_RE_WORKER", False)
|
||||
skip_tests = not internal or inside_re_worker
|
||||
|
||||
|
||||
@unittest.skipIf(skip_tests, "no data")
|
||||
class TestDataJsonIndex(TestCaseMixin, unittest.TestCase):
|
||||
def test_loaders(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider"
|
||||
dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
dataset_args.category = "skateboard"
|
||||
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||
dataset_args.test_restrict_sequence_id = 0
|
||||
dataset_args.dataset_JsonIndexDataset_args.limit_sequences_to = 1
|
||||
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
|
||||
cameras = data_source.get_all_train_cameras()
|
||||
self.assertIsInstance(cameras, PerspectiveCameras)
|
||||
self.assertEqual(len(cameras), 81)
|
||||
|
||||
data_sets, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||
|
||||
self.assertEqual(len(data_sets.train), 81)
|
||||
self.assertEqual(len(data_sets.val), 102)
|
||||
self.assertEqual(len(data_sets.test), 102)
|
@ -16,6 +16,7 @@ from pytorch3d.implicitron.dataset.llff_dataset_map_provider import (
|
||||
LlffDatasetMapProvider,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
|
||||
from pytorch3d.renderer import PerspectiveCameras
|
||||
from tests.common_testing import TestCaseMixin
|
||||
|
||||
|
||||
@ -123,3 +124,7 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
|
||||
for i in data_loaders.test:
|
||||
self.assertEqual(i.frame_type, ["unseen"])
|
||||
self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))
|
||||
|
||||
cameras = data_source.get_all_train_cameras()
|
||||
self.assertIsInstance(cameras, PerspectiveCameras)
|
||||
self.assertEqual(len(cameras), 100)
|
||||
|
@ -24,11 +24,7 @@ from pytorch3d.implicitron.tools.config import expand_args_fields, registry
|
||||
from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth
|
||||
from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
|
||||
|
||||
|
||||
if os.environ.get("FB_TEST", False):
|
||||
from .common_resources import get_skateboard_data, provide_lpips_vgg
|
||||
else:
|
||||
from common_resources import get_skateboard_data, provide_lpips_vgg
|
||||
from .common_resources import get_skateboard_data, provide_lpips_vgg
|
||||
|
||||
|
||||
class TestEvaluation(unittest.TestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user