mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Extracted ImplicitronModelBase and unified API for GenericModel and ModelDBIR
Summary: To avoid model_zoo, we need to make GenericModel pluggable. I also align creation APIs for convenience. Reviewed By: bottler, davnov134 Differential Revision: D35933093 fbshipit-source-id: 8228926528eb41a795fbfbe32304b8019197e2b1
This commit is contained in:
parent
5c59841863
commit
a6dada399d
@ -71,7 +71,7 @@ from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
||||
ImplicitronDataset,
|
||||
)
|
||||
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
|
||||
from pytorch3d.implicitron.models.base import EvaluationMode, GenericModel
|
||||
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
|
||||
from pytorch3d.implicitron.tools import model_io, vis_utils
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
enable_get_default_args,
|
||||
@ -615,11 +615,11 @@ def run_eval(cfg, model, all_source_cameras, loader, task, device):
|
||||
preds = model(
|
||||
**{**frame_data_for_eval, "evaluation_mode": EvaluationMode.EVALUATION}
|
||||
)
|
||||
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
|
||||
implicitron_render = copy.deepcopy(preds["implicitron_render"])
|
||||
per_batch_eval_results.append(
|
||||
evaluate.eval_batch(
|
||||
frame_data,
|
||||
nvs_prediction,
|
||||
implicitron_render,
|
||||
bg_color="black",
|
||||
lpips_model=lpips_model,
|
||||
source_cameras=all_source_cameras,
|
||||
|
@ -29,7 +29,7 @@ from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
||||
ImplicitronDataset,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.utils import is_train_frame
|
||||
from pytorch3d.implicitron.models.base import EvaluationMode
|
||||
from pytorch3d.implicitron.models.base_model import EvaluationMode
|
||||
from pytorch3d.implicitron.tools.configurable import get_default_args
|
||||
from pytorch3d.implicitron.tools.eval_video_trajectory import (
|
||||
generate_eval_video_cameras,
|
||||
|
@ -5,10 +5,9 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import os
|
||||
from typing import cast, Optional
|
||||
from typing import cast, Optional, Tuple
|
||||
|
||||
import lpips
|
||||
import torch
|
||||
@ -76,7 +75,7 @@ def main() -> None:
|
||||
|
||||
def evaluate_dbir_for_category(
|
||||
category: str = "apple",
|
||||
bg_color: float = 0.0,
|
||||
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0),
|
||||
task: str = "singlesequence",
|
||||
single_sequence_id: Optional[int] = None,
|
||||
num_workers: int = 16,
|
||||
@ -141,8 +140,9 @@ def evaluate_dbir_for_category(
|
||||
raise ValueError("Image size should be set in the dataset")
|
||||
|
||||
# init the simple DBIR model
|
||||
model = ModelDBIR(
|
||||
image_size=image_size,
|
||||
model = ModelDBIR( # pyre-ignore[28]: c’tor implicitly overridden
|
||||
render_image_width=image_size,
|
||||
render_image_height=image_size,
|
||||
bg_color=bg_color,
|
||||
max_points=int(1e5),
|
||||
)
|
||||
@ -157,11 +157,10 @@ def evaluate_dbir_for_category(
|
||||
for frame_data in tqdm(test_dataloader):
|
||||
frame_data = dataclass_to_cuda_(frame_data)
|
||||
preds = model(**dataclasses.asdict(frame_data))
|
||||
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
|
||||
per_batch_eval_results.append(
|
||||
eval_batch(
|
||||
frame_data,
|
||||
nvs_prediction,
|
||||
preds["implicitron_render"],
|
||||
bg_color=bg_color,
|
||||
lpips_model=lpips_model,
|
||||
source_cameras=all_source_cameras,
|
||||
|
@ -9,12 +9,14 @@ import copy
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import FrameData
|
||||
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
|
||||
from pytorch3d.implicitron.models.base_model import ImplicitronRender
|
||||
from pytorch3d.implicitron.tools import vis_utils
|
||||
from pytorch3d.implicitron.tools.camera_utils import volumetric_camera_overlaps
|
||||
from pytorch3d.implicitron.tools.image_utils import mask_background
|
||||
@ -31,18 +33,6 @@ from visdom import Visdom
|
||||
EVAL_N_SRC_VIEWS = [1, 3, 5, 7, 9]
|
||||
|
||||
|
||||
@dataclass
|
||||
class NewViewSynthesisPrediction:
|
||||
"""
|
||||
Holds the tensors that describe a result of synthesizing new views.
|
||||
"""
|
||||
|
||||
depth_render: Optional[torch.Tensor] = None
|
||||
image_render: Optional[torch.Tensor] = None
|
||||
mask_render: Optional[torch.Tensor] = None
|
||||
camera_distance: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Visualizer:
|
||||
image_render: torch.Tensor
|
||||
@ -145,8 +135,8 @@ class _Visualizer:
|
||||
|
||||
def eval_batch(
|
||||
frame_data: FrameData,
|
||||
nvs_prediction: NewViewSynthesisPrediction,
|
||||
bg_color: Union[torch.Tensor, str, float] = "black",
|
||||
implicitron_render: ImplicitronRender,
|
||||
bg_color: Union[torch.Tensor, Sequence, str, float] = "black",
|
||||
mask_thr: float = 0.5,
|
||||
lpips_model=None,
|
||||
visualize: bool = False,
|
||||
@ -162,14 +152,14 @@ def eval_batch(
|
||||
is True), a new-view synthesis method (NVS) is tasked to generate new views
|
||||
of the scene from the viewpoint of the target views (for which
|
||||
frame_data.frame_type.endswith('known') is False). The resulting
|
||||
synthesized new views, stored in `nvs_prediction`, are compared to the
|
||||
synthesized new views, stored in `implicitron_render`, are compared to the
|
||||
target ground truth in `frame_data` in terms of geometry and appearance
|
||||
resulting in a dictionary of metrics returned by the `eval_batch` function.
|
||||
|
||||
Args:
|
||||
frame_data: A FrameData object containing the input to the new view
|
||||
synthesis method.
|
||||
nvs_prediction: The data describing the synthesized new views.
|
||||
implicitron_render: The data describing the synthesized new views.
|
||||
bg_color: The background color of the generated new views and the
|
||||
ground truth.
|
||||
lpips_model: A pre-trained model for evaluating the LPIPS metric.
|
||||
@ -184,26 +174,39 @@ def eval_batch(
|
||||
ValueError if frame_data does not have frame_type, camera, or image_rgb
|
||||
ValueError if the batch has a mix of training and test samples
|
||||
ValueError if the batch frames are not [unseen, known, known, ...]
|
||||
ValueError if one of the required fields in nvs_prediction is missing
|
||||
ValueError if one of the required fields in implicitron_render is missing
|
||||
"""
|
||||
REQUIRED_NVS_PREDICTION_FIELDS = ["mask_render", "image_render", "depth_render"]
|
||||
frame_type = frame_data.frame_type
|
||||
if frame_type is None:
|
||||
raise ValueError("Frame type has not been set.")
|
||||
|
||||
# we check that all those fields are not None but Pyre can't infer that properly
|
||||
# TODO: assign to local variables
|
||||
# TODO: assign to local variables and simplify the code.
|
||||
if frame_data.image_rgb is None:
|
||||
raise ValueError("Image is not in the evaluation batch.")
|
||||
|
||||
if frame_data.camera is None:
|
||||
raise ValueError("Camera is not in the evaluation batch.")
|
||||
|
||||
if any(not hasattr(nvs_prediction, k) for k in REQUIRED_NVS_PREDICTION_FIELDS):
|
||||
raise ValueError("One of the required predicted fields is missing")
|
||||
# eval all results in the resolution of the frame_data image
|
||||
image_resol = tuple(frame_data.image_rgb.shape[2:])
|
||||
|
||||
# Post-process the render:
|
||||
# 1) check implicitron_render for Nones,
|
||||
# 2) obtain copies to make sure we dont edit the original data,
|
||||
# 3) take only the 1st (target) image
|
||||
# 4) resize to match ground-truth resolution
|
||||
cloned_render: Dict[str, torch.Tensor] = {}
|
||||
for k in ["mask_render", "image_render", "depth_render"]:
|
||||
field = getattr(implicitron_render, k)
|
||||
if field is None:
|
||||
raise ValueError(f"A required predicted field {k} is missing")
|
||||
|
||||
imode = "bilinear" if k == "image_render" else "nearest"
|
||||
cloned_render[k] = (
|
||||
F.interpolate(field[:1], size=image_resol, mode=imode).detach().clone()
|
||||
)
|
||||
|
||||
# obtain copies to make sure we dont edit the original data
|
||||
nvs_prediction = copy.deepcopy(nvs_prediction)
|
||||
frame_data = copy.deepcopy(frame_data)
|
||||
|
||||
# mask the ground truth depth in case frame_data contains the depth mask
|
||||
@ -226,9 +229,6 @@ def eval_batch(
|
||||
+ " a target view while the rest should be source views."
|
||||
) # TODO: do we need to enforce this?
|
||||
|
||||
# take only the first (target image)
|
||||
for k in REQUIRED_NVS_PREDICTION_FIELDS:
|
||||
setattr(nvs_prediction, k, getattr(nvs_prediction, k)[:1])
|
||||
for k in [
|
||||
"depth_map",
|
||||
"image_rgb",
|
||||
@ -242,10 +242,6 @@ def eval_batch(
|
||||
if frame_data.depth_map is None or frame_data.depth_map.sum() <= 0:
|
||||
warnings.warn("Empty or missing depth map in evaluation!")
|
||||
|
||||
# eval all results in the resolution of the frame_data image
|
||||
# pyre-fixme[16]: `Optional` has no attribute `shape`.
|
||||
image_resol = list(frame_data.image_rgb.shape[2:])
|
||||
|
||||
# threshold the masks to make ground truth binary masks
|
||||
mask_fg, mask_crop = [
|
||||
(getattr(frame_data, k) >= mask_thr) for k in ("fg_probability", "mask_crop")
|
||||
@ -258,29 +254,14 @@ def eval_batch(
|
||||
bg_color=bg_color,
|
||||
)
|
||||
|
||||
# resize to the target resolution
|
||||
for k in REQUIRED_NVS_PREDICTION_FIELDS:
|
||||
imode = "bilinear" if k == "image_render" else "nearest"
|
||||
val = getattr(nvs_prediction, k)
|
||||
setattr(
|
||||
nvs_prediction,
|
||||
k,
|
||||
# pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got
|
||||
# `List[typing.Any]`.
|
||||
torch.nn.functional.interpolate(val, size=image_resol, mode=imode),
|
||||
)
|
||||
|
||||
# clamp predicted images
|
||||
# pyre-fixme[16]: `Optional` has no attribute `clamp`.
|
||||
image_render = nvs_prediction.image_render.clamp(0.0, 1.0)
|
||||
image_render = cloned_render["image_render"].clamp(0.0, 1.0)
|
||||
|
||||
if visualize:
|
||||
visualizer = _Visualizer(
|
||||
image_render=image_render,
|
||||
image_rgb_masked=image_rgb_masked,
|
||||
# pyre-fixme[6]: Expected `Tensor` for 3rd param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
depth_render=nvs_prediction.depth_render,
|
||||
depth_render=cloned_render["depth_render"],
|
||||
# pyre-fixme[6]: Expected `Tensor` for 4th param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
depth_map=frame_data.depth_map,
|
||||
@ -292,9 +273,7 @@ def eval_batch(
|
||||
results: Dict[str, Any] = {}
|
||||
|
||||
results["iou"] = iou(
|
||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
nvs_prediction.mask_render,
|
||||
cloned_render["mask_render"],
|
||||
mask_fg,
|
||||
mask=mask_crop,
|
||||
)
|
||||
@ -321,11 +300,7 @@ def eval_batch(
|
||||
if name_postfix == "_fg":
|
||||
# only record depth metrics for the foreground
|
||||
_, abs_ = eval_depth(
|
||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
nvs_prediction.depth_render,
|
||||
# pyre-fixme[6]: Expected `Tensor` for 2nd param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
cloned_render["depth_render"],
|
||||
frame_data.depth_map,
|
||||
get_best_scale=True,
|
||||
mask=loss_mask_now,
|
||||
@ -343,7 +318,7 @@ def eval_batch(
|
||||
if lpips_model is not None:
|
||||
im1, im2 = [
|
||||
2.0 * im.clamp(0.0, 1.0) - 1.0
|
||||
for im in (image_rgb_masked, nvs_prediction.image_render)
|
||||
for im in (image_rgb_masked, cloned_render["image_render"])
|
||||
]
|
||||
results["lpips"] = lpips_model.forward(im1, im2).item()
|
||||
|
||||
|
87
pytorch3d/implicitron/models/base_model.py
Normal file
87
pytorch3d/implicitron/models/base_model.py
Normal file
@ -0,0 +1,87 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import ReplaceableBase
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
|
||||
from .renderer.base import EvaluationMode
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImplicitronRender:
|
||||
"""
|
||||
Holds the tensors that describe a result of rendering.
|
||||
"""
|
||||
|
||||
depth_render: Optional[torch.Tensor] = None
|
||||
image_render: Optional[torch.Tensor] = None
|
||||
mask_render: Optional[torch.Tensor] = None
|
||||
camera_distance: Optional[torch.Tensor] = None
|
||||
|
||||
def clone(self) -> "ImplicitronRender":
|
||||
def safe_clone(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
||||
return t.detach().clone() if t is not None else None
|
||||
|
||||
return ImplicitronRender(
|
||||
depth_render=safe_clone(self.depth_render),
|
||||
image_render=safe_clone(self.image_render),
|
||||
mask_render=safe_clone(self.mask_render),
|
||||
camera_distance=safe_clone(self.camera_distance),
|
||||
)
|
||||
|
||||
|
||||
class ImplicitronModelBase(ReplaceableBase):
|
||||
"""Replaceable abstract base for all image generation / rendering models.
|
||||
`forward()` method produces a render with a depth map.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*, # force keyword-only arguments
|
||||
image_rgb: Optional[torch.Tensor],
|
||||
camera: CamerasBase,
|
||||
fg_probability: Optional[torch.Tensor],
|
||||
mask_crop: Optional[torch.Tensor],
|
||||
depth_map: Optional[torch.Tensor],
|
||||
sequence_name: Optional[List[str]],
|
||||
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Args:
|
||||
image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images;
|
||||
the first `min(B, n_train_target_views)` images are considered targets and
|
||||
are used to supervise the renders; the rest corresponding to the source
|
||||
viewpoints from which features will be extracted.
|
||||
camera: An instance of CamerasBase containing a batch of `B` cameras corresponding
|
||||
to the viewpoints of target images, from which the rays will be sampled,
|
||||
and source images, which will be used for intersecting with target rays.
|
||||
fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch of
|
||||
foreground masks.
|
||||
mask_crop: A binary tensor of shape `(B, 1, H, W)` deonting valid
|
||||
regions in the input images (i.e. regions that do not correspond
|
||||
to, e.g., zero-padding). When the `RaySampler`'s sampling mode is set to
|
||||
"mask_sample", rays will be sampled in the non zero regions.
|
||||
depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps.
|
||||
sequence_name: A list of `B` strings corresponding to the sequence names
|
||||
from which images `image_rgb` were extracted. They are used to match
|
||||
target frames with relevant source frames.
|
||||
evaluation_mode: one of EvaluationMode.TRAINING or
|
||||
EvaluationMode.EVALUATION which determines the settings used for
|
||||
rendering.
|
||||
|
||||
Returns:
|
||||
preds: A dictionary containing all outputs of the forward pass. All models should
|
||||
output an instance of `ImplicitronRender` in `preds["implicitron_render"]`.
|
||||
"""
|
||||
raise NotImplementedError()
|
@ -13,11 +13,11 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
|
||||
NewViewSynthesisPrediction,
|
||||
)
|
||||
from pytorch3d.implicitron.tools import image_utils, vis_utils
|
||||
from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
registry,
|
||||
run_auto_creation,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
|
||||
from pytorch3d.implicitron.tools.utils import cat_dataclass
|
||||
from pytorch3d.renderer import RayBundle, utils as rend_utils
|
||||
@ -25,6 +25,7 @@ from pytorch3d.renderer.cameras import CamerasBase
|
||||
from visdom import Visdom
|
||||
|
||||
from .autodecoder import Autodecoder
|
||||
from .base_model import ImplicitronModelBase, ImplicitronRender
|
||||
from .implicit_function.base import ImplicitFunctionBase
|
||||
from .implicit_function.idr_feature_field import IdrFeatureField # noqa
|
||||
from .implicit_function.neural_radiance_field import ( # noqa
|
||||
@ -56,8 +57,8 @@ STD_LOG_VARS = ["objective", "epoch", "sec/it"]
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# pyre-ignore: 13
|
||||
class GenericModel(Configurable, torch.nn.Module):
|
||||
@registry.register
|
||||
class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
"""
|
||||
GenericModel is a wrapper for the neural implicit
|
||||
rendering and reconstruction pipeline which consists
|
||||
@ -452,7 +453,7 @@ class GenericModel(Configurable, torch.nn.Module):
|
||||
preds["depths_render"] = rendered.depths.permute(0, 3, 1, 2)
|
||||
preds["masks_render"] = rendered.masks.permute(0, 3, 1, 2)
|
||||
|
||||
preds["nvs_prediction"] = NewViewSynthesisPrediction(
|
||||
preds["implicitron_render"] = ImplicitronRender(
|
||||
image_render=preds["images_render"],
|
||||
depth_render=preds["depths_render"],
|
||||
mask_render=preds["masks_render"],
|
@ -5,13 +5,11 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.dataset.utils import is_known_frame
|
||||
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
|
||||
NewViewSynthesisPrediction,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.config import registry
|
||||
from pytorch3d.implicitron.tools.point_cloud_utils import (
|
||||
get_rgbd_point_cloud,
|
||||
render_point_cloud_pytorch3d,
|
||||
@ -19,41 +17,43 @@ from pytorch3d.implicitron.tools.point_cloud_utils import (
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
from pytorch3d.structures import Pointclouds
|
||||
|
||||
from .base_model import ImplicitronModelBase, ImplicitronRender
|
||||
from .renderer.base import EvaluationMode
|
||||
|
||||
class ModelDBIR(torch.nn.Module):
|
||||
|
||||
@registry.register
|
||||
class ModelDBIR(ImplicitronModelBase, torch.nn.Module):
|
||||
"""
|
||||
A simple depth-based image rendering model.
|
||||
|
||||
Args:
|
||||
render_image_width: The width of the rendered rectangular images.
|
||||
render_image_height: The height of the rendered rectangular images.
|
||||
bg_color: The color of the background.
|
||||
max_points: Maximum number of points in the point cloud
|
||||
formed by unprojecting all source view depths.
|
||||
If more points are present, they are randomly subsampled
|
||||
to this number of points without replacement.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size: int = 256,
|
||||
bg_color: float = 0.0,
|
||||
max_points: int = -1,
|
||||
):
|
||||
"""
|
||||
Initializes a simple DBIR model.
|
||||
|
||||
Args:
|
||||
image_size: The size of the rendered rectangular images.
|
||||
bg_color: The color of the background.
|
||||
max_points: Maximum number of points in the point cloud
|
||||
formed by unprojecting all source view depths.
|
||||
If more points are present, they are randomly subsampled
|
||||
to #max_size points without replacement.
|
||||
"""
|
||||
render_image_width: int = 256
|
||||
render_image_height: int = 256
|
||||
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
|
||||
max_points: int = -1
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
self.image_size = image_size
|
||||
self.bg_color = bg_color
|
||||
self.max_points = max_points
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*, # force keyword-only arguments
|
||||
image_rgb: Optional[torch.Tensor],
|
||||
camera: CamerasBase,
|
||||
image_rgb: torch.Tensor,
|
||||
depth_map: torch.Tensor,
|
||||
fg_probability: torch.Tensor,
|
||||
fg_probability: Optional[torch.Tensor],
|
||||
mask_crop: Optional[torch.Tensor],
|
||||
depth_map: Optional[torch.Tensor],
|
||||
sequence_name: Optional[List[str]],
|
||||
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
||||
frame_type: List[str],
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]: # TODO: return a namedtuple or dataclass
|
||||
@ -72,12 +72,21 @@ class ModelDBIR(torch.nn.Module):
|
||||
|
||||
Returns:
|
||||
preds: A dict with the following fields:
|
||||
nvs_prediction: The rendered colors, depth and mask
|
||||
implicitron_render: The rendered colors, depth and mask
|
||||
of the target views.
|
||||
point_cloud: The point cloud of the scene. It's renders are
|
||||
stored in `nvs_prediction`.
|
||||
stored in `implicitron_render`.
|
||||
"""
|
||||
|
||||
if image_rgb is None:
|
||||
raise ValueError("ModelDBIR needs image input")
|
||||
|
||||
if fg_probability is None:
|
||||
raise ValueError("ModelDBIR needs foreground mask input")
|
||||
|
||||
if depth_map is None:
|
||||
raise ValueError("ModelDBIR needs depth map input")
|
||||
|
||||
is_known = is_known_frame(frame_type)
|
||||
is_known_idx = torch.where(is_known)[0]
|
||||
|
||||
@ -108,7 +117,7 @@ class ModelDBIR(torch.nn.Module):
|
||||
_image_render, _mask_render, _depth_render = render_point_cloud_pytorch3d(
|
||||
camera[int(tgt_idx)],
|
||||
point_cloud,
|
||||
render_size=(self.image_size, self.image_size),
|
||||
render_size=(self.render_image_height, self.render_image_width),
|
||||
point_radius=1e-2,
|
||||
topk=10,
|
||||
bg_color=self.bg_color,
|
||||
@ -121,7 +130,7 @@ class ModelDBIR(torch.nn.Module):
|
||||
image_render.append(_image_render)
|
||||
mask_render.append(_mask_render)
|
||||
|
||||
nvs_prediction = NewViewSynthesisPrediction(
|
||||
implicitron_render = ImplicitronRender(
|
||||
**{
|
||||
k: torch.cat(v, dim=0)
|
||||
for k, v in zip(
|
||||
@ -132,7 +141,7 @@ class ModelDBIR(torch.nn.Module):
|
||||
)
|
||||
|
||||
preds = {
|
||||
"nvs_prediction": nvs_prediction,
|
||||
"implicitron_render": implicitron_render,
|
||||
"point_cloud": point_cloud,
|
||||
}
|
||||
|
||||
|
@ -5,7 +5,7 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Union
|
||||
from typing import Sequence, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -14,7 +14,7 @@ def mask_background(
|
||||
image_rgb: torch.Tensor,
|
||||
mask_fg: torch.Tensor,
|
||||
dim_color: int = 1,
|
||||
bg_color: Union[torch.Tensor, str, float] = 0.0,
|
||||
bg_color: Union[torch.Tensor, Sequence, str, float] = 0.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Mask the background input image tensor `image_rgb` with `bg_color`.
|
||||
@ -26,9 +26,11 @@ def mask_background(
|
||||
# obtain the background color tensor
|
||||
if isinstance(bg_color, torch.Tensor):
|
||||
bg_color_t = bg_color.view(1, 3, 1, 1).clone().to(image_rgb)
|
||||
elif isinstance(bg_color, float):
|
||||
elif isinstance(bg_color, (float, tuple, list)):
|
||||
if isinstance(bg_color, float):
|
||||
bg_color = [bg_color] * 3
|
||||
bg_color_t = torch.tensor(
|
||||
[bg_color] * 3, device=image_rgb.device, dtype=image_rgb.dtype
|
||||
bg_color, device=image_rgb.device, dtype=image_rgb.dtype
|
||||
).view(*tgt_view)
|
||||
elif isinstance(bg_color, str):
|
||||
if bg_color == "white":
|
||||
|
@ -9,7 +9,7 @@ import unittest
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from pytorch3d.implicitron.models.autodecoder import Autodecoder
|
||||
from pytorch3d.implicitron.models.base import GenericModel
|
||||
from pytorch3d.implicitron.models.generic_model import GenericModel
|
||||
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
|
||||
IdrFeatureField,
|
||||
)
|
||||
|
@ -6,7 +6,6 @@
|
||||
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import dataclasses
|
||||
import itertools
|
||||
import math
|
||||
@ -19,8 +18,13 @@ from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
||||
FrameData,
|
||||
ImplicitronDataset,
|
||||
)
|
||||
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import eval_batch
|
||||
from pytorch3d.implicitron.models.model_dbir import ModelDBIR
|
||||
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
|
||||
eval_batch,
|
||||
)
|
||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
||||
from pytorch3d.implicitron.models.generic_model import GenericModel # noqa
|
||||
from pytorch3d.implicitron.models.model_dbir import ModelDBIR # noqa
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields, registry
|
||||
from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth
|
||||
from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
|
||||
|
||||
@ -43,7 +47,7 @@ class TestEvaluation(unittest.TestCase):
|
||||
category = "skateboard"
|
||||
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
|
||||
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
|
||||
self.image_size = 256
|
||||
self.image_size = 64
|
||||
self.dataset = ImplicitronDataset(
|
||||
frame_annotations_file=frame_file,
|
||||
sequence_annotations_file=sequence_file,
|
||||
@ -53,11 +57,11 @@ class TestEvaluation(unittest.TestCase):
|
||||
box_crop=True,
|
||||
path_manager=path_manager,
|
||||
)
|
||||
self.bg_color = 0.0
|
||||
self.bg_color = (0.0, 0.0, 0.0)
|
||||
|
||||
# init the lpips model for eval
|
||||
provide_lpips_vgg()
|
||||
self.lpips_model = lpips.LPIPS(net="vgg")
|
||||
self.lpips_model = lpips.LPIPS(net="vgg").cuda()
|
||||
|
||||
def test_eval_depth(self):
|
||||
"""
|
||||
@ -200,30 +204,17 @@ class TestEvaluation(unittest.TestCase):
|
||||
def _one_sequence_test(
|
||||
self,
|
||||
seq_dataset,
|
||||
n_batches=2,
|
||||
min_batch_size=5,
|
||||
max_batch_size=10,
|
||||
model,
|
||||
batch_indices,
|
||||
check_metrics=False,
|
||||
):
|
||||
# form a list of random batches
|
||||
batch_indices = []
|
||||
for _ in range(n_batches):
|
||||
batch_size = torch.randint(
|
||||
low=min_batch_size, high=max_batch_size, size=(1,)
|
||||
)
|
||||
batch_indices.append(torch.randperm(len(seq_dataset))[:batch_size])
|
||||
|
||||
loader = torch.utils.data.DataLoader(
|
||||
seq_dataset,
|
||||
# batch_size=1,
|
||||
shuffle=False,
|
||||
batch_sampler=batch_indices,
|
||||
collate_fn=FrameData.collate,
|
||||
)
|
||||
|
||||
model = ModelDBIR(image_size=self.image_size, bg_color=self.bg_color)
|
||||
model.cuda()
|
||||
self.lpips_model.cuda()
|
||||
|
||||
for frame_data in loader:
|
||||
self.assertIsNone(frame_data.frame_type)
|
||||
self.assertIsNotNone(frame_data.image_rgb)
|
||||
@ -233,61 +224,101 @@ class TestEvaluation(unittest.TestCase):
|
||||
*(["train_known"] * (len(frame_data.image_rgb) - 1)),
|
||||
]
|
||||
|
||||
# move frame_data to gpu
|
||||
frame_data = dataclass_to_cuda_(frame_data)
|
||||
preds = model(**dataclasses.asdict(frame_data))
|
||||
|
||||
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
|
||||
eval_result = eval_batch(
|
||||
frame_data,
|
||||
nvs_prediction,
|
||||
preds["implicitron_render"],
|
||||
bg_color=self.bg_color,
|
||||
lpips_model=self.lpips_model,
|
||||
)
|
||||
|
||||
# Make a terribly bad NVS prediction and check that this is worse
|
||||
# than the DBIR prediction.
|
||||
nvs_prediction_bad = copy.deepcopy(preds["nvs_prediction"])
|
||||
nvs_prediction_bad.depth_render += (
|
||||
torch.randn_like(nvs_prediction.depth_render) * 100.0
|
||||
)
|
||||
nvs_prediction_bad.image_render += (
|
||||
torch.randn_like(nvs_prediction.image_render) * 100.0
|
||||
)
|
||||
nvs_prediction_bad.mask_render = (
|
||||
torch.randn_like(nvs_prediction.mask_render) > 0.0
|
||||
).float()
|
||||
eval_result_bad = eval_batch(
|
||||
frame_data,
|
||||
nvs_prediction_bad,
|
||||
bg_color=self.bg_color,
|
||||
lpips_model=self.lpips_model,
|
||||
)
|
||||
|
||||
lower_better = {
|
||||
"psnr": False,
|
||||
"psnr_fg": False,
|
||||
"depth_abs_fg": True,
|
||||
"iou": False,
|
||||
"rgb_l1": True,
|
||||
"rgb_l1_fg": True,
|
||||
}
|
||||
|
||||
for metric in lower_better.keys():
|
||||
m_better = eval_result[metric]
|
||||
m_worse = eval_result_bad[metric]
|
||||
if m_better != m_better or m_worse != m_worse:
|
||||
continue # metric is missing, i.e. NaN
|
||||
_assert = (
|
||||
self.assertLessEqual
|
||||
if lower_better[metric]
|
||||
else self.assertGreaterEqual
|
||||
if check_metrics:
|
||||
self._check_metrics(
|
||||
frame_data, preds["implicitron_render"], eval_result
|
||||
)
|
||||
_assert(m_better, m_worse)
|
||||
|
||||
def _check_metrics(self, frame_data, implicitron_render, eval_result):
|
||||
# Make a terribly bad NVS prediction and check that this is worse
|
||||
# than the DBIR prediction.
|
||||
implicitron_render_bad = implicitron_render.clone()
|
||||
implicitron_render_bad.depth_render += (
|
||||
torch.randn_like(implicitron_render_bad.depth_render) * 100.0
|
||||
)
|
||||
implicitron_render_bad.image_render += (
|
||||
torch.randn_like(implicitron_render_bad.image_render) * 100.0
|
||||
)
|
||||
implicitron_render_bad.mask_render = (
|
||||
torch.randn_like(implicitron_render_bad.mask_render) > 0.0
|
||||
).float()
|
||||
eval_result_bad = eval_batch(
|
||||
frame_data,
|
||||
implicitron_render_bad,
|
||||
bg_color=self.bg_color,
|
||||
lpips_model=self.lpips_model,
|
||||
)
|
||||
|
||||
lower_better = {
|
||||
"psnr": False,
|
||||
"psnr_fg": False,
|
||||
"depth_abs_fg": True,
|
||||
"iou": False,
|
||||
"rgb_l1": True,
|
||||
"rgb_l1_fg": True,
|
||||
}
|
||||
|
||||
for metric in lower_better:
|
||||
m_better = eval_result[metric]
|
||||
m_worse = eval_result_bad[metric]
|
||||
if m_better != m_better or m_worse != m_worse:
|
||||
continue # metric is missing, i.e. NaN
|
||||
_assert = (
|
||||
self.assertLessEqual
|
||||
if lower_better[metric]
|
||||
else self.assertGreaterEqual
|
||||
)
|
||||
_assert(m_better, m_worse)
|
||||
|
||||
def _get_random_batch_indices(
|
||||
self, seq_dataset, n_batches=2, min_batch_size=5, max_batch_size=10
|
||||
):
|
||||
batch_indices = []
|
||||
for _ in range(n_batches):
|
||||
batch_size = torch.randint(
|
||||
low=min_batch_size, high=max_batch_size, size=(1,)
|
||||
)
|
||||
batch_indices.append(torch.randperm(len(seq_dataset))[:batch_size])
|
||||
|
||||
return batch_indices
|
||||
|
||||
def test_full_eval(self, n_sequences=5):
|
||||
"""Test evaluation."""
|
||||
|
||||
# caching batch indices first to preserve RNG state
|
||||
seq_datasets = {}
|
||||
batch_indices = {}
|
||||
for seq in itertools.islice(self.dataset.sequence_names(), n_sequences):
|
||||
idx = list(self.dataset.sequence_indices_in_order(seq))
|
||||
seq_dataset = torch.utils.data.Subset(self.dataset, idx)
|
||||
self._one_sequence_test(seq_dataset)
|
||||
seq_datasets[seq] = seq_dataset
|
||||
batch_indices[seq] = self._get_random_batch_indices(seq_dataset)
|
||||
|
||||
for model_class_type in ["ModelDBIR", "GenericModel"]:
|
||||
ModelClass = registry.get(ImplicitronModelBase, model_class_type)
|
||||
expand_args_fields(ModelClass)
|
||||
model = ModelClass(
|
||||
render_image_width=self.image_size,
|
||||
render_image_height=self.image_size,
|
||||
bg_color=self.bg_color,
|
||||
)
|
||||
model.eval()
|
||||
model.cuda()
|
||||
|
||||
for seq in itertools.islice(self.dataset.sequence_names(), n_sequences):
|
||||
self._one_sequence_test(
|
||||
seq_datasets[seq],
|
||||
model,
|
||||
batch_indices[seq],
|
||||
check_metrics=(model_class_type == "ModelDBIR"),
|
||||
)
|
||||
|
@ -7,7 +7,7 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.models.base import GenericModel
|
||||
from pytorch3d.implicitron.models.generic_model import GenericModel
|
||||
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
|
||||
from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras
|
||||
|
Loading…
x
Reference in New Issue
Block a user