Extracted ImplicitronModelBase and unified API for GenericModel and ModelDBIR

Summary:
To avoid model_zoo, we need to make GenericModel pluggable.
I also align creation APIs for convenience.

Reviewed By: bottler, davnov134

Differential Revision: D35933093

fbshipit-source-id: 8228926528eb41a795fbfbe32304b8019197e2b1
This commit is contained in:
Roman Shapovalov 2022-05-09 15:23:07 -07:00 committed by Facebook GitHub Bot
parent 5c59841863
commit a6dada399d
11 changed files with 282 additions and 178 deletions

View File

@ -71,7 +71,7 @@ from pytorch3d.implicitron.dataset.implicitron_dataset import (
ImplicitronDataset, ImplicitronDataset,
) )
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
from pytorch3d.implicitron.models.base import EvaluationMode, GenericModel from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
from pytorch3d.implicitron.tools import model_io, vis_utils from pytorch3d.implicitron.tools import model_io, vis_utils
from pytorch3d.implicitron.tools.config import ( from pytorch3d.implicitron.tools.config import (
enable_get_default_args, enable_get_default_args,
@ -615,11 +615,11 @@ def run_eval(cfg, model, all_source_cameras, loader, task, device):
preds = model( preds = model(
**{**frame_data_for_eval, "evaluation_mode": EvaluationMode.EVALUATION} **{**frame_data_for_eval, "evaluation_mode": EvaluationMode.EVALUATION}
) )
nvs_prediction = copy.deepcopy(preds["nvs_prediction"]) implicitron_render = copy.deepcopy(preds["implicitron_render"])
per_batch_eval_results.append( per_batch_eval_results.append(
evaluate.eval_batch( evaluate.eval_batch(
frame_data, frame_data,
nvs_prediction, implicitron_render,
bg_color="black", bg_color="black",
lpips_model=lpips_model, lpips_model=lpips_model,
source_cameras=all_source_cameras, source_cameras=all_source_cameras,

View File

@ -29,7 +29,7 @@ from pytorch3d.implicitron.dataset.implicitron_dataset import (
ImplicitronDataset, ImplicitronDataset,
) )
from pytorch3d.implicitron.dataset.utils import is_train_frame from pytorch3d.implicitron.dataset.utils import is_train_frame
from pytorch3d.implicitron.models.base import EvaluationMode from pytorch3d.implicitron.models.base_model import EvaluationMode
from pytorch3d.implicitron.tools.configurable import get_default_args from pytorch3d.implicitron.tools.configurable import get_default_args
from pytorch3d.implicitron.tools.eval_video_trajectory import ( from pytorch3d.implicitron.tools.eval_video_trajectory import (
generate_eval_video_cameras, generate_eval_video_cameras,

View File

@ -5,10 +5,9 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import copy
import dataclasses import dataclasses
import os import os
from typing import cast, Optional from typing import cast, Optional, Tuple
import lpips import lpips
import torch import torch
@ -76,7 +75,7 @@ def main() -> None:
def evaluate_dbir_for_category( def evaluate_dbir_for_category(
category: str = "apple", category: str = "apple",
bg_color: float = 0.0, bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0),
task: str = "singlesequence", task: str = "singlesequence",
single_sequence_id: Optional[int] = None, single_sequence_id: Optional[int] = None,
num_workers: int = 16, num_workers: int = 16,
@ -141,8 +140,9 @@ def evaluate_dbir_for_category(
raise ValueError("Image size should be set in the dataset") raise ValueError("Image size should be set in the dataset")
# init the simple DBIR model # init the simple DBIR model
model = ModelDBIR( model = ModelDBIR( # pyre-ignore[28]: ctor implicitly overridden
image_size=image_size, render_image_width=image_size,
render_image_height=image_size,
bg_color=bg_color, bg_color=bg_color,
max_points=int(1e5), max_points=int(1e5),
) )
@ -157,11 +157,10 @@ def evaluate_dbir_for_category(
for frame_data in tqdm(test_dataloader): for frame_data in tqdm(test_dataloader):
frame_data = dataclass_to_cuda_(frame_data) frame_data = dataclass_to_cuda_(frame_data)
preds = model(**dataclasses.asdict(frame_data)) preds = model(**dataclasses.asdict(frame_data))
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
per_batch_eval_results.append( per_batch_eval_results.append(
eval_batch( eval_batch(
frame_data, frame_data,
nvs_prediction, preds["implicitron_render"],
bg_color=bg_color, bg_color=bg_color,
lpips_model=lpips_model, lpips_model=lpips_model,
source_cameras=all_source_cameras, source_cameras=all_source_cameras,

View File

@ -9,12 +9,14 @@ import copy
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Sequence, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
from pytorch3d.implicitron.dataset.implicitron_dataset import FrameData from pytorch3d.implicitron.dataset.implicitron_dataset import FrameData
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
from pytorch3d.implicitron.models.base_model import ImplicitronRender
from pytorch3d.implicitron.tools import vis_utils from pytorch3d.implicitron.tools import vis_utils
from pytorch3d.implicitron.tools.camera_utils import volumetric_camera_overlaps from pytorch3d.implicitron.tools.camera_utils import volumetric_camera_overlaps
from pytorch3d.implicitron.tools.image_utils import mask_background from pytorch3d.implicitron.tools.image_utils import mask_background
@ -31,18 +33,6 @@ from visdom import Visdom
EVAL_N_SRC_VIEWS = [1, 3, 5, 7, 9] EVAL_N_SRC_VIEWS = [1, 3, 5, 7, 9]
@dataclass
class NewViewSynthesisPrediction:
"""
Holds the tensors that describe a result of synthesizing new views.
"""
depth_render: Optional[torch.Tensor] = None
image_render: Optional[torch.Tensor] = None
mask_render: Optional[torch.Tensor] = None
camera_distance: Optional[torch.Tensor] = None
@dataclass @dataclass
class _Visualizer: class _Visualizer:
image_render: torch.Tensor image_render: torch.Tensor
@ -145,8 +135,8 @@ class _Visualizer:
def eval_batch( def eval_batch(
frame_data: FrameData, frame_data: FrameData,
nvs_prediction: NewViewSynthesisPrediction, implicitron_render: ImplicitronRender,
bg_color: Union[torch.Tensor, str, float] = "black", bg_color: Union[torch.Tensor, Sequence, str, float] = "black",
mask_thr: float = 0.5, mask_thr: float = 0.5,
lpips_model=None, lpips_model=None,
visualize: bool = False, visualize: bool = False,
@ -162,14 +152,14 @@ def eval_batch(
is True), a new-view synthesis method (NVS) is tasked to generate new views is True), a new-view synthesis method (NVS) is tasked to generate new views
of the scene from the viewpoint of the target views (for which of the scene from the viewpoint of the target views (for which
frame_data.frame_type.endswith('known') is False). The resulting frame_data.frame_type.endswith('known') is False). The resulting
synthesized new views, stored in `nvs_prediction`, are compared to the synthesized new views, stored in `implicitron_render`, are compared to the
target ground truth in `frame_data` in terms of geometry and appearance target ground truth in `frame_data` in terms of geometry and appearance
resulting in a dictionary of metrics returned by the `eval_batch` function. resulting in a dictionary of metrics returned by the `eval_batch` function.
Args: Args:
frame_data: A FrameData object containing the input to the new view frame_data: A FrameData object containing the input to the new view
synthesis method. synthesis method.
nvs_prediction: The data describing the synthesized new views. implicitron_render: The data describing the synthesized new views.
bg_color: The background color of the generated new views and the bg_color: The background color of the generated new views and the
ground truth. ground truth.
lpips_model: A pre-trained model for evaluating the LPIPS metric. lpips_model: A pre-trained model for evaluating the LPIPS metric.
@ -184,26 +174,39 @@ def eval_batch(
ValueError if frame_data does not have frame_type, camera, or image_rgb ValueError if frame_data does not have frame_type, camera, or image_rgb
ValueError if the batch has a mix of training and test samples ValueError if the batch has a mix of training and test samples
ValueError if the batch frames are not [unseen, known, known, ...] ValueError if the batch frames are not [unseen, known, known, ...]
ValueError if one of the required fields in nvs_prediction is missing ValueError if one of the required fields in implicitron_render is missing
""" """
REQUIRED_NVS_PREDICTION_FIELDS = ["mask_render", "image_render", "depth_render"]
frame_type = frame_data.frame_type frame_type = frame_data.frame_type
if frame_type is None: if frame_type is None:
raise ValueError("Frame type has not been set.") raise ValueError("Frame type has not been set.")
# we check that all those fields are not None but Pyre can't infer that properly # we check that all those fields are not None but Pyre can't infer that properly
# TODO: assign to local variables # TODO: assign to local variables and simplify the code.
if frame_data.image_rgb is None: if frame_data.image_rgb is None:
raise ValueError("Image is not in the evaluation batch.") raise ValueError("Image is not in the evaluation batch.")
if frame_data.camera is None: if frame_data.camera is None:
raise ValueError("Camera is not in the evaluation batch.") raise ValueError("Camera is not in the evaluation batch.")
if any(not hasattr(nvs_prediction, k) for k in REQUIRED_NVS_PREDICTION_FIELDS): # eval all results in the resolution of the frame_data image
raise ValueError("One of the required predicted fields is missing") image_resol = tuple(frame_data.image_rgb.shape[2:])
# Post-process the render:
# 1) check implicitron_render for Nones,
# 2) obtain copies to make sure we dont edit the original data,
# 3) take only the 1st (target) image
# 4) resize to match ground-truth resolution
cloned_render: Dict[str, torch.Tensor] = {}
for k in ["mask_render", "image_render", "depth_render"]:
field = getattr(implicitron_render, k)
if field is None:
raise ValueError(f"A required predicted field {k} is missing")
imode = "bilinear" if k == "image_render" else "nearest"
cloned_render[k] = (
F.interpolate(field[:1], size=image_resol, mode=imode).detach().clone()
)
# obtain copies to make sure we dont edit the original data
nvs_prediction = copy.deepcopy(nvs_prediction)
frame_data = copy.deepcopy(frame_data) frame_data = copy.deepcopy(frame_data)
# mask the ground truth depth in case frame_data contains the depth mask # mask the ground truth depth in case frame_data contains the depth mask
@ -226,9 +229,6 @@ def eval_batch(
+ " a target view while the rest should be source views." + " a target view while the rest should be source views."
) # TODO: do we need to enforce this? ) # TODO: do we need to enforce this?
# take only the first (target image)
for k in REQUIRED_NVS_PREDICTION_FIELDS:
setattr(nvs_prediction, k, getattr(nvs_prediction, k)[:1])
for k in [ for k in [
"depth_map", "depth_map",
"image_rgb", "image_rgb",
@ -242,10 +242,6 @@ def eval_batch(
if frame_data.depth_map is None or frame_data.depth_map.sum() <= 0: if frame_data.depth_map is None or frame_data.depth_map.sum() <= 0:
warnings.warn("Empty or missing depth map in evaluation!") warnings.warn("Empty or missing depth map in evaluation!")
# eval all results in the resolution of the frame_data image
# pyre-fixme[16]: `Optional` has no attribute `shape`.
image_resol = list(frame_data.image_rgb.shape[2:])
# threshold the masks to make ground truth binary masks # threshold the masks to make ground truth binary masks
mask_fg, mask_crop = [ mask_fg, mask_crop = [
(getattr(frame_data, k) >= mask_thr) for k in ("fg_probability", "mask_crop") (getattr(frame_data, k) >= mask_thr) for k in ("fg_probability", "mask_crop")
@ -258,29 +254,14 @@ def eval_batch(
bg_color=bg_color, bg_color=bg_color,
) )
# resize to the target resolution
for k in REQUIRED_NVS_PREDICTION_FIELDS:
imode = "bilinear" if k == "image_render" else "nearest"
val = getattr(nvs_prediction, k)
setattr(
nvs_prediction,
k,
# pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got
# `List[typing.Any]`.
torch.nn.functional.interpolate(val, size=image_resol, mode=imode),
)
# clamp predicted images # clamp predicted images
# pyre-fixme[16]: `Optional` has no attribute `clamp`. image_render = cloned_render["image_render"].clamp(0.0, 1.0)
image_render = nvs_prediction.image_render.clamp(0.0, 1.0)
if visualize: if visualize:
visualizer = _Visualizer( visualizer = _Visualizer(
image_render=image_render, image_render=image_render,
image_rgb_masked=image_rgb_masked, image_rgb_masked=image_rgb_masked,
# pyre-fixme[6]: Expected `Tensor` for 3rd param but got depth_render=cloned_render["depth_render"],
# `Optional[torch.Tensor]`.
depth_render=nvs_prediction.depth_render,
# pyre-fixme[6]: Expected `Tensor` for 4th param but got # pyre-fixme[6]: Expected `Tensor` for 4th param but got
# `Optional[torch.Tensor]`. # `Optional[torch.Tensor]`.
depth_map=frame_data.depth_map, depth_map=frame_data.depth_map,
@ -292,9 +273,7 @@ def eval_batch(
results: Dict[str, Any] = {} results: Dict[str, Any] = {}
results["iou"] = iou( results["iou"] = iou(
# pyre-fixme[6]: Expected `Tensor` for 1st param but got cloned_render["mask_render"],
# `Optional[torch.Tensor]`.
nvs_prediction.mask_render,
mask_fg, mask_fg,
mask=mask_crop, mask=mask_crop,
) )
@ -321,11 +300,7 @@ def eval_batch(
if name_postfix == "_fg": if name_postfix == "_fg":
# only record depth metrics for the foreground # only record depth metrics for the foreground
_, abs_ = eval_depth( _, abs_ = eval_depth(
# pyre-fixme[6]: Expected `Tensor` for 1st param but got cloned_render["depth_render"],
# `Optional[torch.Tensor]`.
nvs_prediction.depth_render,
# pyre-fixme[6]: Expected `Tensor` for 2nd param but got
# `Optional[torch.Tensor]`.
frame_data.depth_map, frame_data.depth_map,
get_best_scale=True, get_best_scale=True,
mask=loss_mask_now, mask=loss_mask_now,
@ -343,7 +318,7 @@ def eval_batch(
if lpips_model is not None: if lpips_model is not None:
im1, im2 = [ im1, im2 = [
2.0 * im.clamp(0.0, 1.0) - 1.0 2.0 * im.clamp(0.0, 1.0) - 1.0
for im in (image_rgb_masked, nvs_prediction.image_render) for im in (image_rgb_masked, cloned_render["image_render"])
] ]
results["lpips"] = lpips_model.forward(im1, im2).item() results["lpips"] = lpips_model.forward(im1, im2).item()

View File

@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import torch
from pytorch3d.implicitron.tools.config import ReplaceableBase
from pytorch3d.renderer.cameras import CamerasBase
from .renderer.base import EvaluationMode
@dataclass
class ImplicitronRender:
"""
Holds the tensors that describe a result of rendering.
"""
depth_render: Optional[torch.Tensor] = None
image_render: Optional[torch.Tensor] = None
mask_render: Optional[torch.Tensor] = None
camera_distance: Optional[torch.Tensor] = None
def clone(self) -> "ImplicitronRender":
def safe_clone(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
return t.detach().clone() if t is not None else None
return ImplicitronRender(
depth_render=safe_clone(self.depth_render),
image_render=safe_clone(self.image_render),
mask_render=safe_clone(self.mask_render),
camera_distance=safe_clone(self.camera_distance),
)
class ImplicitronModelBase(ReplaceableBase):
"""Replaceable abstract base for all image generation / rendering models.
`forward()` method produces a render with a depth map.
"""
def __init__(self) -> None:
super().__init__()
def forward(
self,
*, # force keyword-only arguments
image_rgb: Optional[torch.Tensor],
camera: CamerasBase,
fg_probability: Optional[torch.Tensor],
mask_crop: Optional[torch.Tensor],
depth_map: Optional[torch.Tensor],
sequence_name: Optional[List[str]],
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
**kwargs,
) -> Dict[str, Any]:
"""
Args:
image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images;
the first `min(B, n_train_target_views)` images are considered targets and
are used to supervise the renders; the rest corresponding to the source
viewpoints from which features will be extracted.
camera: An instance of CamerasBase containing a batch of `B` cameras corresponding
to the viewpoints of target images, from which the rays will be sampled,
and source images, which will be used for intersecting with target rays.
fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch of
foreground masks.
mask_crop: A binary tensor of shape `(B, 1, H, W)` deonting valid
regions in the input images (i.e. regions that do not correspond
to, e.g., zero-padding). When the `RaySampler`'s sampling mode is set to
"mask_sample", rays will be sampled in the non zero regions.
depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps.
sequence_name: A list of `B` strings corresponding to the sequence names
from which images `image_rgb` were extracted. They are used to match
target frames with relevant source frames.
evaluation_mode: one of EvaluationMode.TRAINING or
EvaluationMode.EVALUATION which determines the settings used for
rendering.
Returns:
preds: A dictionary containing all outputs of the forward pass. All models should
output an instance of `ImplicitronRender` in `preds["implicitron_render"]`.
"""
raise NotImplementedError()

View File

@ -13,11 +13,11 @@ from typing import Any, Dict, List, Optional, Tuple
import torch import torch
import tqdm import tqdm
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
NewViewSynthesisPrediction,
)
from pytorch3d.implicitron.tools import image_utils, vis_utils from pytorch3d.implicitron.tools import image_utils, vis_utils
from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation from pytorch3d.implicitron.tools.config import (
registry,
run_auto_creation,
)
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
from pytorch3d.implicitron.tools.utils import cat_dataclass from pytorch3d.implicitron.tools.utils import cat_dataclass
from pytorch3d.renderer import RayBundle, utils as rend_utils from pytorch3d.renderer import RayBundle, utils as rend_utils
@ -25,6 +25,7 @@ from pytorch3d.renderer.cameras import CamerasBase
from visdom import Visdom from visdom import Visdom
from .autodecoder import Autodecoder from .autodecoder import Autodecoder
from .base_model import ImplicitronModelBase, ImplicitronRender
from .implicit_function.base import ImplicitFunctionBase from .implicit_function.base import ImplicitFunctionBase
from .implicit_function.idr_feature_field import IdrFeatureField # noqa from .implicit_function.idr_feature_field import IdrFeatureField # noqa
from .implicit_function.neural_radiance_field import ( # noqa from .implicit_function.neural_radiance_field import ( # noqa
@ -56,8 +57,8 @@ STD_LOG_VARS = ["objective", "epoch", "sec/it"]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# pyre-ignore: 13 @registry.register
class GenericModel(Configurable, torch.nn.Module): class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
""" """
GenericModel is a wrapper for the neural implicit GenericModel is a wrapper for the neural implicit
rendering and reconstruction pipeline which consists rendering and reconstruction pipeline which consists
@ -452,7 +453,7 @@ class GenericModel(Configurable, torch.nn.Module):
preds["depths_render"] = rendered.depths.permute(0, 3, 1, 2) preds["depths_render"] = rendered.depths.permute(0, 3, 1, 2)
preds["masks_render"] = rendered.masks.permute(0, 3, 1, 2) preds["masks_render"] = rendered.masks.permute(0, 3, 1, 2)
preds["nvs_prediction"] = NewViewSynthesisPrediction( preds["implicitron_render"] = ImplicitronRender(
image_render=preds["images_render"], image_render=preds["images_render"],
depth_render=preds["depths_render"], depth_render=preds["depths_render"],
mask_render=preds["masks_render"], mask_render=preds["masks_render"],

View File

@ -5,13 +5,11 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Any, Dict, List from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from pytorch3d.implicitron.dataset.utils import is_known_frame from pytorch3d.implicitron.dataset.utils import is_known_frame
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import ( from pytorch3d.implicitron.tools.config import registry
NewViewSynthesisPrediction,
)
from pytorch3d.implicitron.tools.point_cloud_utils import ( from pytorch3d.implicitron.tools.point_cloud_utils import (
get_rgbd_point_cloud, get_rgbd_point_cloud,
render_point_cloud_pytorch3d, render_point_cloud_pytorch3d,
@ -19,41 +17,43 @@ from pytorch3d.implicitron.tools.point_cloud_utils import (
from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.structures import Pointclouds from pytorch3d.structures import Pointclouds
from .base_model import ImplicitronModelBase, ImplicitronRender
from .renderer.base import EvaluationMode
class ModelDBIR(torch.nn.Module):
@registry.register
class ModelDBIR(ImplicitronModelBase, torch.nn.Module):
""" """
A simple depth-based image rendering model. A simple depth-based image rendering model.
"""
def __init__(
self,
image_size: int = 256,
bg_color: float = 0.0,
max_points: int = -1,
):
"""
Initializes a simple DBIR model.
Args: Args:
image_size: The size of the rendered rectangular images. render_image_width: The width of the rendered rectangular images.
render_image_height: The height of the rendered rectangular images.
bg_color: The color of the background. bg_color: The color of the background.
max_points: Maximum number of points in the point cloud max_points: Maximum number of points in the point cloud
formed by unprojecting all source view depths. formed by unprojecting all source view depths.
If more points are present, they are randomly subsampled If more points are present, they are randomly subsampled
to #max_size points without replacement. to this number of points without replacement.
""" """
render_image_width: int = 256
render_image_height: int = 256
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
max_points: int = -1
def __post_init__(self):
super().__init__() super().__init__()
self.image_size = image_size
self.bg_color = bg_color
self.max_points = max_points
def forward( def forward(
self, self,
*, # force keyword-only arguments
image_rgb: Optional[torch.Tensor],
camera: CamerasBase, camera: CamerasBase,
image_rgb: torch.Tensor, fg_probability: Optional[torch.Tensor],
depth_map: torch.Tensor, mask_crop: Optional[torch.Tensor],
fg_probability: torch.Tensor, depth_map: Optional[torch.Tensor],
sequence_name: Optional[List[str]],
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
frame_type: List[str], frame_type: List[str],
**kwargs, **kwargs,
) -> Dict[str, Any]: # TODO: return a namedtuple or dataclass ) -> Dict[str, Any]: # TODO: return a namedtuple or dataclass
@ -72,12 +72,21 @@ class ModelDBIR(torch.nn.Module):
Returns: Returns:
preds: A dict with the following fields: preds: A dict with the following fields:
nvs_prediction: The rendered colors, depth and mask implicitron_render: The rendered colors, depth and mask
of the target views. of the target views.
point_cloud: The point cloud of the scene. It's renders are point_cloud: The point cloud of the scene. It's renders are
stored in `nvs_prediction`. stored in `implicitron_render`.
""" """
if image_rgb is None:
raise ValueError("ModelDBIR needs image input")
if fg_probability is None:
raise ValueError("ModelDBIR needs foreground mask input")
if depth_map is None:
raise ValueError("ModelDBIR needs depth map input")
is_known = is_known_frame(frame_type) is_known = is_known_frame(frame_type)
is_known_idx = torch.where(is_known)[0] is_known_idx = torch.where(is_known)[0]
@ -108,7 +117,7 @@ class ModelDBIR(torch.nn.Module):
_image_render, _mask_render, _depth_render = render_point_cloud_pytorch3d( _image_render, _mask_render, _depth_render = render_point_cloud_pytorch3d(
camera[int(tgt_idx)], camera[int(tgt_idx)],
point_cloud, point_cloud,
render_size=(self.image_size, self.image_size), render_size=(self.render_image_height, self.render_image_width),
point_radius=1e-2, point_radius=1e-2,
topk=10, topk=10,
bg_color=self.bg_color, bg_color=self.bg_color,
@ -121,7 +130,7 @@ class ModelDBIR(torch.nn.Module):
image_render.append(_image_render) image_render.append(_image_render)
mask_render.append(_mask_render) mask_render.append(_mask_render)
nvs_prediction = NewViewSynthesisPrediction( implicitron_render = ImplicitronRender(
**{ **{
k: torch.cat(v, dim=0) k: torch.cat(v, dim=0)
for k, v in zip( for k, v in zip(
@ -132,7 +141,7 @@ class ModelDBIR(torch.nn.Module):
) )
preds = { preds = {
"nvs_prediction": nvs_prediction, "implicitron_render": implicitron_render,
"point_cloud": point_cloud, "point_cloud": point_cloud,
} }

View File

@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Union from typing import Sequence, Union
import torch import torch
@ -14,7 +14,7 @@ def mask_background(
image_rgb: torch.Tensor, image_rgb: torch.Tensor,
mask_fg: torch.Tensor, mask_fg: torch.Tensor,
dim_color: int = 1, dim_color: int = 1,
bg_color: Union[torch.Tensor, str, float] = 0.0, bg_color: Union[torch.Tensor, Sequence, str, float] = 0.0,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Mask the background input image tensor `image_rgb` with `bg_color`. Mask the background input image tensor `image_rgb` with `bg_color`.
@ -26,9 +26,11 @@ def mask_background(
# obtain the background color tensor # obtain the background color tensor
if isinstance(bg_color, torch.Tensor): if isinstance(bg_color, torch.Tensor):
bg_color_t = bg_color.view(1, 3, 1, 1).clone().to(image_rgb) bg_color_t = bg_color.view(1, 3, 1, 1).clone().to(image_rgb)
elif isinstance(bg_color, float): elif isinstance(bg_color, (float, tuple, list)):
if isinstance(bg_color, float):
bg_color = [bg_color] * 3
bg_color_t = torch.tensor( bg_color_t = torch.tensor(
[bg_color] * 3, device=image_rgb.device, dtype=image_rgb.dtype bg_color, device=image_rgb.device, dtype=image_rgb.dtype
).view(*tgt_view) ).view(*tgt_view)
elif isinstance(bg_color, str): elif isinstance(bg_color, str):
if bg_color == "white": if bg_color == "white":

View File

@ -9,7 +9,7 @@ import unittest
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pytorch3d.implicitron.models.autodecoder import Autodecoder from pytorch3d.implicitron.models.autodecoder import Autodecoder
from pytorch3d.implicitron.models.base import GenericModel from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
IdrFeatureField, IdrFeatureField,
) )

View File

@ -6,7 +6,6 @@
import contextlib import contextlib
import copy
import dataclasses import dataclasses
import itertools import itertools
import math import math
@ -19,8 +18,13 @@ from pytorch3d.implicitron.dataset.implicitron_dataset import (
FrameData, FrameData,
ImplicitronDataset, ImplicitronDataset,
) )
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import eval_batch from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
from pytorch3d.implicitron.models.model_dbir import ModelDBIR eval_batch,
)
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
from pytorch3d.implicitron.models.generic_model import GenericModel # noqa
from pytorch3d.implicitron.models.model_dbir import ModelDBIR # noqa
from pytorch3d.implicitron.tools.config import expand_args_fields, registry
from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth
from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_ from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
@ -43,7 +47,7 @@ class TestEvaluation(unittest.TestCase):
category = "skateboard" category = "skateboard"
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz") frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz") sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
self.image_size = 256 self.image_size = 64
self.dataset = ImplicitronDataset( self.dataset = ImplicitronDataset(
frame_annotations_file=frame_file, frame_annotations_file=frame_file,
sequence_annotations_file=sequence_file, sequence_annotations_file=sequence_file,
@ -53,11 +57,11 @@ class TestEvaluation(unittest.TestCase):
box_crop=True, box_crop=True,
path_manager=path_manager, path_manager=path_manager,
) )
self.bg_color = 0.0 self.bg_color = (0.0, 0.0, 0.0)
# init the lpips model for eval # init the lpips model for eval
provide_lpips_vgg() provide_lpips_vgg()
self.lpips_model = lpips.LPIPS(net="vgg") self.lpips_model = lpips.LPIPS(net="vgg").cuda()
def test_eval_depth(self): def test_eval_depth(self):
""" """
@ -200,30 +204,17 @@ class TestEvaluation(unittest.TestCase):
def _one_sequence_test( def _one_sequence_test(
self, self,
seq_dataset, seq_dataset,
n_batches=2, model,
min_batch_size=5, batch_indices,
max_batch_size=10, check_metrics=False,
): ):
# form a list of random batches
batch_indices = []
for _ in range(n_batches):
batch_size = torch.randint(
low=min_batch_size, high=max_batch_size, size=(1,)
)
batch_indices.append(torch.randperm(len(seq_dataset))[:batch_size])
loader = torch.utils.data.DataLoader( loader = torch.utils.data.DataLoader(
seq_dataset, seq_dataset,
# batch_size=1,
shuffle=False, shuffle=False,
batch_sampler=batch_indices, batch_sampler=batch_indices,
collate_fn=FrameData.collate, collate_fn=FrameData.collate,
) )
model = ModelDBIR(image_size=self.image_size, bg_color=self.bg_color)
model.cuda()
self.lpips_model.cuda()
for frame_data in loader: for frame_data in loader:
self.assertIsNone(frame_data.frame_type) self.assertIsNone(frame_data.frame_type)
self.assertIsNotNone(frame_data.image_rgb) self.assertIsNotNone(frame_data.image_rgb)
@ -233,33 +224,37 @@ class TestEvaluation(unittest.TestCase):
*(["train_known"] * (len(frame_data.image_rgb) - 1)), *(["train_known"] * (len(frame_data.image_rgb) - 1)),
] ]
# move frame_data to gpu
frame_data = dataclass_to_cuda_(frame_data) frame_data = dataclass_to_cuda_(frame_data)
preds = model(**dataclasses.asdict(frame_data)) preds = model(**dataclasses.asdict(frame_data))
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
eval_result = eval_batch( eval_result = eval_batch(
frame_data, frame_data,
nvs_prediction, preds["implicitron_render"],
bg_color=self.bg_color, bg_color=self.bg_color,
lpips_model=self.lpips_model, lpips_model=self.lpips_model,
) )
if check_metrics:
self._check_metrics(
frame_data, preds["implicitron_render"], eval_result
)
def _check_metrics(self, frame_data, implicitron_render, eval_result):
# Make a terribly bad NVS prediction and check that this is worse # Make a terribly bad NVS prediction and check that this is worse
# than the DBIR prediction. # than the DBIR prediction.
nvs_prediction_bad = copy.deepcopy(preds["nvs_prediction"]) implicitron_render_bad = implicitron_render.clone()
nvs_prediction_bad.depth_render += ( implicitron_render_bad.depth_render += (
torch.randn_like(nvs_prediction.depth_render) * 100.0 torch.randn_like(implicitron_render_bad.depth_render) * 100.0
) )
nvs_prediction_bad.image_render += ( implicitron_render_bad.image_render += (
torch.randn_like(nvs_prediction.image_render) * 100.0 torch.randn_like(implicitron_render_bad.image_render) * 100.0
) )
nvs_prediction_bad.mask_render = ( implicitron_render_bad.mask_render = (
torch.randn_like(nvs_prediction.mask_render) > 0.0 torch.randn_like(implicitron_render_bad.mask_render) > 0.0
).float() ).float()
eval_result_bad = eval_batch( eval_result_bad = eval_batch(
frame_data, frame_data,
nvs_prediction_bad, implicitron_render_bad,
bg_color=self.bg_color, bg_color=self.bg_color,
lpips_model=self.lpips_model, lpips_model=self.lpips_model,
) )
@ -273,7 +268,7 @@ class TestEvaluation(unittest.TestCase):
"rgb_l1_fg": True, "rgb_l1_fg": True,
} }
for metric in lower_better.keys(): for metric in lower_better:
m_better = eval_result[metric] m_better = eval_result[metric]
m_worse = eval_result_bad[metric] m_worse = eval_result_bad[metric]
if m_better != m_better or m_worse != m_worse: if m_better != m_better or m_worse != m_worse:
@ -285,9 +280,45 @@ class TestEvaluation(unittest.TestCase):
) )
_assert(m_better, m_worse) _assert(m_better, m_worse)
def _get_random_batch_indices(
self, seq_dataset, n_batches=2, min_batch_size=5, max_batch_size=10
):
batch_indices = []
for _ in range(n_batches):
batch_size = torch.randint(
low=min_batch_size, high=max_batch_size, size=(1,)
)
batch_indices.append(torch.randperm(len(seq_dataset))[:batch_size])
return batch_indices
def test_full_eval(self, n_sequences=5): def test_full_eval(self, n_sequences=5):
"""Test evaluation.""" """Test evaluation."""
# caching batch indices first to preserve RNG state
seq_datasets = {}
batch_indices = {}
for seq in itertools.islice(self.dataset.sequence_names(), n_sequences): for seq in itertools.islice(self.dataset.sequence_names(), n_sequences):
idx = list(self.dataset.sequence_indices_in_order(seq)) idx = list(self.dataset.sequence_indices_in_order(seq))
seq_dataset = torch.utils.data.Subset(self.dataset, idx) seq_dataset = torch.utils.data.Subset(self.dataset, idx)
self._one_sequence_test(seq_dataset) seq_datasets[seq] = seq_dataset
batch_indices[seq] = self._get_random_batch_indices(seq_dataset)
for model_class_type in ["ModelDBIR", "GenericModel"]:
ModelClass = registry.get(ImplicitronModelBase, model_class_type)
expand_args_fields(ModelClass)
model = ModelClass(
render_image_width=self.image_size,
render_image_height=self.image_size,
bg_color=self.bg_color,
)
model.eval()
model.cuda()
for seq in itertools.islice(self.dataset.sequence_names(), n_sequences):
self._one_sequence_test(
seq_datasets[seq],
model,
batch_indices[seq],
check_metrics=(model_class_type == "ModelDBIR"),
)

View File

@ -7,7 +7,7 @@
import unittest import unittest
import torch import torch
from pytorch3d.implicitron.models.base import GenericModel from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.models.renderer.base import EvaluationMode from pytorch3d.implicitron.models.renderer.base import EvaluationMode
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras