mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Summary: Create a test runner for the eval_demo code. Debugging this is useful for understanding datasets. Introduces an environment variable INTERACTIVE_TESTING for ignoring tests which are not intended for use in regular test runs. Reviewed By: shapovalov Differential Revision: D35964016 fbshipit-source-id: ab0f93aff66b6cfeca942b14466cf81f7feb2224
221 lines
7.2 KiB
Python
221 lines
7.2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import copy
|
|
import dataclasses
|
|
import os
|
|
from typing import cast, Optional
|
|
|
|
import lpips
|
|
import torch
|
|
from iopath.common.file_io import PathManager
|
|
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
|
|
from pytorch3d.implicitron.dataset.dataset_zoo import CO3D_CATEGORIES, dataset_zoo
|
|
from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
|
FrameData,
|
|
ImplicitronDataset,
|
|
ImplicitronDatasetBase,
|
|
)
|
|
from pytorch3d.implicitron.dataset.utils import is_known_frame
|
|
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
|
|
aggregate_nvs_results,
|
|
eval_batch,
|
|
pretty_print_nvs_metrics,
|
|
summarize_nvs_eval_results,
|
|
)
|
|
from pytorch3d.implicitron.models.model_dbir import ModelDBIR
|
|
from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
|
|
from tqdm import tqdm
|
|
|
|
|
|
def main() -> None:
|
|
"""
|
|
Evaluates new view synthesis metrics of a simple depth-based image rendering
|
|
(DBIR) model for multisequence/singlesequence tasks for several categories.
|
|
|
|
The evaluation is conducted on the same data as in [1] and, hence, the results
|
|
are directly comparable to the numbers reported in [1].
|
|
|
|
References:
|
|
[1] J. Reizenstein, R. Shapovalov, P. Henzler, L. Sbordone,
|
|
P. Labatut, D. Novotny:
|
|
Common Objects in 3D: Large-Scale Learning
|
|
and Evaluation of Real-life 3D Category Reconstruction
|
|
"""
|
|
|
|
task_results = {}
|
|
for task in ("singlesequence", "multisequence"):
|
|
task_results[task] = []
|
|
for category in CO3D_CATEGORIES[: (20 if task == "singlesequence" else 10)]:
|
|
for single_sequence_id in (0, 1) if task == "singlesequence" else (None,):
|
|
category_result = evaluate_dbir_for_category(
|
|
category, task=task, single_sequence_id=single_sequence_id
|
|
)
|
|
print("")
|
|
print(
|
|
f"Results for task={task}; category={category};"
|
|
+ (
|
|
f" sequence={single_sequence_id}:"
|
|
if single_sequence_id is not None
|
|
else ":"
|
|
)
|
|
)
|
|
pretty_print_nvs_metrics(category_result)
|
|
print("")
|
|
|
|
task_results[task].append(category_result)
|
|
_print_aggregate_results(task, task_results)
|
|
|
|
for task in task_results:
|
|
_print_aggregate_results(task, task_results)
|
|
|
|
|
|
def evaluate_dbir_for_category(
|
|
category: str = "apple",
|
|
bg_color: float = 0.0,
|
|
task: str = "singlesequence",
|
|
single_sequence_id: Optional[int] = None,
|
|
num_workers: int = 16,
|
|
path_manager: Optional[PathManager] = None,
|
|
):
|
|
"""
|
|
Evaluates new view synthesis metrics of a simple depth-based image rendering
|
|
(DBIR) model for a given task, category, and sequence (in case task=='singlesequence').
|
|
|
|
Args:
|
|
category: Object category.
|
|
bg_color: Background color of the renders.
|
|
task: Evaluation task. Either singlesequence or multisequence.
|
|
single_sequence_id: The ID of the evaluiation sequence for the singlesequence task.
|
|
num_workers: The number of workers for the employed dataloaders.
|
|
path_manager: (optional) Used for interpreting paths.
|
|
|
|
Returns:
|
|
category_result: A dictionary of quantitative metrics.
|
|
"""
|
|
|
|
single_sequence_id = single_sequence_id if single_sequence_id is not None else -1
|
|
|
|
torch.manual_seed(42)
|
|
|
|
if task not in ["multisequence", "singlesequence"]:
|
|
raise ValueError("'task' has to be either 'multisequence' or 'singlesequence'")
|
|
|
|
datasets = dataset_zoo(
|
|
category=category,
|
|
dataset_root=os.environ["CO3D_DATASET_ROOT"],
|
|
assert_single_seq=task == "singlesequence",
|
|
dataset_name=f"co3d_{task}",
|
|
test_on_train=False,
|
|
load_point_clouds=True,
|
|
test_restrict_sequence_id=single_sequence_id,
|
|
path_manager=path_manager,
|
|
)
|
|
|
|
dataloaders = dataloader_zoo(
|
|
datasets,
|
|
dataset_name=f"co3d_{task}",
|
|
)
|
|
|
|
test_dataset = datasets["test"]
|
|
test_dataloader = dataloaders["test"]
|
|
|
|
if task == "singlesequence":
|
|
# all_source_cameras are needed for evaluation of the
|
|
# target camera difficulty
|
|
# pyre-fixme[16]: `ImplicitronDataset` has no attribute `frame_annots`.
|
|
sequence_name = test_dataset.frame_annots[0]["frame_annotation"].sequence_name
|
|
all_source_cameras = _get_all_source_cameras(
|
|
test_dataset, sequence_name, num_workers=num_workers
|
|
)
|
|
else:
|
|
all_source_cameras = None
|
|
|
|
image_size = cast(ImplicitronDataset, test_dataset).image_width
|
|
|
|
if image_size is None:
|
|
raise ValueError("Image size should be set in the dataset")
|
|
|
|
# init the simple DBIR model
|
|
model = ModelDBIR(
|
|
image_size=image_size,
|
|
bg_color=bg_color,
|
|
max_points=int(1e5),
|
|
)
|
|
model.cuda()
|
|
|
|
# init the lpips model for eval
|
|
lpips_model = lpips.LPIPS(net="vgg")
|
|
lpips_model = lpips_model.cuda()
|
|
|
|
per_batch_eval_results = []
|
|
print("Evaluating DBIR model ...")
|
|
for frame_data in tqdm(test_dataloader):
|
|
frame_data = dataclass_to_cuda_(frame_data)
|
|
preds = model(**dataclasses.asdict(frame_data))
|
|
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
|
|
per_batch_eval_results.append(
|
|
eval_batch(
|
|
frame_data,
|
|
nvs_prediction,
|
|
bg_color=bg_color,
|
|
lpips_model=lpips_model,
|
|
source_cameras=all_source_cameras,
|
|
)
|
|
)
|
|
|
|
category_result_flat, category_result = summarize_nvs_eval_results(
|
|
per_batch_eval_results, task
|
|
)
|
|
|
|
return category_result["results"]
|
|
|
|
|
|
def _print_aggregate_results(task, task_results) -> None:
|
|
"""
|
|
Prints the aggregate metrics for a given task.
|
|
"""
|
|
aggregate_task_result = aggregate_nvs_results(task_results[task])
|
|
print("")
|
|
print(f"Aggregate results for task={task}:")
|
|
pretty_print_nvs_metrics(aggregate_task_result)
|
|
print("")
|
|
|
|
|
|
def _get_all_source_cameras(
|
|
dataset: ImplicitronDatasetBase, sequence_name: str, num_workers: int = 8
|
|
):
|
|
"""
|
|
Loads all training cameras of a given sequence.
|
|
|
|
The set of all seen cameras is needed for evaluating the viewpoint difficulty
|
|
for the singlescene evaluation.
|
|
|
|
Args:
|
|
dataset: Co3D dataset object.
|
|
sequence_name: The name of the sequence.
|
|
num_workers: The number of for the utilized dataloader.
|
|
"""
|
|
|
|
# load all source cameras of the sequence
|
|
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
|
|
dataset_for_loader = torch.utils.data.Subset(dataset, seq_idx)
|
|
(all_frame_data,) = torch.utils.data.DataLoader(
|
|
dataset_for_loader,
|
|
shuffle=False,
|
|
batch_size=len(dataset_for_loader),
|
|
num_workers=num_workers,
|
|
collate_fn=FrameData.collate,
|
|
)
|
|
is_known = is_known_frame(all_frame_data.frame_type)
|
|
source_cameras = all_frame_data.camera[torch.where(is_known)[0]]
|
|
return source_cameras
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|