mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Extracted ImplicitronModelBase and unified API for GenericModel and ModelDBIR
Summary: To avoid model_zoo, we need to make GenericModel pluggable. I also align creation APIs for convenience. Reviewed By: bottler, davnov134 Differential Revision: D35933093 fbshipit-source-id: 8228926528eb41a795fbfbe32304b8019197e2b1
This commit is contained in:
parent
5c59841863
commit
a6dada399d
@ -71,7 +71,7 @@ from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
|||||||
ImplicitronDataset,
|
ImplicitronDataset,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
|
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
|
||||||
from pytorch3d.implicitron.models.base import EvaluationMode, GenericModel
|
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
|
||||||
from pytorch3d.implicitron.tools import model_io, vis_utils
|
from pytorch3d.implicitron.tools import model_io, vis_utils
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
enable_get_default_args,
|
enable_get_default_args,
|
||||||
@ -615,11 +615,11 @@ def run_eval(cfg, model, all_source_cameras, loader, task, device):
|
|||||||
preds = model(
|
preds = model(
|
||||||
**{**frame_data_for_eval, "evaluation_mode": EvaluationMode.EVALUATION}
|
**{**frame_data_for_eval, "evaluation_mode": EvaluationMode.EVALUATION}
|
||||||
)
|
)
|
||||||
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
|
implicitron_render = copy.deepcopy(preds["implicitron_render"])
|
||||||
per_batch_eval_results.append(
|
per_batch_eval_results.append(
|
||||||
evaluate.eval_batch(
|
evaluate.eval_batch(
|
||||||
frame_data,
|
frame_data,
|
||||||
nvs_prediction,
|
implicitron_render,
|
||||||
bg_color="black",
|
bg_color="black",
|
||||||
lpips_model=lpips_model,
|
lpips_model=lpips_model,
|
||||||
source_cameras=all_source_cameras,
|
source_cameras=all_source_cameras,
|
||||||
|
@ -29,7 +29,7 @@ from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
|||||||
ImplicitronDataset,
|
ImplicitronDataset,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.dataset.utils import is_train_frame
|
from pytorch3d.implicitron.dataset.utils import is_train_frame
|
||||||
from pytorch3d.implicitron.models.base import EvaluationMode
|
from pytorch3d.implicitron.models.base_model import EvaluationMode
|
||||||
from pytorch3d.implicitron.tools.configurable import get_default_args
|
from pytorch3d.implicitron.tools.configurable import get_default_args
|
||||||
from pytorch3d.implicitron.tools.eval_video_trajectory import (
|
from pytorch3d.implicitron.tools.eval_video_trajectory import (
|
||||||
generate_eval_video_cameras,
|
generate_eval_video_cameras,
|
||||||
|
@ -5,10 +5,9 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
import copy
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import os
|
import os
|
||||||
from typing import cast, Optional
|
from typing import cast, Optional, Tuple
|
||||||
|
|
||||||
import lpips
|
import lpips
|
||||||
import torch
|
import torch
|
||||||
@ -76,7 +75,7 @@ def main() -> None:
|
|||||||
|
|
||||||
def evaluate_dbir_for_category(
|
def evaluate_dbir_for_category(
|
||||||
category: str = "apple",
|
category: str = "apple",
|
||||||
bg_color: float = 0.0,
|
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0),
|
||||||
task: str = "singlesequence",
|
task: str = "singlesequence",
|
||||||
single_sequence_id: Optional[int] = None,
|
single_sequence_id: Optional[int] = None,
|
||||||
num_workers: int = 16,
|
num_workers: int = 16,
|
||||||
@ -141,8 +140,9 @@ def evaluate_dbir_for_category(
|
|||||||
raise ValueError("Image size should be set in the dataset")
|
raise ValueError("Image size should be set in the dataset")
|
||||||
|
|
||||||
# init the simple DBIR model
|
# init the simple DBIR model
|
||||||
model = ModelDBIR(
|
model = ModelDBIR( # pyre-ignore[28]: c’tor implicitly overridden
|
||||||
image_size=image_size,
|
render_image_width=image_size,
|
||||||
|
render_image_height=image_size,
|
||||||
bg_color=bg_color,
|
bg_color=bg_color,
|
||||||
max_points=int(1e5),
|
max_points=int(1e5),
|
||||||
)
|
)
|
||||||
@ -157,11 +157,10 @@ def evaluate_dbir_for_category(
|
|||||||
for frame_data in tqdm(test_dataloader):
|
for frame_data in tqdm(test_dataloader):
|
||||||
frame_data = dataclass_to_cuda_(frame_data)
|
frame_data = dataclass_to_cuda_(frame_data)
|
||||||
preds = model(**dataclasses.asdict(frame_data))
|
preds = model(**dataclasses.asdict(frame_data))
|
||||||
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
|
|
||||||
per_batch_eval_results.append(
|
per_batch_eval_results.append(
|
||||||
eval_batch(
|
eval_batch(
|
||||||
frame_data,
|
frame_data,
|
||||||
nvs_prediction,
|
preds["implicitron_render"],
|
||||||
bg_color=bg_color,
|
bg_color=bg_color,
|
||||||
lpips_model=lpips_model,
|
lpips_model=lpips_model,
|
||||||
source_cameras=all_source_cameras,
|
source_cameras=all_source_cameras,
|
||||||
|
@ -9,12 +9,14 @@ import copy
|
|||||||
import warnings
|
import warnings
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from pytorch3d.implicitron.dataset.implicitron_dataset import FrameData
|
from pytorch3d.implicitron.dataset.implicitron_dataset import FrameData
|
||||||
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
|
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
|
||||||
|
from pytorch3d.implicitron.models.base_model import ImplicitronRender
|
||||||
from pytorch3d.implicitron.tools import vis_utils
|
from pytorch3d.implicitron.tools import vis_utils
|
||||||
from pytorch3d.implicitron.tools.camera_utils import volumetric_camera_overlaps
|
from pytorch3d.implicitron.tools.camera_utils import volumetric_camera_overlaps
|
||||||
from pytorch3d.implicitron.tools.image_utils import mask_background
|
from pytorch3d.implicitron.tools.image_utils import mask_background
|
||||||
@ -31,18 +33,6 @@ from visdom import Visdom
|
|||||||
EVAL_N_SRC_VIEWS = [1, 3, 5, 7, 9]
|
EVAL_N_SRC_VIEWS = [1, 3, 5, 7, 9]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class NewViewSynthesisPrediction:
|
|
||||||
"""
|
|
||||||
Holds the tensors that describe a result of synthesizing new views.
|
|
||||||
"""
|
|
||||||
|
|
||||||
depth_render: Optional[torch.Tensor] = None
|
|
||||||
image_render: Optional[torch.Tensor] = None
|
|
||||||
mask_render: Optional[torch.Tensor] = None
|
|
||||||
camera_distance: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _Visualizer:
|
class _Visualizer:
|
||||||
image_render: torch.Tensor
|
image_render: torch.Tensor
|
||||||
@ -145,8 +135,8 @@ class _Visualizer:
|
|||||||
|
|
||||||
def eval_batch(
|
def eval_batch(
|
||||||
frame_data: FrameData,
|
frame_data: FrameData,
|
||||||
nvs_prediction: NewViewSynthesisPrediction,
|
implicitron_render: ImplicitronRender,
|
||||||
bg_color: Union[torch.Tensor, str, float] = "black",
|
bg_color: Union[torch.Tensor, Sequence, str, float] = "black",
|
||||||
mask_thr: float = 0.5,
|
mask_thr: float = 0.5,
|
||||||
lpips_model=None,
|
lpips_model=None,
|
||||||
visualize: bool = False,
|
visualize: bool = False,
|
||||||
@ -162,14 +152,14 @@ def eval_batch(
|
|||||||
is True), a new-view synthesis method (NVS) is tasked to generate new views
|
is True), a new-view synthesis method (NVS) is tasked to generate new views
|
||||||
of the scene from the viewpoint of the target views (for which
|
of the scene from the viewpoint of the target views (for which
|
||||||
frame_data.frame_type.endswith('known') is False). The resulting
|
frame_data.frame_type.endswith('known') is False). The resulting
|
||||||
synthesized new views, stored in `nvs_prediction`, are compared to the
|
synthesized new views, stored in `implicitron_render`, are compared to the
|
||||||
target ground truth in `frame_data` in terms of geometry and appearance
|
target ground truth in `frame_data` in terms of geometry and appearance
|
||||||
resulting in a dictionary of metrics returned by the `eval_batch` function.
|
resulting in a dictionary of metrics returned by the `eval_batch` function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
frame_data: A FrameData object containing the input to the new view
|
frame_data: A FrameData object containing the input to the new view
|
||||||
synthesis method.
|
synthesis method.
|
||||||
nvs_prediction: The data describing the synthesized new views.
|
implicitron_render: The data describing the synthesized new views.
|
||||||
bg_color: The background color of the generated new views and the
|
bg_color: The background color of the generated new views and the
|
||||||
ground truth.
|
ground truth.
|
||||||
lpips_model: A pre-trained model for evaluating the LPIPS metric.
|
lpips_model: A pre-trained model for evaluating the LPIPS metric.
|
||||||
@ -184,26 +174,39 @@ def eval_batch(
|
|||||||
ValueError if frame_data does not have frame_type, camera, or image_rgb
|
ValueError if frame_data does not have frame_type, camera, or image_rgb
|
||||||
ValueError if the batch has a mix of training and test samples
|
ValueError if the batch has a mix of training and test samples
|
||||||
ValueError if the batch frames are not [unseen, known, known, ...]
|
ValueError if the batch frames are not [unseen, known, known, ...]
|
||||||
ValueError if one of the required fields in nvs_prediction is missing
|
ValueError if one of the required fields in implicitron_render is missing
|
||||||
"""
|
"""
|
||||||
REQUIRED_NVS_PREDICTION_FIELDS = ["mask_render", "image_render", "depth_render"]
|
|
||||||
frame_type = frame_data.frame_type
|
frame_type = frame_data.frame_type
|
||||||
if frame_type is None:
|
if frame_type is None:
|
||||||
raise ValueError("Frame type has not been set.")
|
raise ValueError("Frame type has not been set.")
|
||||||
|
|
||||||
# we check that all those fields are not None but Pyre can't infer that properly
|
# we check that all those fields are not None but Pyre can't infer that properly
|
||||||
# TODO: assign to local variables
|
# TODO: assign to local variables and simplify the code.
|
||||||
if frame_data.image_rgb is None:
|
if frame_data.image_rgb is None:
|
||||||
raise ValueError("Image is not in the evaluation batch.")
|
raise ValueError("Image is not in the evaluation batch.")
|
||||||
|
|
||||||
if frame_data.camera is None:
|
if frame_data.camera is None:
|
||||||
raise ValueError("Camera is not in the evaluation batch.")
|
raise ValueError("Camera is not in the evaluation batch.")
|
||||||
|
|
||||||
if any(not hasattr(nvs_prediction, k) for k in REQUIRED_NVS_PREDICTION_FIELDS):
|
# eval all results in the resolution of the frame_data image
|
||||||
raise ValueError("One of the required predicted fields is missing")
|
image_resol = tuple(frame_data.image_rgb.shape[2:])
|
||||||
|
|
||||||
|
# Post-process the render:
|
||||||
|
# 1) check implicitron_render for Nones,
|
||||||
|
# 2) obtain copies to make sure we dont edit the original data,
|
||||||
|
# 3) take only the 1st (target) image
|
||||||
|
# 4) resize to match ground-truth resolution
|
||||||
|
cloned_render: Dict[str, torch.Tensor] = {}
|
||||||
|
for k in ["mask_render", "image_render", "depth_render"]:
|
||||||
|
field = getattr(implicitron_render, k)
|
||||||
|
if field is None:
|
||||||
|
raise ValueError(f"A required predicted field {k} is missing")
|
||||||
|
|
||||||
|
imode = "bilinear" if k == "image_render" else "nearest"
|
||||||
|
cloned_render[k] = (
|
||||||
|
F.interpolate(field[:1], size=image_resol, mode=imode).detach().clone()
|
||||||
|
)
|
||||||
|
|
||||||
# obtain copies to make sure we dont edit the original data
|
|
||||||
nvs_prediction = copy.deepcopy(nvs_prediction)
|
|
||||||
frame_data = copy.deepcopy(frame_data)
|
frame_data = copy.deepcopy(frame_data)
|
||||||
|
|
||||||
# mask the ground truth depth in case frame_data contains the depth mask
|
# mask the ground truth depth in case frame_data contains the depth mask
|
||||||
@ -226,9 +229,6 @@ def eval_batch(
|
|||||||
+ " a target view while the rest should be source views."
|
+ " a target view while the rest should be source views."
|
||||||
) # TODO: do we need to enforce this?
|
) # TODO: do we need to enforce this?
|
||||||
|
|
||||||
# take only the first (target image)
|
|
||||||
for k in REQUIRED_NVS_PREDICTION_FIELDS:
|
|
||||||
setattr(nvs_prediction, k, getattr(nvs_prediction, k)[:1])
|
|
||||||
for k in [
|
for k in [
|
||||||
"depth_map",
|
"depth_map",
|
||||||
"image_rgb",
|
"image_rgb",
|
||||||
@ -242,10 +242,6 @@ def eval_batch(
|
|||||||
if frame_data.depth_map is None or frame_data.depth_map.sum() <= 0:
|
if frame_data.depth_map is None or frame_data.depth_map.sum() <= 0:
|
||||||
warnings.warn("Empty or missing depth map in evaluation!")
|
warnings.warn("Empty or missing depth map in evaluation!")
|
||||||
|
|
||||||
# eval all results in the resolution of the frame_data image
|
|
||||||
# pyre-fixme[16]: `Optional` has no attribute `shape`.
|
|
||||||
image_resol = list(frame_data.image_rgb.shape[2:])
|
|
||||||
|
|
||||||
# threshold the masks to make ground truth binary masks
|
# threshold the masks to make ground truth binary masks
|
||||||
mask_fg, mask_crop = [
|
mask_fg, mask_crop = [
|
||||||
(getattr(frame_data, k) >= mask_thr) for k in ("fg_probability", "mask_crop")
|
(getattr(frame_data, k) >= mask_thr) for k in ("fg_probability", "mask_crop")
|
||||||
@ -258,29 +254,14 @@ def eval_batch(
|
|||||||
bg_color=bg_color,
|
bg_color=bg_color,
|
||||||
)
|
)
|
||||||
|
|
||||||
# resize to the target resolution
|
|
||||||
for k in REQUIRED_NVS_PREDICTION_FIELDS:
|
|
||||||
imode = "bilinear" if k == "image_render" else "nearest"
|
|
||||||
val = getattr(nvs_prediction, k)
|
|
||||||
setattr(
|
|
||||||
nvs_prediction,
|
|
||||||
k,
|
|
||||||
# pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got
|
|
||||||
# `List[typing.Any]`.
|
|
||||||
torch.nn.functional.interpolate(val, size=image_resol, mode=imode),
|
|
||||||
)
|
|
||||||
|
|
||||||
# clamp predicted images
|
# clamp predicted images
|
||||||
# pyre-fixme[16]: `Optional` has no attribute `clamp`.
|
image_render = cloned_render["image_render"].clamp(0.0, 1.0)
|
||||||
image_render = nvs_prediction.image_render.clamp(0.0, 1.0)
|
|
||||||
|
|
||||||
if visualize:
|
if visualize:
|
||||||
visualizer = _Visualizer(
|
visualizer = _Visualizer(
|
||||||
image_render=image_render,
|
image_render=image_render,
|
||||||
image_rgb_masked=image_rgb_masked,
|
image_rgb_masked=image_rgb_masked,
|
||||||
# pyre-fixme[6]: Expected `Tensor` for 3rd param but got
|
depth_render=cloned_render["depth_render"],
|
||||||
# `Optional[torch.Tensor]`.
|
|
||||||
depth_render=nvs_prediction.depth_render,
|
|
||||||
# pyre-fixme[6]: Expected `Tensor` for 4th param but got
|
# pyre-fixme[6]: Expected `Tensor` for 4th param but got
|
||||||
# `Optional[torch.Tensor]`.
|
# `Optional[torch.Tensor]`.
|
||||||
depth_map=frame_data.depth_map,
|
depth_map=frame_data.depth_map,
|
||||||
@ -292,9 +273,7 @@ def eval_batch(
|
|||||||
results: Dict[str, Any] = {}
|
results: Dict[str, Any] = {}
|
||||||
|
|
||||||
results["iou"] = iou(
|
results["iou"] = iou(
|
||||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
cloned_render["mask_render"],
|
||||||
# `Optional[torch.Tensor]`.
|
|
||||||
nvs_prediction.mask_render,
|
|
||||||
mask_fg,
|
mask_fg,
|
||||||
mask=mask_crop,
|
mask=mask_crop,
|
||||||
)
|
)
|
||||||
@ -321,11 +300,7 @@ def eval_batch(
|
|||||||
if name_postfix == "_fg":
|
if name_postfix == "_fg":
|
||||||
# only record depth metrics for the foreground
|
# only record depth metrics for the foreground
|
||||||
_, abs_ = eval_depth(
|
_, abs_ = eval_depth(
|
||||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
cloned_render["depth_render"],
|
||||||
# `Optional[torch.Tensor]`.
|
|
||||||
nvs_prediction.depth_render,
|
|
||||||
# pyre-fixme[6]: Expected `Tensor` for 2nd param but got
|
|
||||||
# `Optional[torch.Tensor]`.
|
|
||||||
frame_data.depth_map,
|
frame_data.depth_map,
|
||||||
get_best_scale=True,
|
get_best_scale=True,
|
||||||
mask=loss_mask_now,
|
mask=loss_mask_now,
|
||||||
@ -343,7 +318,7 @@ def eval_batch(
|
|||||||
if lpips_model is not None:
|
if lpips_model is not None:
|
||||||
im1, im2 = [
|
im1, im2 = [
|
||||||
2.0 * im.clamp(0.0, 1.0) - 1.0
|
2.0 * im.clamp(0.0, 1.0) - 1.0
|
||||||
for im in (image_rgb_masked, nvs_prediction.image_render)
|
for im in (image_rgb_masked, cloned_render["image_render"])
|
||||||
]
|
]
|
||||||
results["lpips"] = lpips_model.forward(im1, im2).item()
|
results["lpips"] = lpips_model.forward(im1, im2).item()
|
||||||
|
|
||||||
|
87
pytorch3d/implicitron/models/base_model.py
Normal file
87
pytorch3d/implicitron/models/base_model.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pytorch3d.implicitron.tools.config import ReplaceableBase
|
||||||
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|
||||||
|
from .renderer.base import EvaluationMode
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImplicitronRender:
|
||||||
|
"""
|
||||||
|
Holds the tensors that describe a result of rendering.
|
||||||
|
"""
|
||||||
|
|
||||||
|
depth_render: Optional[torch.Tensor] = None
|
||||||
|
image_render: Optional[torch.Tensor] = None
|
||||||
|
mask_render: Optional[torch.Tensor] = None
|
||||||
|
camera_distance: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
def clone(self) -> "ImplicitronRender":
|
||||||
|
def safe_clone(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
||||||
|
return t.detach().clone() if t is not None else None
|
||||||
|
|
||||||
|
return ImplicitronRender(
|
||||||
|
depth_render=safe_clone(self.depth_render),
|
||||||
|
image_render=safe_clone(self.image_render),
|
||||||
|
mask_render=safe_clone(self.mask_render),
|
||||||
|
camera_distance=safe_clone(self.camera_distance),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ImplicitronModelBase(ReplaceableBase):
|
||||||
|
"""Replaceable abstract base for all image generation / rendering models.
|
||||||
|
`forward()` method produces a render with a depth map.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
*, # force keyword-only arguments
|
||||||
|
image_rgb: Optional[torch.Tensor],
|
||||||
|
camera: CamerasBase,
|
||||||
|
fg_probability: Optional[torch.Tensor],
|
||||||
|
mask_crop: Optional[torch.Tensor],
|
||||||
|
depth_map: Optional[torch.Tensor],
|
||||||
|
sequence_name: Optional[List[str]],
|
||||||
|
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images;
|
||||||
|
the first `min(B, n_train_target_views)` images are considered targets and
|
||||||
|
are used to supervise the renders; the rest corresponding to the source
|
||||||
|
viewpoints from which features will be extracted.
|
||||||
|
camera: An instance of CamerasBase containing a batch of `B` cameras corresponding
|
||||||
|
to the viewpoints of target images, from which the rays will be sampled,
|
||||||
|
and source images, which will be used for intersecting with target rays.
|
||||||
|
fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch of
|
||||||
|
foreground masks.
|
||||||
|
mask_crop: A binary tensor of shape `(B, 1, H, W)` deonting valid
|
||||||
|
regions in the input images (i.e. regions that do not correspond
|
||||||
|
to, e.g., zero-padding). When the `RaySampler`'s sampling mode is set to
|
||||||
|
"mask_sample", rays will be sampled in the non zero regions.
|
||||||
|
depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps.
|
||||||
|
sequence_name: A list of `B` strings corresponding to the sequence names
|
||||||
|
from which images `image_rgb` were extracted. They are used to match
|
||||||
|
target frames with relevant source frames.
|
||||||
|
evaluation_mode: one of EvaluationMode.TRAINING or
|
||||||
|
EvaluationMode.EVALUATION which determines the settings used for
|
||||||
|
rendering.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
preds: A dictionary containing all outputs of the forward pass. All models should
|
||||||
|
output an instance of `ImplicitronRender` in `preds["implicitron_render"]`.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
@ -13,11 +13,11 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
|
|
||||||
NewViewSynthesisPrediction,
|
|
||||||
)
|
|
||||||
from pytorch3d.implicitron.tools import image_utils, vis_utils
|
from pytorch3d.implicitron.tools import image_utils, vis_utils
|
||||||
from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation
|
from pytorch3d.implicitron.tools.config import (
|
||||||
|
registry,
|
||||||
|
run_auto_creation,
|
||||||
|
)
|
||||||
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
|
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
|
||||||
from pytorch3d.implicitron.tools.utils import cat_dataclass
|
from pytorch3d.implicitron.tools.utils import cat_dataclass
|
||||||
from pytorch3d.renderer import RayBundle, utils as rend_utils
|
from pytorch3d.renderer import RayBundle, utils as rend_utils
|
||||||
@ -25,6 +25,7 @@ from pytorch3d.renderer.cameras import CamerasBase
|
|||||||
from visdom import Visdom
|
from visdom import Visdom
|
||||||
|
|
||||||
from .autodecoder import Autodecoder
|
from .autodecoder import Autodecoder
|
||||||
|
from .base_model import ImplicitronModelBase, ImplicitronRender
|
||||||
from .implicit_function.base import ImplicitFunctionBase
|
from .implicit_function.base import ImplicitFunctionBase
|
||||||
from .implicit_function.idr_feature_field import IdrFeatureField # noqa
|
from .implicit_function.idr_feature_field import IdrFeatureField # noqa
|
||||||
from .implicit_function.neural_radiance_field import ( # noqa
|
from .implicit_function.neural_radiance_field import ( # noqa
|
||||||
@ -56,8 +57,8 @@ STD_LOG_VARS = ["objective", "epoch", "sec/it"]
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# pyre-ignore: 13
|
@registry.register
|
||||||
class GenericModel(Configurable, torch.nn.Module):
|
class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||||
"""
|
"""
|
||||||
GenericModel is a wrapper for the neural implicit
|
GenericModel is a wrapper for the neural implicit
|
||||||
rendering and reconstruction pipeline which consists
|
rendering and reconstruction pipeline which consists
|
||||||
@ -452,7 +453,7 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
preds["depths_render"] = rendered.depths.permute(0, 3, 1, 2)
|
preds["depths_render"] = rendered.depths.permute(0, 3, 1, 2)
|
||||||
preds["masks_render"] = rendered.masks.permute(0, 3, 1, 2)
|
preds["masks_render"] = rendered.masks.permute(0, 3, 1, 2)
|
||||||
|
|
||||||
preds["nvs_prediction"] = NewViewSynthesisPrediction(
|
preds["implicitron_render"] = ImplicitronRender(
|
||||||
image_render=preds["images_render"],
|
image_render=preds["images_render"],
|
||||||
depth_render=preds["depths_render"],
|
depth_render=preds["depths_render"],
|
||||||
mask_render=preds["masks_render"],
|
mask_render=preds["masks_render"],
|
@ -5,13 +5,11 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.dataset.utils import is_known_frame
|
from pytorch3d.implicitron.dataset.utils import is_known_frame
|
||||||
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
|
from pytorch3d.implicitron.tools.config import registry
|
||||||
NewViewSynthesisPrediction,
|
|
||||||
)
|
|
||||||
from pytorch3d.implicitron.tools.point_cloud_utils import (
|
from pytorch3d.implicitron.tools.point_cloud_utils import (
|
||||||
get_rgbd_point_cloud,
|
get_rgbd_point_cloud,
|
||||||
render_point_cloud_pytorch3d,
|
render_point_cloud_pytorch3d,
|
||||||
@ -19,41 +17,43 @@ from pytorch3d.implicitron.tools.point_cloud_utils import (
|
|||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
from pytorch3d.structures import Pointclouds
|
from pytorch3d.structures import Pointclouds
|
||||||
|
|
||||||
|
from .base_model import ImplicitronModelBase, ImplicitronRender
|
||||||
|
from .renderer.base import EvaluationMode
|
||||||
|
|
||||||
class ModelDBIR(torch.nn.Module):
|
|
||||||
|
@registry.register
|
||||||
|
class ModelDBIR(ImplicitronModelBase, torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
A simple depth-based image rendering model.
|
A simple depth-based image rendering model.
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
image_size: int = 256,
|
|
||||||
bg_color: float = 0.0,
|
|
||||||
max_points: int = -1,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initializes a simple DBIR model.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image_size: The size of the rendered rectangular images.
|
render_image_width: The width of the rendered rectangular images.
|
||||||
|
render_image_height: The height of the rendered rectangular images.
|
||||||
bg_color: The color of the background.
|
bg_color: The color of the background.
|
||||||
max_points: Maximum number of points in the point cloud
|
max_points: Maximum number of points in the point cloud
|
||||||
formed by unprojecting all source view depths.
|
formed by unprojecting all source view depths.
|
||||||
If more points are present, they are randomly subsampled
|
If more points are present, they are randomly subsampled
|
||||||
to #max_size points without replacement.
|
to this number of points without replacement.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
render_image_width: int = 256
|
||||||
|
render_image_height: int = 256
|
||||||
|
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
|
||||||
|
max_points: int = -1
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.image_size = image_size
|
|
||||||
self.bg_color = bg_color
|
|
||||||
self.max_points = max_points
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
*, # force keyword-only arguments
|
||||||
|
image_rgb: Optional[torch.Tensor],
|
||||||
camera: CamerasBase,
|
camera: CamerasBase,
|
||||||
image_rgb: torch.Tensor,
|
fg_probability: Optional[torch.Tensor],
|
||||||
depth_map: torch.Tensor,
|
mask_crop: Optional[torch.Tensor],
|
||||||
fg_probability: torch.Tensor,
|
depth_map: Optional[torch.Tensor],
|
||||||
|
sequence_name: Optional[List[str]],
|
||||||
|
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
||||||
frame_type: List[str],
|
frame_type: List[str],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]: # TODO: return a namedtuple or dataclass
|
) -> Dict[str, Any]: # TODO: return a namedtuple or dataclass
|
||||||
@ -72,12 +72,21 @@ class ModelDBIR(torch.nn.Module):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
preds: A dict with the following fields:
|
preds: A dict with the following fields:
|
||||||
nvs_prediction: The rendered colors, depth and mask
|
implicitron_render: The rendered colors, depth and mask
|
||||||
of the target views.
|
of the target views.
|
||||||
point_cloud: The point cloud of the scene. It's renders are
|
point_cloud: The point cloud of the scene. It's renders are
|
||||||
stored in `nvs_prediction`.
|
stored in `implicitron_render`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if image_rgb is None:
|
||||||
|
raise ValueError("ModelDBIR needs image input")
|
||||||
|
|
||||||
|
if fg_probability is None:
|
||||||
|
raise ValueError("ModelDBIR needs foreground mask input")
|
||||||
|
|
||||||
|
if depth_map is None:
|
||||||
|
raise ValueError("ModelDBIR needs depth map input")
|
||||||
|
|
||||||
is_known = is_known_frame(frame_type)
|
is_known = is_known_frame(frame_type)
|
||||||
is_known_idx = torch.where(is_known)[0]
|
is_known_idx = torch.where(is_known)[0]
|
||||||
|
|
||||||
@ -108,7 +117,7 @@ class ModelDBIR(torch.nn.Module):
|
|||||||
_image_render, _mask_render, _depth_render = render_point_cloud_pytorch3d(
|
_image_render, _mask_render, _depth_render = render_point_cloud_pytorch3d(
|
||||||
camera[int(tgt_idx)],
|
camera[int(tgt_idx)],
|
||||||
point_cloud,
|
point_cloud,
|
||||||
render_size=(self.image_size, self.image_size),
|
render_size=(self.render_image_height, self.render_image_width),
|
||||||
point_radius=1e-2,
|
point_radius=1e-2,
|
||||||
topk=10,
|
topk=10,
|
||||||
bg_color=self.bg_color,
|
bg_color=self.bg_color,
|
||||||
@ -121,7 +130,7 @@ class ModelDBIR(torch.nn.Module):
|
|||||||
image_render.append(_image_render)
|
image_render.append(_image_render)
|
||||||
mask_render.append(_mask_render)
|
mask_render.append(_mask_render)
|
||||||
|
|
||||||
nvs_prediction = NewViewSynthesisPrediction(
|
implicitron_render = ImplicitronRender(
|
||||||
**{
|
**{
|
||||||
k: torch.cat(v, dim=0)
|
k: torch.cat(v, dim=0)
|
||||||
for k, v in zip(
|
for k, v in zip(
|
||||||
@ -132,7 +141,7 @@ class ModelDBIR(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
preds = {
|
preds = {
|
||||||
"nvs_prediction": nvs_prediction,
|
"implicitron_render": implicitron_render,
|
||||||
"point_cloud": point_cloud,
|
"point_cloud": point_cloud,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from typing import Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -14,7 +14,7 @@ def mask_background(
|
|||||||
image_rgb: torch.Tensor,
|
image_rgb: torch.Tensor,
|
||||||
mask_fg: torch.Tensor,
|
mask_fg: torch.Tensor,
|
||||||
dim_color: int = 1,
|
dim_color: int = 1,
|
||||||
bg_color: Union[torch.Tensor, str, float] = 0.0,
|
bg_color: Union[torch.Tensor, Sequence, str, float] = 0.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Mask the background input image tensor `image_rgb` with `bg_color`.
|
Mask the background input image tensor `image_rgb` with `bg_color`.
|
||||||
@ -26,9 +26,11 @@ def mask_background(
|
|||||||
# obtain the background color tensor
|
# obtain the background color tensor
|
||||||
if isinstance(bg_color, torch.Tensor):
|
if isinstance(bg_color, torch.Tensor):
|
||||||
bg_color_t = bg_color.view(1, 3, 1, 1).clone().to(image_rgb)
|
bg_color_t = bg_color.view(1, 3, 1, 1).clone().to(image_rgb)
|
||||||
elif isinstance(bg_color, float):
|
elif isinstance(bg_color, (float, tuple, list)):
|
||||||
|
if isinstance(bg_color, float):
|
||||||
|
bg_color = [bg_color] * 3
|
||||||
bg_color_t = torch.tensor(
|
bg_color_t = torch.tensor(
|
||||||
[bg_color] * 3, device=image_rgb.device, dtype=image_rgb.dtype
|
bg_color, device=image_rgb.device, dtype=image_rgb.dtype
|
||||||
).view(*tgt_view)
|
).view(*tgt_view)
|
||||||
elif isinstance(bg_color, str):
|
elif isinstance(bg_color, str):
|
||||||
if bg_color == "white":
|
if bg_color == "white":
|
||||||
|
@ -9,7 +9,7 @@ import unittest
|
|||||||
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pytorch3d.implicitron.models.autodecoder import Autodecoder
|
from pytorch3d.implicitron.models.autodecoder import Autodecoder
|
||||||
from pytorch3d.implicitron.models.base import GenericModel
|
from pytorch3d.implicitron.models.generic_model import GenericModel
|
||||||
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
|
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
|
||||||
IdrFeatureField,
|
IdrFeatureField,
|
||||||
)
|
)
|
||||||
|
@ -6,7 +6,6 @@
|
|||||||
|
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import copy
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import itertools
|
import itertools
|
||||||
import math
|
import math
|
||||||
@ -19,8 +18,13 @@ from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
|||||||
FrameData,
|
FrameData,
|
||||||
ImplicitronDataset,
|
ImplicitronDataset,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import eval_batch
|
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
|
||||||
from pytorch3d.implicitron.models.model_dbir import ModelDBIR
|
eval_batch,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
||||||
|
from pytorch3d.implicitron.models.generic_model import GenericModel # noqa
|
||||||
|
from pytorch3d.implicitron.models.model_dbir import ModelDBIR # noqa
|
||||||
|
from pytorch3d.implicitron.tools.config import expand_args_fields, registry
|
||||||
from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth
|
from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth
|
||||||
from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
|
from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
|
||||||
|
|
||||||
@ -43,7 +47,7 @@ class TestEvaluation(unittest.TestCase):
|
|||||||
category = "skateboard"
|
category = "skateboard"
|
||||||
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
|
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
|
||||||
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
|
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
|
||||||
self.image_size = 256
|
self.image_size = 64
|
||||||
self.dataset = ImplicitronDataset(
|
self.dataset = ImplicitronDataset(
|
||||||
frame_annotations_file=frame_file,
|
frame_annotations_file=frame_file,
|
||||||
sequence_annotations_file=sequence_file,
|
sequence_annotations_file=sequence_file,
|
||||||
@ -53,11 +57,11 @@ class TestEvaluation(unittest.TestCase):
|
|||||||
box_crop=True,
|
box_crop=True,
|
||||||
path_manager=path_manager,
|
path_manager=path_manager,
|
||||||
)
|
)
|
||||||
self.bg_color = 0.0
|
self.bg_color = (0.0, 0.0, 0.0)
|
||||||
|
|
||||||
# init the lpips model for eval
|
# init the lpips model for eval
|
||||||
provide_lpips_vgg()
|
provide_lpips_vgg()
|
||||||
self.lpips_model = lpips.LPIPS(net="vgg")
|
self.lpips_model = lpips.LPIPS(net="vgg").cuda()
|
||||||
|
|
||||||
def test_eval_depth(self):
|
def test_eval_depth(self):
|
||||||
"""
|
"""
|
||||||
@ -200,30 +204,17 @@ class TestEvaluation(unittest.TestCase):
|
|||||||
def _one_sequence_test(
|
def _one_sequence_test(
|
||||||
self,
|
self,
|
||||||
seq_dataset,
|
seq_dataset,
|
||||||
n_batches=2,
|
model,
|
||||||
min_batch_size=5,
|
batch_indices,
|
||||||
max_batch_size=10,
|
check_metrics=False,
|
||||||
):
|
):
|
||||||
# form a list of random batches
|
|
||||||
batch_indices = []
|
|
||||||
for _ in range(n_batches):
|
|
||||||
batch_size = torch.randint(
|
|
||||||
low=min_batch_size, high=max_batch_size, size=(1,)
|
|
||||||
)
|
|
||||||
batch_indices.append(torch.randperm(len(seq_dataset))[:batch_size])
|
|
||||||
|
|
||||||
loader = torch.utils.data.DataLoader(
|
loader = torch.utils.data.DataLoader(
|
||||||
seq_dataset,
|
seq_dataset,
|
||||||
# batch_size=1,
|
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
batch_sampler=batch_indices,
|
batch_sampler=batch_indices,
|
||||||
collate_fn=FrameData.collate,
|
collate_fn=FrameData.collate,
|
||||||
)
|
)
|
||||||
|
|
||||||
model = ModelDBIR(image_size=self.image_size, bg_color=self.bg_color)
|
|
||||||
model.cuda()
|
|
||||||
self.lpips_model.cuda()
|
|
||||||
|
|
||||||
for frame_data in loader:
|
for frame_data in loader:
|
||||||
self.assertIsNone(frame_data.frame_type)
|
self.assertIsNone(frame_data.frame_type)
|
||||||
self.assertIsNotNone(frame_data.image_rgb)
|
self.assertIsNotNone(frame_data.image_rgb)
|
||||||
@ -233,33 +224,37 @@ class TestEvaluation(unittest.TestCase):
|
|||||||
*(["train_known"] * (len(frame_data.image_rgb) - 1)),
|
*(["train_known"] * (len(frame_data.image_rgb) - 1)),
|
||||||
]
|
]
|
||||||
|
|
||||||
# move frame_data to gpu
|
|
||||||
frame_data = dataclass_to_cuda_(frame_data)
|
frame_data = dataclass_to_cuda_(frame_data)
|
||||||
preds = model(**dataclasses.asdict(frame_data))
|
preds = model(**dataclasses.asdict(frame_data))
|
||||||
|
|
||||||
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
|
|
||||||
eval_result = eval_batch(
|
eval_result = eval_batch(
|
||||||
frame_data,
|
frame_data,
|
||||||
nvs_prediction,
|
preds["implicitron_render"],
|
||||||
bg_color=self.bg_color,
|
bg_color=self.bg_color,
|
||||||
lpips_model=self.lpips_model,
|
lpips_model=self.lpips_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if check_metrics:
|
||||||
|
self._check_metrics(
|
||||||
|
frame_data, preds["implicitron_render"], eval_result
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_metrics(self, frame_data, implicitron_render, eval_result):
|
||||||
# Make a terribly bad NVS prediction and check that this is worse
|
# Make a terribly bad NVS prediction and check that this is worse
|
||||||
# than the DBIR prediction.
|
# than the DBIR prediction.
|
||||||
nvs_prediction_bad = copy.deepcopy(preds["nvs_prediction"])
|
implicitron_render_bad = implicitron_render.clone()
|
||||||
nvs_prediction_bad.depth_render += (
|
implicitron_render_bad.depth_render += (
|
||||||
torch.randn_like(nvs_prediction.depth_render) * 100.0
|
torch.randn_like(implicitron_render_bad.depth_render) * 100.0
|
||||||
)
|
)
|
||||||
nvs_prediction_bad.image_render += (
|
implicitron_render_bad.image_render += (
|
||||||
torch.randn_like(nvs_prediction.image_render) * 100.0
|
torch.randn_like(implicitron_render_bad.image_render) * 100.0
|
||||||
)
|
)
|
||||||
nvs_prediction_bad.mask_render = (
|
implicitron_render_bad.mask_render = (
|
||||||
torch.randn_like(nvs_prediction.mask_render) > 0.0
|
torch.randn_like(implicitron_render_bad.mask_render) > 0.0
|
||||||
).float()
|
).float()
|
||||||
eval_result_bad = eval_batch(
|
eval_result_bad = eval_batch(
|
||||||
frame_data,
|
frame_data,
|
||||||
nvs_prediction_bad,
|
implicitron_render_bad,
|
||||||
bg_color=self.bg_color,
|
bg_color=self.bg_color,
|
||||||
lpips_model=self.lpips_model,
|
lpips_model=self.lpips_model,
|
||||||
)
|
)
|
||||||
@ -273,7 +268,7 @@ class TestEvaluation(unittest.TestCase):
|
|||||||
"rgb_l1_fg": True,
|
"rgb_l1_fg": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
for metric in lower_better.keys():
|
for metric in lower_better:
|
||||||
m_better = eval_result[metric]
|
m_better = eval_result[metric]
|
||||||
m_worse = eval_result_bad[metric]
|
m_worse = eval_result_bad[metric]
|
||||||
if m_better != m_better or m_worse != m_worse:
|
if m_better != m_better or m_worse != m_worse:
|
||||||
@ -285,9 +280,45 @@ class TestEvaluation(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
_assert(m_better, m_worse)
|
_assert(m_better, m_worse)
|
||||||
|
|
||||||
|
def _get_random_batch_indices(
|
||||||
|
self, seq_dataset, n_batches=2, min_batch_size=5, max_batch_size=10
|
||||||
|
):
|
||||||
|
batch_indices = []
|
||||||
|
for _ in range(n_batches):
|
||||||
|
batch_size = torch.randint(
|
||||||
|
low=min_batch_size, high=max_batch_size, size=(1,)
|
||||||
|
)
|
||||||
|
batch_indices.append(torch.randperm(len(seq_dataset))[:batch_size])
|
||||||
|
|
||||||
|
return batch_indices
|
||||||
|
|
||||||
def test_full_eval(self, n_sequences=5):
|
def test_full_eval(self, n_sequences=5):
|
||||||
"""Test evaluation."""
|
"""Test evaluation."""
|
||||||
|
|
||||||
|
# caching batch indices first to preserve RNG state
|
||||||
|
seq_datasets = {}
|
||||||
|
batch_indices = {}
|
||||||
for seq in itertools.islice(self.dataset.sequence_names(), n_sequences):
|
for seq in itertools.islice(self.dataset.sequence_names(), n_sequences):
|
||||||
idx = list(self.dataset.sequence_indices_in_order(seq))
|
idx = list(self.dataset.sequence_indices_in_order(seq))
|
||||||
seq_dataset = torch.utils.data.Subset(self.dataset, idx)
|
seq_dataset = torch.utils.data.Subset(self.dataset, idx)
|
||||||
self._one_sequence_test(seq_dataset)
|
seq_datasets[seq] = seq_dataset
|
||||||
|
batch_indices[seq] = self._get_random_batch_indices(seq_dataset)
|
||||||
|
|
||||||
|
for model_class_type in ["ModelDBIR", "GenericModel"]:
|
||||||
|
ModelClass = registry.get(ImplicitronModelBase, model_class_type)
|
||||||
|
expand_args_fields(ModelClass)
|
||||||
|
model = ModelClass(
|
||||||
|
render_image_width=self.image_size,
|
||||||
|
render_image_height=self.image_size,
|
||||||
|
bg_color=self.bg_color,
|
||||||
|
)
|
||||||
|
model.eval()
|
||||||
|
model.cuda()
|
||||||
|
|
||||||
|
for seq in itertools.islice(self.dataset.sequence_names(), n_sequences):
|
||||||
|
self._one_sequence_test(
|
||||||
|
seq_datasets[seq],
|
||||||
|
model,
|
||||||
|
batch_indices[seq],
|
||||||
|
check_metrics=(model_class_type == "ModelDBIR"),
|
||||||
|
)
|
||||||
|
@ -7,7 +7,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.models.base import GenericModel
|
from pytorch3d.implicitron.models.generic_model import GenericModel
|
||||||
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
|
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
|
||||||
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
|
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
|
||||||
from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras
|
from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras
|
||||||
|
Loading…
x
Reference in New Issue
Block a user