diff --git a/projects/implicitron_trainer/experiment.py b/projects/implicitron_trainer/experiment.py index 972b44c0..380ffe4c 100755 --- a/projects/implicitron_trainer/experiment.py +++ b/projects/implicitron_trainer/experiment.py @@ -71,7 +71,7 @@ from pytorch3d.implicitron.dataset.implicitron_dataset import ( ImplicitronDataset, ) from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate -from pytorch3d.implicitron.models.base import EvaluationMode, GenericModel +from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel from pytorch3d.implicitron.tools import model_io, vis_utils from pytorch3d.implicitron.tools.config import ( enable_get_default_args, @@ -615,11 +615,11 @@ def run_eval(cfg, model, all_source_cameras, loader, task, device): preds = model( **{**frame_data_for_eval, "evaluation_mode": EvaluationMode.EVALUATION} ) - nvs_prediction = copy.deepcopy(preds["nvs_prediction"]) + implicitron_render = copy.deepcopy(preds["implicitron_render"]) per_batch_eval_results.append( evaluate.eval_batch( frame_data, - nvs_prediction, + implicitron_render, bg_color="black", lpips_model=lpips_model, source_cameras=all_source_cameras, diff --git a/projects/implicitron_trainer/visualize_reconstruction.py b/projects/implicitron_trainer/visualize_reconstruction.py index 51be395f..a283ecef 100644 --- a/projects/implicitron_trainer/visualize_reconstruction.py +++ b/projects/implicitron_trainer/visualize_reconstruction.py @@ -29,7 +29,7 @@ from pytorch3d.implicitron.dataset.implicitron_dataset import ( ImplicitronDataset, ) from pytorch3d.implicitron.dataset.utils import is_train_frame -from pytorch3d.implicitron.models.base import EvaluationMode +from pytorch3d.implicitron.models.base_model import EvaluationMode from pytorch3d.implicitron.tools.configurable import get_default_args from pytorch3d.implicitron.tools.eval_video_trajectory import ( generate_eval_video_cameras, diff --git a/pytorch3d/implicitron/eval_demo.py b/pytorch3d/implicitron/eval_demo.py index 53cdb734..54464fcd 100644 --- a/pytorch3d/implicitron/eval_demo.py +++ b/pytorch3d/implicitron/eval_demo.py @@ -5,10 +5,9 @@ # LICENSE file in the root directory of this source tree. -import copy import dataclasses import os -from typing import cast, Optional +from typing import cast, Optional, Tuple import lpips import torch @@ -76,7 +75,7 @@ def main() -> None: def evaluate_dbir_for_category( category: str = "apple", - bg_color: float = 0.0, + bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0), task: str = "singlesequence", single_sequence_id: Optional[int] = None, num_workers: int = 16, @@ -141,8 +140,9 @@ def evaluate_dbir_for_category( raise ValueError("Image size should be set in the dataset") # init the simple DBIR model - model = ModelDBIR( - image_size=image_size, + model = ModelDBIR( # pyre-ignore[28]: c’tor implicitly overridden + render_image_width=image_size, + render_image_height=image_size, bg_color=bg_color, max_points=int(1e5), ) @@ -157,11 +157,10 @@ def evaluate_dbir_for_category( for frame_data in tqdm(test_dataloader): frame_data = dataclass_to_cuda_(frame_data) preds = model(**dataclasses.asdict(frame_data)) - nvs_prediction = copy.deepcopy(preds["nvs_prediction"]) per_batch_eval_results.append( eval_batch( frame_data, - nvs_prediction, + preds["implicitron_render"], bg_color=bg_color, lpips_model=lpips_model, source_cameras=all_source_cameras, diff --git a/pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py b/pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py index df8fbaec..b863e073 100644 --- a/pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py +++ b/pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py @@ -9,12 +9,14 @@ import copy import warnings from collections import OrderedDict from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Sequence, Union import numpy as np import torch +import torch.nn.functional as F from pytorch3d.implicitron.dataset.implicitron_dataset import FrameData from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame +from pytorch3d.implicitron.models.base_model import ImplicitronRender from pytorch3d.implicitron.tools import vis_utils from pytorch3d.implicitron.tools.camera_utils import volumetric_camera_overlaps from pytorch3d.implicitron.tools.image_utils import mask_background @@ -31,18 +33,6 @@ from visdom import Visdom EVAL_N_SRC_VIEWS = [1, 3, 5, 7, 9] -@dataclass -class NewViewSynthesisPrediction: - """ - Holds the tensors that describe a result of synthesizing new views. - """ - - depth_render: Optional[torch.Tensor] = None - image_render: Optional[torch.Tensor] = None - mask_render: Optional[torch.Tensor] = None - camera_distance: Optional[torch.Tensor] = None - - @dataclass class _Visualizer: image_render: torch.Tensor @@ -145,8 +135,8 @@ class _Visualizer: def eval_batch( frame_data: FrameData, - nvs_prediction: NewViewSynthesisPrediction, - bg_color: Union[torch.Tensor, str, float] = "black", + implicitron_render: ImplicitronRender, + bg_color: Union[torch.Tensor, Sequence, str, float] = "black", mask_thr: float = 0.5, lpips_model=None, visualize: bool = False, @@ -162,14 +152,14 @@ def eval_batch( is True), a new-view synthesis method (NVS) is tasked to generate new views of the scene from the viewpoint of the target views (for which frame_data.frame_type.endswith('known') is False). The resulting - synthesized new views, stored in `nvs_prediction`, are compared to the + synthesized new views, stored in `implicitron_render`, are compared to the target ground truth in `frame_data` in terms of geometry and appearance resulting in a dictionary of metrics returned by the `eval_batch` function. Args: frame_data: A FrameData object containing the input to the new view synthesis method. - nvs_prediction: The data describing the synthesized new views. + implicitron_render: The data describing the synthesized new views. bg_color: The background color of the generated new views and the ground truth. lpips_model: A pre-trained model for evaluating the LPIPS metric. @@ -184,26 +174,39 @@ def eval_batch( ValueError if frame_data does not have frame_type, camera, or image_rgb ValueError if the batch has a mix of training and test samples ValueError if the batch frames are not [unseen, known, known, ...] - ValueError if one of the required fields in nvs_prediction is missing + ValueError if one of the required fields in implicitron_render is missing """ - REQUIRED_NVS_PREDICTION_FIELDS = ["mask_render", "image_render", "depth_render"] frame_type = frame_data.frame_type if frame_type is None: raise ValueError("Frame type has not been set.") # we check that all those fields are not None but Pyre can't infer that properly - # TODO: assign to local variables + # TODO: assign to local variables and simplify the code. if frame_data.image_rgb is None: raise ValueError("Image is not in the evaluation batch.") if frame_data.camera is None: raise ValueError("Camera is not in the evaluation batch.") - if any(not hasattr(nvs_prediction, k) for k in REQUIRED_NVS_PREDICTION_FIELDS): - raise ValueError("One of the required predicted fields is missing") + # eval all results in the resolution of the frame_data image + image_resol = tuple(frame_data.image_rgb.shape[2:]) + + # Post-process the render: + # 1) check implicitron_render for Nones, + # 2) obtain copies to make sure we dont edit the original data, + # 3) take only the 1st (target) image + # 4) resize to match ground-truth resolution + cloned_render: Dict[str, torch.Tensor] = {} + for k in ["mask_render", "image_render", "depth_render"]: + field = getattr(implicitron_render, k) + if field is None: + raise ValueError(f"A required predicted field {k} is missing") + + imode = "bilinear" if k == "image_render" else "nearest" + cloned_render[k] = ( + F.interpolate(field[:1], size=image_resol, mode=imode).detach().clone() + ) - # obtain copies to make sure we dont edit the original data - nvs_prediction = copy.deepcopy(nvs_prediction) frame_data = copy.deepcopy(frame_data) # mask the ground truth depth in case frame_data contains the depth mask @@ -226,9 +229,6 @@ def eval_batch( + " a target view while the rest should be source views." ) # TODO: do we need to enforce this? - # take only the first (target image) - for k in REQUIRED_NVS_PREDICTION_FIELDS: - setattr(nvs_prediction, k, getattr(nvs_prediction, k)[:1]) for k in [ "depth_map", "image_rgb", @@ -242,10 +242,6 @@ def eval_batch( if frame_data.depth_map is None or frame_data.depth_map.sum() <= 0: warnings.warn("Empty or missing depth map in evaluation!") - # eval all results in the resolution of the frame_data image - # pyre-fixme[16]: `Optional` has no attribute `shape`. - image_resol = list(frame_data.image_rgb.shape[2:]) - # threshold the masks to make ground truth binary masks mask_fg, mask_crop = [ (getattr(frame_data, k) >= mask_thr) for k in ("fg_probability", "mask_crop") @@ -258,29 +254,14 @@ def eval_batch( bg_color=bg_color, ) - # resize to the target resolution - for k in REQUIRED_NVS_PREDICTION_FIELDS: - imode = "bilinear" if k == "image_render" else "nearest" - val = getattr(nvs_prediction, k) - setattr( - nvs_prediction, - k, - # pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got - # `List[typing.Any]`. - torch.nn.functional.interpolate(val, size=image_resol, mode=imode), - ) - # clamp predicted images - # pyre-fixme[16]: `Optional` has no attribute `clamp`. - image_render = nvs_prediction.image_render.clamp(0.0, 1.0) + image_render = cloned_render["image_render"].clamp(0.0, 1.0) if visualize: visualizer = _Visualizer( image_render=image_render, image_rgb_masked=image_rgb_masked, - # pyre-fixme[6]: Expected `Tensor` for 3rd param but got - # `Optional[torch.Tensor]`. - depth_render=nvs_prediction.depth_render, + depth_render=cloned_render["depth_render"], # pyre-fixme[6]: Expected `Tensor` for 4th param but got # `Optional[torch.Tensor]`. depth_map=frame_data.depth_map, @@ -292,9 +273,7 @@ def eval_batch( results: Dict[str, Any] = {} results["iou"] = iou( - # pyre-fixme[6]: Expected `Tensor` for 1st param but got - # `Optional[torch.Tensor]`. - nvs_prediction.mask_render, + cloned_render["mask_render"], mask_fg, mask=mask_crop, ) @@ -321,11 +300,7 @@ def eval_batch( if name_postfix == "_fg": # only record depth metrics for the foreground _, abs_ = eval_depth( - # pyre-fixme[6]: Expected `Tensor` for 1st param but got - # `Optional[torch.Tensor]`. - nvs_prediction.depth_render, - # pyre-fixme[6]: Expected `Tensor` for 2nd param but got - # `Optional[torch.Tensor]`. + cloned_render["depth_render"], frame_data.depth_map, get_best_scale=True, mask=loss_mask_now, @@ -343,7 +318,7 @@ def eval_batch( if lpips_model is not None: im1, im2 = [ 2.0 * im.clamp(0.0, 1.0) - 1.0 - for im in (image_rgb_masked, nvs_prediction.image_render) + for im in (image_rgb_masked, cloned_render["image_render"]) ] results["lpips"] = lpips_model.forward(im1, im2).item() diff --git a/pytorch3d/implicitron/models/base_model.py b/pytorch3d/implicitron/models/base_model.py new file mode 100644 index 00000000..1bcb577c --- /dev/null +++ b/pytorch3d/implicitron/models/base_model.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import torch +from pytorch3d.implicitron.tools.config import ReplaceableBase +from pytorch3d.renderer.cameras import CamerasBase + +from .renderer.base import EvaluationMode + + +@dataclass +class ImplicitronRender: + """ + Holds the tensors that describe a result of rendering. + """ + + depth_render: Optional[torch.Tensor] = None + image_render: Optional[torch.Tensor] = None + mask_render: Optional[torch.Tensor] = None + camera_distance: Optional[torch.Tensor] = None + + def clone(self) -> "ImplicitronRender": + def safe_clone(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + return t.detach().clone() if t is not None else None + + return ImplicitronRender( + depth_render=safe_clone(self.depth_render), + image_render=safe_clone(self.image_render), + mask_render=safe_clone(self.mask_render), + camera_distance=safe_clone(self.camera_distance), + ) + + +class ImplicitronModelBase(ReplaceableBase): + """Replaceable abstract base for all image generation / rendering models. + `forward()` method produces a render with a depth map. + """ + + def __init__(self) -> None: + super().__init__() + + def forward( + self, + *, # force keyword-only arguments + image_rgb: Optional[torch.Tensor], + camera: CamerasBase, + fg_probability: Optional[torch.Tensor], + mask_crop: Optional[torch.Tensor], + depth_map: Optional[torch.Tensor], + sequence_name: Optional[List[str]], + evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, + **kwargs, + ) -> Dict[str, Any]: + """ + Args: + image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images; + the first `min(B, n_train_target_views)` images are considered targets and + are used to supervise the renders; the rest corresponding to the source + viewpoints from which features will be extracted. + camera: An instance of CamerasBase containing a batch of `B` cameras corresponding + to the viewpoints of target images, from which the rays will be sampled, + and source images, which will be used for intersecting with target rays. + fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch of + foreground masks. + mask_crop: A binary tensor of shape `(B, 1, H, W)` deonting valid + regions in the input images (i.e. regions that do not correspond + to, e.g., zero-padding). When the `RaySampler`'s sampling mode is set to + "mask_sample", rays will be sampled in the non zero regions. + depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps. + sequence_name: A list of `B` strings corresponding to the sequence names + from which images `image_rgb` were extracted. They are used to match + target frames with relevant source frames. + evaluation_mode: one of EvaluationMode.TRAINING or + EvaluationMode.EVALUATION which determines the settings used for + rendering. + + Returns: + preds: A dictionary containing all outputs of the forward pass. All models should + output an instance of `ImplicitronRender` in `preds["implicitron_render"]`. + """ + raise NotImplementedError() diff --git a/pytorch3d/implicitron/models/base.py b/pytorch3d/implicitron/models/generic_model.py similarity index 99% rename from pytorch3d/implicitron/models/base.py rename to pytorch3d/implicitron/models/generic_model.py index 4b5d6392..1683d017 100644 --- a/pytorch3d/implicitron/models/base.py +++ b/pytorch3d/implicitron/models/generic_model.py @@ -13,11 +13,11 @@ from typing import Any, Dict, List, Optional, Tuple import torch import tqdm -from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import ( - NewViewSynthesisPrediction, -) from pytorch3d.implicitron.tools import image_utils, vis_utils -from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation +from pytorch3d.implicitron.tools.config import ( + registry, + run_auto_creation, +) from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples from pytorch3d.implicitron.tools.utils import cat_dataclass from pytorch3d.renderer import RayBundle, utils as rend_utils @@ -25,6 +25,7 @@ from pytorch3d.renderer.cameras import CamerasBase from visdom import Visdom from .autodecoder import Autodecoder +from .base_model import ImplicitronModelBase, ImplicitronRender from .implicit_function.base import ImplicitFunctionBase from .implicit_function.idr_feature_field import IdrFeatureField # noqa from .implicit_function.neural_radiance_field import ( # noqa @@ -56,8 +57,8 @@ STD_LOG_VARS = ["objective", "epoch", "sec/it"] logger = logging.getLogger(__name__) -# pyre-ignore: 13 -class GenericModel(Configurable, torch.nn.Module): +@registry.register +class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 """ GenericModel is a wrapper for the neural implicit rendering and reconstruction pipeline which consists @@ -452,7 +453,7 @@ class GenericModel(Configurable, torch.nn.Module): preds["depths_render"] = rendered.depths.permute(0, 3, 1, 2) preds["masks_render"] = rendered.masks.permute(0, 3, 1, 2) - preds["nvs_prediction"] = NewViewSynthesisPrediction( + preds["implicitron_render"] = ImplicitronRender( image_render=preds["images_render"], depth_render=preds["depths_render"], mask_render=preds["masks_render"], diff --git a/pytorch3d/implicitron/models/model_dbir.py b/pytorch3d/implicitron/models/model_dbir.py index 7f3031dc..780d5995 100644 --- a/pytorch3d/implicitron/models/model_dbir.py +++ b/pytorch3d/implicitron/models/model_dbir.py @@ -5,13 +5,11 @@ # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Tuple import torch from pytorch3d.implicitron.dataset.utils import is_known_frame -from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import ( - NewViewSynthesisPrediction, -) +from pytorch3d.implicitron.tools.config import registry from pytorch3d.implicitron.tools.point_cloud_utils import ( get_rgbd_point_cloud, render_point_cloud_pytorch3d, @@ -19,41 +17,43 @@ from pytorch3d.implicitron.tools.point_cloud_utils import ( from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.structures import Pointclouds +from .base_model import ImplicitronModelBase, ImplicitronRender +from .renderer.base import EvaluationMode -class ModelDBIR(torch.nn.Module): + +@registry.register +class ModelDBIR(ImplicitronModelBase, torch.nn.Module): """ A simple depth-based image rendering model. + + Args: + render_image_width: The width of the rendered rectangular images. + render_image_height: The height of the rendered rectangular images. + bg_color: The color of the background. + max_points: Maximum number of points in the point cloud + formed by unprojecting all source view depths. + If more points are present, they are randomly subsampled + to this number of points without replacement. """ - def __init__( - self, - image_size: int = 256, - bg_color: float = 0.0, - max_points: int = -1, - ): - """ - Initializes a simple DBIR model. - - Args: - image_size: The size of the rendered rectangular images. - bg_color: The color of the background. - max_points: Maximum number of points in the point cloud - formed by unprojecting all source view depths. - If more points are present, they are randomly subsampled - to #max_size points without replacement. - """ + render_image_width: int = 256 + render_image_height: int = 256 + bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0) + max_points: int = -1 + def __post_init__(self): super().__init__() - self.image_size = image_size - self.bg_color = bg_color - self.max_points = max_points def forward( self, + *, # force keyword-only arguments + image_rgb: Optional[torch.Tensor], camera: CamerasBase, - image_rgb: torch.Tensor, - depth_map: torch.Tensor, - fg_probability: torch.Tensor, + fg_probability: Optional[torch.Tensor], + mask_crop: Optional[torch.Tensor], + depth_map: Optional[torch.Tensor], + sequence_name: Optional[List[str]], + evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, frame_type: List[str], **kwargs, ) -> Dict[str, Any]: # TODO: return a namedtuple or dataclass @@ -72,12 +72,21 @@ class ModelDBIR(torch.nn.Module): Returns: preds: A dict with the following fields: - nvs_prediction: The rendered colors, depth and mask + implicitron_render: The rendered colors, depth and mask of the target views. point_cloud: The point cloud of the scene. It's renders are - stored in `nvs_prediction`. + stored in `implicitron_render`. """ + if image_rgb is None: + raise ValueError("ModelDBIR needs image input") + + if fg_probability is None: + raise ValueError("ModelDBIR needs foreground mask input") + + if depth_map is None: + raise ValueError("ModelDBIR needs depth map input") + is_known = is_known_frame(frame_type) is_known_idx = torch.where(is_known)[0] @@ -108,7 +117,7 @@ class ModelDBIR(torch.nn.Module): _image_render, _mask_render, _depth_render = render_point_cloud_pytorch3d( camera[int(tgt_idx)], point_cloud, - render_size=(self.image_size, self.image_size), + render_size=(self.render_image_height, self.render_image_width), point_radius=1e-2, topk=10, bg_color=self.bg_color, @@ -121,7 +130,7 @@ class ModelDBIR(torch.nn.Module): image_render.append(_image_render) mask_render.append(_mask_render) - nvs_prediction = NewViewSynthesisPrediction( + implicitron_render = ImplicitronRender( **{ k: torch.cat(v, dim=0) for k, v in zip( @@ -132,7 +141,7 @@ class ModelDBIR(torch.nn.Module): ) preds = { - "nvs_prediction": nvs_prediction, + "implicitron_render": implicitron_render, "point_cloud": point_cloud, } diff --git a/pytorch3d/implicitron/tools/image_utils.py b/pytorch3d/implicitron/tools/image_utils.py index 33926f31..29c7e0a4 100644 --- a/pytorch3d/implicitron/tools/image_utils.py +++ b/pytorch3d/implicitron/tools/image_utils.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. -from typing import Union +from typing import Sequence, Union import torch @@ -14,7 +14,7 @@ def mask_background( image_rgb: torch.Tensor, mask_fg: torch.Tensor, dim_color: int = 1, - bg_color: Union[torch.Tensor, str, float] = 0.0, + bg_color: Union[torch.Tensor, Sequence, str, float] = 0.0, ) -> torch.Tensor: """ Mask the background input image tensor `image_rgb` with `bg_color`. @@ -26,9 +26,11 @@ def mask_background( # obtain the background color tensor if isinstance(bg_color, torch.Tensor): bg_color_t = bg_color.view(1, 3, 1, 1).clone().to(image_rgb) - elif isinstance(bg_color, float): + elif isinstance(bg_color, (float, tuple, list)): + if isinstance(bg_color, float): + bg_color = [bg_color] * 3 bg_color_t = torch.tensor( - [bg_color] * 3, device=image_rgb.device, dtype=image_rgb.dtype + bg_color, device=image_rgb.device, dtype=image_rgb.dtype ).view(*tgt_view) elif isinstance(bg_color, str): if bg_color == "white": diff --git a/tests/implicitron/test_config_use.py b/tests/implicitron/test_config_use.py index 6ab5ecb2..a4f94343 100644 --- a/tests/implicitron/test_config_use.py +++ b/tests/implicitron/test_config_use.py @@ -9,7 +9,7 @@ import unittest from omegaconf import OmegaConf from pytorch3d.implicitron.models.autodecoder import Autodecoder -from pytorch3d.implicitron.models.base import GenericModel +from pytorch3d.implicitron.models.generic_model import GenericModel from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( IdrFeatureField, ) diff --git a/tests/implicitron/test_evaluation.py b/tests/implicitron/test_evaluation.py index 34ef0ee4..95f4c9d3 100644 --- a/tests/implicitron/test_evaluation.py +++ b/tests/implicitron/test_evaluation.py @@ -6,7 +6,6 @@ import contextlib -import copy import dataclasses import itertools import math @@ -19,8 +18,13 @@ from pytorch3d.implicitron.dataset.implicitron_dataset import ( FrameData, ImplicitronDataset, ) -from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import eval_batch -from pytorch3d.implicitron.models.model_dbir import ModelDBIR +from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import ( + eval_batch, +) +from pytorch3d.implicitron.models.base_model import ImplicitronModelBase +from pytorch3d.implicitron.models.generic_model import GenericModel # noqa +from pytorch3d.implicitron.models.model_dbir import ModelDBIR # noqa +from pytorch3d.implicitron.tools.config import expand_args_fields, registry from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_ @@ -43,7 +47,7 @@ class TestEvaluation(unittest.TestCase): category = "skateboard" frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz") sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz") - self.image_size = 256 + self.image_size = 64 self.dataset = ImplicitronDataset( frame_annotations_file=frame_file, sequence_annotations_file=sequence_file, @@ -53,11 +57,11 @@ class TestEvaluation(unittest.TestCase): box_crop=True, path_manager=path_manager, ) - self.bg_color = 0.0 + self.bg_color = (0.0, 0.0, 0.0) # init the lpips model for eval provide_lpips_vgg() - self.lpips_model = lpips.LPIPS(net="vgg") + self.lpips_model = lpips.LPIPS(net="vgg").cuda() def test_eval_depth(self): """ @@ -200,30 +204,17 @@ class TestEvaluation(unittest.TestCase): def _one_sequence_test( self, seq_dataset, - n_batches=2, - min_batch_size=5, - max_batch_size=10, + model, + batch_indices, + check_metrics=False, ): - # form a list of random batches - batch_indices = [] - for _ in range(n_batches): - batch_size = torch.randint( - low=min_batch_size, high=max_batch_size, size=(1,) - ) - batch_indices.append(torch.randperm(len(seq_dataset))[:batch_size]) - loader = torch.utils.data.DataLoader( seq_dataset, - # batch_size=1, shuffle=False, batch_sampler=batch_indices, collate_fn=FrameData.collate, ) - model = ModelDBIR(image_size=self.image_size, bg_color=self.bg_color) - model.cuda() - self.lpips_model.cuda() - for frame_data in loader: self.assertIsNone(frame_data.frame_type) self.assertIsNotNone(frame_data.image_rgb) @@ -233,61 +224,101 @@ class TestEvaluation(unittest.TestCase): *(["train_known"] * (len(frame_data.image_rgb) - 1)), ] - # move frame_data to gpu frame_data = dataclass_to_cuda_(frame_data) preds = model(**dataclasses.asdict(frame_data)) - nvs_prediction = copy.deepcopy(preds["nvs_prediction"]) eval_result = eval_batch( frame_data, - nvs_prediction, + preds["implicitron_render"], bg_color=self.bg_color, lpips_model=self.lpips_model, ) - # Make a terribly bad NVS prediction and check that this is worse - # than the DBIR prediction. - nvs_prediction_bad = copy.deepcopy(preds["nvs_prediction"]) - nvs_prediction_bad.depth_render += ( - torch.randn_like(nvs_prediction.depth_render) * 100.0 - ) - nvs_prediction_bad.image_render += ( - torch.randn_like(nvs_prediction.image_render) * 100.0 - ) - nvs_prediction_bad.mask_render = ( - torch.randn_like(nvs_prediction.mask_render) > 0.0 - ).float() - eval_result_bad = eval_batch( - frame_data, - nvs_prediction_bad, - bg_color=self.bg_color, - lpips_model=self.lpips_model, - ) - - lower_better = { - "psnr": False, - "psnr_fg": False, - "depth_abs_fg": True, - "iou": False, - "rgb_l1": True, - "rgb_l1_fg": True, - } - - for metric in lower_better.keys(): - m_better = eval_result[metric] - m_worse = eval_result_bad[metric] - if m_better != m_better or m_worse != m_worse: - continue # metric is missing, i.e. NaN - _assert = ( - self.assertLessEqual - if lower_better[metric] - else self.assertGreaterEqual + if check_metrics: + self._check_metrics( + frame_data, preds["implicitron_render"], eval_result ) - _assert(m_better, m_worse) + + def _check_metrics(self, frame_data, implicitron_render, eval_result): + # Make a terribly bad NVS prediction and check that this is worse + # than the DBIR prediction. + implicitron_render_bad = implicitron_render.clone() + implicitron_render_bad.depth_render += ( + torch.randn_like(implicitron_render_bad.depth_render) * 100.0 + ) + implicitron_render_bad.image_render += ( + torch.randn_like(implicitron_render_bad.image_render) * 100.0 + ) + implicitron_render_bad.mask_render = ( + torch.randn_like(implicitron_render_bad.mask_render) > 0.0 + ).float() + eval_result_bad = eval_batch( + frame_data, + implicitron_render_bad, + bg_color=self.bg_color, + lpips_model=self.lpips_model, + ) + + lower_better = { + "psnr": False, + "psnr_fg": False, + "depth_abs_fg": True, + "iou": False, + "rgb_l1": True, + "rgb_l1_fg": True, + } + + for metric in lower_better: + m_better = eval_result[metric] + m_worse = eval_result_bad[metric] + if m_better != m_better or m_worse != m_worse: + continue # metric is missing, i.e. NaN + _assert = ( + self.assertLessEqual + if lower_better[metric] + else self.assertGreaterEqual + ) + _assert(m_better, m_worse) + + def _get_random_batch_indices( + self, seq_dataset, n_batches=2, min_batch_size=5, max_batch_size=10 + ): + batch_indices = [] + for _ in range(n_batches): + batch_size = torch.randint( + low=min_batch_size, high=max_batch_size, size=(1,) + ) + batch_indices.append(torch.randperm(len(seq_dataset))[:batch_size]) + + return batch_indices def test_full_eval(self, n_sequences=5): """Test evaluation.""" + + # caching batch indices first to preserve RNG state + seq_datasets = {} + batch_indices = {} for seq in itertools.islice(self.dataset.sequence_names(), n_sequences): idx = list(self.dataset.sequence_indices_in_order(seq)) seq_dataset = torch.utils.data.Subset(self.dataset, idx) - self._one_sequence_test(seq_dataset) + seq_datasets[seq] = seq_dataset + batch_indices[seq] = self._get_random_batch_indices(seq_dataset) + + for model_class_type in ["ModelDBIR", "GenericModel"]: + ModelClass = registry.get(ImplicitronModelBase, model_class_type) + expand_args_fields(ModelClass) + model = ModelClass( + render_image_width=self.image_size, + render_image_height=self.image_size, + bg_color=self.bg_color, + ) + model.eval() + model.cuda() + + for seq in itertools.islice(self.dataset.sequence_names(), n_sequences): + self._one_sequence_test( + seq_datasets[seq], + model, + batch_indices[seq], + check_metrics=(model_class_type == "ModelDBIR"), + ) diff --git a/tests/implicitron/test_forward_pass.py b/tests/implicitron/test_forward_pass.py index e1909248..10ec48f7 100644 --- a/tests/implicitron/test_forward_pass.py +++ b/tests/implicitron/test_forward_pass.py @@ -7,7 +7,7 @@ import unittest import torch -from pytorch3d.implicitron.models.base import GenericModel +from pytorch3d.implicitron.models.generic_model import GenericModel from pytorch3d.implicitron.models.renderer.base import EvaluationMode from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras