Extracted ImplicitronModelBase and unified API for GenericModel and ModelDBIR

Summary:
To avoid model_zoo, we need to make GenericModel pluggable.
I also align creation APIs for convenience.

Reviewed By: bottler, davnov134

Differential Revision: D35933093

fbshipit-source-id: 8228926528eb41a795fbfbe32304b8019197e2b1
This commit is contained in:
Roman Shapovalov 2022-05-09 15:23:07 -07:00 committed by Facebook GitHub Bot
parent 5c59841863
commit a6dada399d
11 changed files with 282 additions and 178 deletions

View File

@ -71,7 +71,7 @@ from pytorch3d.implicitron.dataset.implicitron_dataset import (
ImplicitronDataset,
)
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
from pytorch3d.implicitron.models.base import EvaluationMode, GenericModel
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
from pytorch3d.implicitron.tools import model_io, vis_utils
from pytorch3d.implicitron.tools.config import (
enable_get_default_args,
@ -615,11 +615,11 @@ def run_eval(cfg, model, all_source_cameras, loader, task, device):
preds = model(
**{**frame_data_for_eval, "evaluation_mode": EvaluationMode.EVALUATION}
)
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
implicitron_render = copy.deepcopy(preds["implicitron_render"])
per_batch_eval_results.append(
evaluate.eval_batch(
frame_data,
nvs_prediction,
implicitron_render,
bg_color="black",
lpips_model=lpips_model,
source_cameras=all_source_cameras,

View File

@ -29,7 +29,7 @@ from pytorch3d.implicitron.dataset.implicitron_dataset import (
ImplicitronDataset,
)
from pytorch3d.implicitron.dataset.utils import is_train_frame
from pytorch3d.implicitron.models.base import EvaluationMode
from pytorch3d.implicitron.models.base_model import EvaluationMode
from pytorch3d.implicitron.tools.configurable import get_default_args
from pytorch3d.implicitron.tools.eval_video_trajectory import (
generate_eval_video_cameras,

View File

@ -5,10 +5,9 @@
# LICENSE file in the root directory of this source tree.
import copy
import dataclasses
import os
from typing import cast, Optional
from typing import cast, Optional, Tuple
import lpips
import torch
@ -76,7 +75,7 @@ def main() -> None:
def evaluate_dbir_for_category(
category: str = "apple",
bg_color: float = 0.0,
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0),
task: str = "singlesequence",
single_sequence_id: Optional[int] = None,
num_workers: int = 16,
@ -141,8 +140,9 @@ def evaluate_dbir_for_category(
raise ValueError("Image size should be set in the dataset")
# init the simple DBIR model
model = ModelDBIR(
image_size=image_size,
model = ModelDBIR( # pyre-ignore[28]: ctor implicitly overridden
render_image_width=image_size,
render_image_height=image_size,
bg_color=bg_color,
max_points=int(1e5),
)
@ -157,11 +157,10 @@ def evaluate_dbir_for_category(
for frame_data in tqdm(test_dataloader):
frame_data = dataclass_to_cuda_(frame_data)
preds = model(**dataclasses.asdict(frame_data))
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
per_batch_eval_results.append(
eval_batch(
frame_data,
nvs_prediction,
preds["implicitron_render"],
bg_color=bg_color,
lpips_model=lpips_model,
source_cameras=all_source_cameras,

View File

@ -9,12 +9,14 @@ import copy
import warnings
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Sequence, Union
import numpy as np
import torch
import torch.nn.functional as F
from pytorch3d.implicitron.dataset.implicitron_dataset import FrameData
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
from pytorch3d.implicitron.models.base_model import ImplicitronRender
from pytorch3d.implicitron.tools import vis_utils
from pytorch3d.implicitron.tools.camera_utils import volumetric_camera_overlaps
from pytorch3d.implicitron.tools.image_utils import mask_background
@ -31,18 +33,6 @@ from visdom import Visdom
EVAL_N_SRC_VIEWS = [1, 3, 5, 7, 9]
@dataclass
class NewViewSynthesisPrediction:
"""
Holds the tensors that describe a result of synthesizing new views.
"""
depth_render: Optional[torch.Tensor] = None
image_render: Optional[torch.Tensor] = None
mask_render: Optional[torch.Tensor] = None
camera_distance: Optional[torch.Tensor] = None
@dataclass
class _Visualizer:
image_render: torch.Tensor
@ -145,8 +135,8 @@ class _Visualizer:
def eval_batch(
frame_data: FrameData,
nvs_prediction: NewViewSynthesisPrediction,
bg_color: Union[torch.Tensor, str, float] = "black",
implicitron_render: ImplicitronRender,
bg_color: Union[torch.Tensor, Sequence, str, float] = "black",
mask_thr: float = 0.5,
lpips_model=None,
visualize: bool = False,
@ -162,14 +152,14 @@ def eval_batch(
is True), a new-view synthesis method (NVS) is tasked to generate new views
of the scene from the viewpoint of the target views (for which
frame_data.frame_type.endswith('known') is False). The resulting
synthesized new views, stored in `nvs_prediction`, are compared to the
synthesized new views, stored in `implicitron_render`, are compared to the
target ground truth in `frame_data` in terms of geometry and appearance
resulting in a dictionary of metrics returned by the `eval_batch` function.
Args:
frame_data: A FrameData object containing the input to the new view
synthesis method.
nvs_prediction: The data describing the synthesized new views.
implicitron_render: The data describing the synthesized new views.
bg_color: The background color of the generated new views and the
ground truth.
lpips_model: A pre-trained model for evaluating the LPIPS metric.
@ -184,26 +174,39 @@ def eval_batch(
ValueError if frame_data does not have frame_type, camera, or image_rgb
ValueError if the batch has a mix of training and test samples
ValueError if the batch frames are not [unseen, known, known, ...]
ValueError if one of the required fields in nvs_prediction is missing
ValueError if one of the required fields in implicitron_render is missing
"""
REQUIRED_NVS_PREDICTION_FIELDS = ["mask_render", "image_render", "depth_render"]
frame_type = frame_data.frame_type
if frame_type is None:
raise ValueError("Frame type has not been set.")
# we check that all those fields are not None but Pyre can't infer that properly
# TODO: assign to local variables
# TODO: assign to local variables and simplify the code.
if frame_data.image_rgb is None:
raise ValueError("Image is not in the evaluation batch.")
if frame_data.camera is None:
raise ValueError("Camera is not in the evaluation batch.")
if any(not hasattr(nvs_prediction, k) for k in REQUIRED_NVS_PREDICTION_FIELDS):
raise ValueError("One of the required predicted fields is missing")
# eval all results in the resolution of the frame_data image
image_resol = tuple(frame_data.image_rgb.shape[2:])
# Post-process the render:
# 1) check implicitron_render for Nones,
# 2) obtain copies to make sure we dont edit the original data,
# 3) take only the 1st (target) image
# 4) resize to match ground-truth resolution
cloned_render: Dict[str, torch.Tensor] = {}
for k in ["mask_render", "image_render", "depth_render"]:
field = getattr(implicitron_render, k)
if field is None:
raise ValueError(f"A required predicted field {k} is missing")
imode = "bilinear" if k == "image_render" else "nearest"
cloned_render[k] = (
F.interpolate(field[:1], size=image_resol, mode=imode).detach().clone()
)
# obtain copies to make sure we dont edit the original data
nvs_prediction = copy.deepcopy(nvs_prediction)
frame_data = copy.deepcopy(frame_data)
# mask the ground truth depth in case frame_data contains the depth mask
@ -226,9 +229,6 @@ def eval_batch(
+ " a target view while the rest should be source views."
) # TODO: do we need to enforce this?
# take only the first (target image)
for k in REQUIRED_NVS_PREDICTION_FIELDS:
setattr(nvs_prediction, k, getattr(nvs_prediction, k)[:1])
for k in [
"depth_map",
"image_rgb",
@ -242,10 +242,6 @@ def eval_batch(
if frame_data.depth_map is None or frame_data.depth_map.sum() <= 0:
warnings.warn("Empty or missing depth map in evaluation!")
# eval all results in the resolution of the frame_data image
# pyre-fixme[16]: `Optional` has no attribute `shape`.
image_resol = list(frame_data.image_rgb.shape[2:])
# threshold the masks to make ground truth binary masks
mask_fg, mask_crop = [
(getattr(frame_data, k) >= mask_thr) for k in ("fg_probability", "mask_crop")
@ -258,29 +254,14 @@ def eval_batch(
bg_color=bg_color,
)
# resize to the target resolution
for k in REQUIRED_NVS_PREDICTION_FIELDS:
imode = "bilinear" if k == "image_render" else "nearest"
val = getattr(nvs_prediction, k)
setattr(
nvs_prediction,
k,
# pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got
# `List[typing.Any]`.
torch.nn.functional.interpolate(val, size=image_resol, mode=imode),
)
# clamp predicted images
# pyre-fixme[16]: `Optional` has no attribute `clamp`.
image_render = nvs_prediction.image_render.clamp(0.0, 1.0)
image_render = cloned_render["image_render"].clamp(0.0, 1.0)
if visualize:
visualizer = _Visualizer(
image_render=image_render,
image_rgb_masked=image_rgb_masked,
# pyre-fixme[6]: Expected `Tensor` for 3rd param but got
# `Optional[torch.Tensor]`.
depth_render=nvs_prediction.depth_render,
depth_render=cloned_render["depth_render"],
# pyre-fixme[6]: Expected `Tensor` for 4th param but got
# `Optional[torch.Tensor]`.
depth_map=frame_data.depth_map,
@ -292,9 +273,7 @@ def eval_batch(
results: Dict[str, Any] = {}
results["iou"] = iou(
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
nvs_prediction.mask_render,
cloned_render["mask_render"],
mask_fg,
mask=mask_crop,
)
@ -321,11 +300,7 @@ def eval_batch(
if name_postfix == "_fg":
# only record depth metrics for the foreground
_, abs_ = eval_depth(
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
nvs_prediction.depth_render,
# pyre-fixme[6]: Expected `Tensor` for 2nd param but got
# `Optional[torch.Tensor]`.
cloned_render["depth_render"],
frame_data.depth_map,
get_best_scale=True,
mask=loss_mask_now,
@ -343,7 +318,7 @@ def eval_batch(
if lpips_model is not None:
im1, im2 = [
2.0 * im.clamp(0.0, 1.0) - 1.0
for im in (image_rgb_masked, nvs_prediction.image_render)
for im in (image_rgb_masked, cloned_render["image_render"])
]
results["lpips"] = lpips_model.forward(im1, im2).item()

View File

@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import torch
from pytorch3d.implicitron.tools.config import ReplaceableBase
from pytorch3d.renderer.cameras import CamerasBase
from .renderer.base import EvaluationMode
@dataclass
class ImplicitronRender:
"""
Holds the tensors that describe a result of rendering.
"""
depth_render: Optional[torch.Tensor] = None
image_render: Optional[torch.Tensor] = None
mask_render: Optional[torch.Tensor] = None
camera_distance: Optional[torch.Tensor] = None
def clone(self) -> "ImplicitronRender":
def safe_clone(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
return t.detach().clone() if t is not None else None
return ImplicitronRender(
depth_render=safe_clone(self.depth_render),
image_render=safe_clone(self.image_render),
mask_render=safe_clone(self.mask_render),
camera_distance=safe_clone(self.camera_distance),
)
class ImplicitronModelBase(ReplaceableBase):
"""Replaceable abstract base for all image generation / rendering models.
`forward()` method produces a render with a depth map.
"""
def __init__(self) -> None:
super().__init__()
def forward(
self,
*, # force keyword-only arguments
image_rgb: Optional[torch.Tensor],
camera: CamerasBase,
fg_probability: Optional[torch.Tensor],
mask_crop: Optional[torch.Tensor],
depth_map: Optional[torch.Tensor],
sequence_name: Optional[List[str]],
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
**kwargs,
) -> Dict[str, Any]:
"""
Args:
image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images;
the first `min(B, n_train_target_views)` images are considered targets and
are used to supervise the renders; the rest corresponding to the source
viewpoints from which features will be extracted.
camera: An instance of CamerasBase containing a batch of `B` cameras corresponding
to the viewpoints of target images, from which the rays will be sampled,
and source images, which will be used for intersecting with target rays.
fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch of
foreground masks.
mask_crop: A binary tensor of shape `(B, 1, H, W)` deonting valid
regions in the input images (i.e. regions that do not correspond
to, e.g., zero-padding). When the `RaySampler`'s sampling mode is set to
"mask_sample", rays will be sampled in the non zero regions.
depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps.
sequence_name: A list of `B` strings corresponding to the sequence names
from which images `image_rgb` were extracted. They are used to match
target frames with relevant source frames.
evaluation_mode: one of EvaluationMode.TRAINING or
EvaluationMode.EVALUATION which determines the settings used for
rendering.
Returns:
preds: A dictionary containing all outputs of the forward pass. All models should
output an instance of `ImplicitronRender` in `preds["implicitron_render"]`.
"""
raise NotImplementedError()

View File

@ -13,11 +13,11 @@ from typing import Any, Dict, List, Optional, Tuple
import torch
import tqdm
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
NewViewSynthesisPrediction,
)
from pytorch3d.implicitron.tools import image_utils, vis_utils
from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation
from pytorch3d.implicitron.tools.config import (
registry,
run_auto_creation,
)
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
from pytorch3d.implicitron.tools.utils import cat_dataclass
from pytorch3d.renderer import RayBundle, utils as rend_utils
@ -25,6 +25,7 @@ from pytorch3d.renderer.cameras import CamerasBase
from visdom import Visdom
from .autodecoder import Autodecoder
from .base_model import ImplicitronModelBase, ImplicitronRender
from .implicit_function.base import ImplicitFunctionBase
from .implicit_function.idr_feature_field import IdrFeatureField # noqa
from .implicit_function.neural_radiance_field import ( # noqa
@ -56,8 +57,8 @@ STD_LOG_VARS = ["objective", "epoch", "sec/it"]
logger = logging.getLogger(__name__)
# pyre-ignore: 13
class GenericModel(Configurable, torch.nn.Module):
@registry.register
class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
"""
GenericModel is a wrapper for the neural implicit
rendering and reconstruction pipeline which consists
@ -452,7 +453,7 @@ class GenericModel(Configurable, torch.nn.Module):
preds["depths_render"] = rendered.depths.permute(0, 3, 1, 2)
preds["masks_render"] = rendered.masks.permute(0, 3, 1, 2)
preds["nvs_prediction"] = NewViewSynthesisPrediction(
preds["implicitron_render"] = ImplicitronRender(
image_render=preds["images_render"],
depth_render=preds["depths_render"],
mask_render=preds["masks_render"],

View File

@ -5,13 +5,11 @@
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional, Tuple
import torch
from pytorch3d.implicitron.dataset.utils import is_known_frame
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
NewViewSynthesisPrediction,
)
from pytorch3d.implicitron.tools.config import registry
from pytorch3d.implicitron.tools.point_cloud_utils import (
get_rgbd_point_cloud,
render_point_cloud_pytorch3d,
@ -19,41 +17,43 @@ from pytorch3d.implicitron.tools.point_cloud_utils import (
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.structures import Pointclouds
from .base_model import ImplicitronModelBase, ImplicitronRender
from .renderer.base import EvaluationMode
class ModelDBIR(torch.nn.Module):
@registry.register
class ModelDBIR(ImplicitronModelBase, torch.nn.Module):
"""
A simple depth-based image rendering model.
Args:
render_image_width: The width of the rendered rectangular images.
render_image_height: The height of the rendered rectangular images.
bg_color: The color of the background.
max_points: Maximum number of points in the point cloud
formed by unprojecting all source view depths.
If more points are present, they are randomly subsampled
to this number of points without replacement.
"""
def __init__(
self,
image_size: int = 256,
bg_color: float = 0.0,
max_points: int = -1,
):
"""
Initializes a simple DBIR model.
Args:
image_size: The size of the rendered rectangular images.
bg_color: The color of the background.
max_points: Maximum number of points in the point cloud
formed by unprojecting all source view depths.
If more points are present, they are randomly subsampled
to #max_size points without replacement.
"""
render_image_width: int = 256
render_image_height: int = 256
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
max_points: int = -1
def __post_init__(self):
super().__init__()
self.image_size = image_size
self.bg_color = bg_color
self.max_points = max_points
def forward(
self,
*, # force keyword-only arguments
image_rgb: Optional[torch.Tensor],
camera: CamerasBase,
image_rgb: torch.Tensor,
depth_map: torch.Tensor,
fg_probability: torch.Tensor,
fg_probability: Optional[torch.Tensor],
mask_crop: Optional[torch.Tensor],
depth_map: Optional[torch.Tensor],
sequence_name: Optional[List[str]],
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
frame_type: List[str],
**kwargs,
) -> Dict[str, Any]: # TODO: return a namedtuple or dataclass
@ -72,12 +72,21 @@ class ModelDBIR(torch.nn.Module):
Returns:
preds: A dict with the following fields:
nvs_prediction: The rendered colors, depth and mask
implicitron_render: The rendered colors, depth and mask
of the target views.
point_cloud: The point cloud of the scene. It's renders are
stored in `nvs_prediction`.
stored in `implicitron_render`.
"""
if image_rgb is None:
raise ValueError("ModelDBIR needs image input")
if fg_probability is None:
raise ValueError("ModelDBIR needs foreground mask input")
if depth_map is None:
raise ValueError("ModelDBIR needs depth map input")
is_known = is_known_frame(frame_type)
is_known_idx = torch.where(is_known)[0]
@ -108,7 +117,7 @@ class ModelDBIR(torch.nn.Module):
_image_render, _mask_render, _depth_render = render_point_cloud_pytorch3d(
camera[int(tgt_idx)],
point_cloud,
render_size=(self.image_size, self.image_size),
render_size=(self.render_image_height, self.render_image_width),
point_radius=1e-2,
topk=10,
bg_color=self.bg_color,
@ -121,7 +130,7 @@ class ModelDBIR(torch.nn.Module):
image_render.append(_image_render)
mask_render.append(_mask_render)
nvs_prediction = NewViewSynthesisPrediction(
implicitron_render = ImplicitronRender(
**{
k: torch.cat(v, dim=0)
for k, v in zip(
@ -132,7 +141,7 @@ class ModelDBIR(torch.nn.Module):
)
preds = {
"nvs_prediction": nvs_prediction,
"implicitron_render": implicitron_render,
"point_cloud": point_cloud,
}

View File

@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.
from typing import Union
from typing import Sequence, Union
import torch
@ -14,7 +14,7 @@ def mask_background(
image_rgb: torch.Tensor,
mask_fg: torch.Tensor,
dim_color: int = 1,
bg_color: Union[torch.Tensor, str, float] = 0.0,
bg_color: Union[torch.Tensor, Sequence, str, float] = 0.0,
) -> torch.Tensor:
"""
Mask the background input image tensor `image_rgb` with `bg_color`.
@ -26,9 +26,11 @@ def mask_background(
# obtain the background color tensor
if isinstance(bg_color, torch.Tensor):
bg_color_t = bg_color.view(1, 3, 1, 1).clone().to(image_rgb)
elif isinstance(bg_color, float):
elif isinstance(bg_color, (float, tuple, list)):
if isinstance(bg_color, float):
bg_color = [bg_color] * 3
bg_color_t = torch.tensor(
[bg_color] * 3, device=image_rgb.device, dtype=image_rgb.dtype
bg_color, device=image_rgb.device, dtype=image_rgb.dtype
).view(*tgt_view)
elif isinstance(bg_color, str):
if bg_color == "white":

View File

@ -9,7 +9,7 @@ import unittest
from omegaconf import OmegaConf
from pytorch3d.implicitron.models.autodecoder import Autodecoder
from pytorch3d.implicitron.models.base import GenericModel
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
IdrFeatureField,
)

View File

@ -6,7 +6,6 @@
import contextlib
import copy
import dataclasses
import itertools
import math
@ -19,8 +18,13 @@ from pytorch3d.implicitron.dataset.implicitron_dataset import (
FrameData,
ImplicitronDataset,
)
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import eval_batch
from pytorch3d.implicitron.models.model_dbir import ModelDBIR
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
eval_batch,
)
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
from pytorch3d.implicitron.models.generic_model import GenericModel # noqa
from pytorch3d.implicitron.models.model_dbir import ModelDBIR # noqa
from pytorch3d.implicitron.tools.config import expand_args_fields, registry
from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth
from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
@ -43,7 +47,7 @@ class TestEvaluation(unittest.TestCase):
category = "skateboard"
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
self.image_size = 256
self.image_size = 64
self.dataset = ImplicitronDataset(
frame_annotations_file=frame_file,
sequence_annotations_file=sequence_file,
@ -53,11 +57,11 @@ class TestEvaluation(unittest.TestCase):
box_crop=True,
path_manager=path_manager,
)
self.bg_color = 0.0
self.bg_color = (0.0, 0.0, 0.0)
# init the lpips model for eval
provide_lpips_vgg()
self.lpips_model = lpips.LPIPS(net="vgg")
self.lpips_model = lpips.LPIPS(net="vgg").cuda()
def test_eval_depth(self):
"""
@ -200,30 +204,17 @@ class TestEvaluation(unittest.TestCase):
def _one_sequence_test(
self,
seq_dataset,
n_batches=2,
min_batch_size=5,
max_batch_size=10,
model,
batch_indices,
check_metrics=False,
):
# form a list of random batches
batch_indices = []
for _ in range(n_batches):
batch_size = torch.randint(
low=min_batch_size, high=max_batch_size, size=(1,)
)
batch_indices.append(torch.randperm(len(seq_dataset))[:batch_size])
loader = torch.utils.data.DataLoader(
seq_dataset,
# batch_size=1,
shuffle=False,
batch_sampler=batch_indices,
collate_fn=FrameData.collate,
)
model = ModelDBIR(image_size=self.image_size, bg_color=self.bg_color)
model.cuda()
self.lpips_model.cuda()
for frame_data in loader:
self.assertIsNone(frame_data.frame_type)
self.assertIsNotNone(frame_data.image_rgb)
@ -233,61 +224,101 @@ class TestEvaluation(unittest.TestCase):
*(["train_known"] * (len(frame_data.image_rgb) - 1)),
]
# move frame_data to gpu
frame_data = dataclass_to_cuda_(frame_data)
preds = model(**dataclasses.asdict(frame_data))
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
eval_result = eval_batch(
frame_data,
nvs_prediction,
preds["implicitron_render"],
bg_color=self.bg_color,
lpips_model=self.lpips_model,
)
# Make a terribly bad NVS prediction and check that this is worse
# than the DBIR prediction.
nvs_prediction_bad = copy.deepcopy(preds["nvs_prediction"])
nvs_prediction_bad.depth_render += (
torch.randn_like(nvs_prediction.depth_render) * 100.0
)
nvs_prediction_bad.image_render += (
torch.randn_like(nvs_prediction.image_render) * 100.0
)
nvs_prediction_bad.mask_render = (
torch.randn_like(nvs_prediction.mask_render) > 0.0
).float()
eval_result_bad = eval_batch(
frame_data,
nvs_prediction_bad,
bg_color=self.bg_color,
lpips_model=self.lpips_model,
)
lower_better = {
"psnr": False,
"psnr_fg": False,
"depth_abs_fg": True,
"iou": False,
"rgb_l1": True,
"rgb_l1_fg": True,
}
for metric in lower_better.keys():
m_better = eval_result[metric]
m_worse = eval_result_bad[metric]
if m_better != m_better or m_worse != m_worse:
continue # metric is missing, i.e. NaN
_assert = (
self.assertLessEqual
if lower_better[metric]
else self.assertGreaterEqual
if check_metrics:
self._check_metrics(
frame_data, preds["implicitron_render"], eval_result
)
_assert(m_better, m_worse)
def _check_metrics(self, frame_data, implicitron_render, eval_result):
# Make a terribly bad NVS prediction and check that this is worse
# than the DBIR prediction.
implicitron_render_bad = implicitron_render.clone()
implicitron_render_bad.depth_render += (
torch.randn_like(implicitron_render_bad.depth_render) * 100.0
)
implicitron_render_bad.image_render += (
torch.randn_like(implicitron_render_bad.image_render) * 100.0
)
implicitron_render_bad.mask_render = (
torch.randn_like(implicitron_render_bad.mask_render) > 0.0
).float()
eval_result_bad = eval_batch(
frame_data,
implicitron_render_bad,
bg_color=self.bg_color,
lpips_model=self.lpips_model,
)
lower_better = {
"psnr": False,
"psnr_fg": False,
"depth_abs_fg": True,
"iou": False,
"rgb_l1": True,
"rgb_l1_fg": True,
}
for metric in lower_better:
m_better = eval_result[metric]
m_worse = eval_result_bad[metric]
if m_better != m_better or m_worse != m_worse:
continue # metric is missing, i.e. NaN
_assert = (
self.assertLessEqual
if lower_better[metric]
else self.assertGreaterEqual
)
_assert(m_better, m_worse)
def _get_random_batch_indices(
self, seq_dataset, n_batches=2, min_batch_size=5, max_batch_size=10
):
batch_indices = []
for _ in range(n_batches):
batch_size = torch.randint(
low=min_batch_size, high=max_batch_size, size=(1,)
)
batch_indices.append(torch.randperm(len(seq_dataset))[:batch_size])
return batch_indices
def test_full_eval(self, n_sequences=5):
"""Test evaluation."""
# caching batch indices first to preserve RNG state
seq_datasets = {}
batch_indices = {}
for seq in itertools.islice(self.dataset.sequence_names(), n_sequences):
idx = list(self.dataset.sequence_indices_in_order(seq))
seq_dataset = torch.utils.data.Subset(self.dataset, idx)
self._one_sequence_test(seq_dataset)
seq_datasets[seq] = seq_dataset
batch_indices[seq] = self._get_random_batch_indices(seq_dataset)
for model_class_type in ["ModelDBIR", "GenericModel"]:
ModelClass = registry.get(ImplicitronModelBase, model_class_type)
expand_args_fields(ModelClass)
model = ModelClass(
render_image_width=self.image_size,
render_image_height=self.image_size,
bg_color=self.bg_color,
)
model.eval()
model.cuda()
for seq in itertools.islice(self.dataset.sequence_names(), n_sequences):
self._one_sequence_test(
seq_datasets[seq],
model,
batch_indices[seq],
check_metrics=(model_class_type == "ModelDBIR"),
)

View File

@ -7,7 +7,7 @@
import unittest
import torch
from pytorch3d.implicitron.models.base import GenericModel
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras