mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Refactor ViewMetrics
Summary: Make ViewMetrics easy to replace by putting them into an OmegaConf dataclass. Also, re-word a few variable names and fix minor TODOs. Reviewed By: bottler Differential Revision: D37327157 fbshipit-source-id: 78d8e39bbb3548b952f10abbe05688409fb987cc
This commit is contained in:
parent
f4dd151037
commit
ae35824f21
@ -21,6 +21,8 @@ generic_model_args:
|
||||
image_feature_extractor_class_type: null
|
||||
view_pooler_enabled: false
|
||||
implicit_function_class_type: NeuralRadianceFieldImplicitFunction
|
||||
view_metrics_class_type: ViewMetrics
|
||||
regularization_metrics_class_type: RegularizationMetrics
|
||||
loss_weights:
|
||||
loss_rgb_mse: 1.0
|
||||
loss_prev_stage_rgb_mse: 1.0
|
||||
@ -268,6 +270,8 @@ generic_model_args:
|
||||
in_features: 256
|
||||
out_features: 3
|
||||
ray_dir_in_camera_coords: false
|
||||
view_metrics_ViewMetrics_args: {}
|
||||
regularization_metrics_RegularizationMetrics_args: {}
|
||||
solver_args:
|
||||
breed: adam
|
||||
weight_decay: 0.0
|
||||
|
@ -38,7 +38,8 @@ class ImplicitronRender:
|
||||
|
||||
|
||||
class ImplicitronModelBase(ReplaceableBase):
|
||||
"""Replaceable abstract base for all image generation / rendering models.
|
||||
"""
|
||||
Replaceable abstract base for all image generation / rendering models.
|
||||
`forward()` method produces a render with a depth map.
|
||||
"""
|
||||
|
||||
|
@ -16,6 +16,12 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from pytorch3d.implicitron.models.metrics import ( # noqa
|
||||
RegularizationMetrics,
|
||||
RegularizationMetricsBase,
|
||||
ViewMetrics,
|
||||
ViewMetricsBase,
|
||||
)
|
||||
from pytorch3d.implicitron.tools import image_utils, vis_utils
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
expand_args_fields,
|
||||
@ -42,7 +48,7 @@ from .implicit_function.scene_representation_networks import ( # noqa
|
||||
SRNHyperNetImplicitFunction,
|
||||
SRNImplicitFunction,
|
||||
)
|
||||
from .metrics import ViewMetrics
|
||||
|
||||
from .renderer.base import (
|
||||
BaseRenderer,
|
||||
EvaluationMode,
|
||||
@ -184,6 +190,14 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
is available in the global registry.
|
||||
implicit_function: An instance of ImplicitFunctionBase. The actual implicit functions
|
||||
are initialised to be in self._implicit_functions.
|
||||
view_metrics: An instance of ViewMetricsBase used to compute loss terms which
|
||||
are independent of the model's parameters.
|
||||
view_metrics_class_type: The type of view metrics to use, must be available in
|
||||
the global registry.
|
||||
regularization_metrics: An instance of RegularizationMetricsBase used to compute
|
||||
regularization terms which can depend on the model's parameters.
|
||||
regularization_metrics_class_type: The type of regularization metrics to use,
|
||||
must be available in the global registry.
|
||||
loss_weights: A dictionary with a {loss_name: weight} mapping; see documentation
|
||||
for `ViewMetrics` class for available loss functions.
|
||||
log_vars: A list of variable names which should be logged.
|
||||
@ -232,6 +246,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
# The actual implicit functions live in self._implicit_functions
|
||||
implicit_function: ImplicitFunctionBase
|
||||
|
||||
# ----- metrics
|
||||
view_metrics: ViewMetricsBase
|
||||
view_metrics_class_type: str = "ViewMetrics"
|
||||
|
||||
regularization_metrics: RegularizationMetricsBase
|
||||
regularization_metrics_class_type: str = "RegularizationMetrics"
|
||||
|
||||
# ---- loss weights
|
||||
loss_weights: Dict[str, float] = field(
|
||||
default_factory=lambda: {
|
||||
@ -269,7 +290,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
self.view_metrics = ViewMetrics()
|
||||
|
||||
if self.view_pooler_enabled:
|
||||
if self.image_feature_extractor_class_type is None:
|
||||
@ -424,15 +444,31 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
for func in self._implicit_functions:
|
||||
func.unbind_args()
|
||||
|
||||
preds = self._get_view_metrics(
|
||||
raymarched=rendered,
|
||||
xys=ray_bundle.xys,
|
||||
image_rgb=None if image_rgb is None else image_rgb[:n_targets],
|
||||
depth_map=None if depth_map is None else depth_map[:n_targets],
|
||||
fg_probability=None
|
||||
if fg_probability is None
|
||||
else fg_probability[:n_targets],
|
||||
mask_crop=None if mask_crop is None else mask_crop[:n_targets],
|
||||
# A dict to store losses as well as rendering results.
|
||||
preds: Dict[str, Any] = {}
|
||||
|
||||
def safe_slice_targets(
|
||||
tensor: Optional[torch.Tensor],
|
||||
) -> Optional[torch.Tensor]:
|
||||
return None if tensor is None else tensor[:n_targets]
|
||||
|
||||
preds.update(
|
||||
self.view_metrics(
|
||||
results=preds,
|
||||
raymarched=rendered,
|
||||
xys=ray_bundle.xys,
|
||||
image_rgb=safe_slice_targets(image_rgb),
|
||||
depth_map=safe_slice_targets(depth_map),
|
||||
fg_probability=safe_slice_targets(fg_probability),
|
||||
mask_crop=safe_slice_targets(mask_crop),
|
||||
)
|
||||
)
|
||||
|
||||
preds.update(
|
||||
self.regularization_metrics(
|
||||
results=preds,
|
||||
model=self,
|
||||
)
|
||||
)
|
||||
|
||||
if sampling_mode == RenderSamplingMode.MASK_SAMPLE:
|
||||
@ -462,11 +498,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
else:
|
||||
raise AssertionError("Unreachable state")
|
||||
|
||||
# calc the AD penalty, returns None if autodecoder is not active
|
||||
ad_penalty = self.sequence_autodecoder.calc_squared_encoding_norm()
|
||||
if ad_penalty is not None:
|
||||
preds["loss_autodecoder_norm"] = ad_penalty
|
||||
|
||||
# (7) Compute losses
|
||||
# finally get the optimization objective using self.loss_weights
|
||||
objective = self._get_objective(preds)
|
||||
@ -744,45 +775,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
|
||||
return image_rgb, fg_mask, depth_map
|
||||
|
||||
def _get_view_metrics(
|
||||
self,
|
||||
raymarched: RendererOutput,
|
||||
xys: torch.Tensor,
|
||||
image_rgb: Optional[torch.Tensor] = None,
|
||||
depth_map: Optional[torch.Tensor] = None,
|
||||
fg_probability: Optional[torch.Tensor] = None,
|
||||
mask_crop: Optional[torch.Tensor] = None,
|
||||
keys_prefix: str = "loss_",
|
||||
):
|
||||
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
||||
metrics = self.view_metrics(
|
||||
image_sampling_grid=xys,
|
||||
images_pred=raymarched.features,
|
||||
images=image_rgb,
|
||||
depths_pred=raymarched.depths,
|
||||
depths=depth_map,
|
||||
masks_pred=raymarched.masks,
|
||||
masks=fg_probability,
|
||||
masks_crop=mask_crop,
|
||||
keys_prefix=keys_prefix,
|
||||
**raymarched.aux,
|
||||
)
|
||||
|
||||
if raymarched.prev_stage:
|
||||
metrics.update(
|
||||
self._get_view_metrics(
|
||||
raymarched.prev_stage,
|
||||
xys,
|
||||
image_rgb,
|
||||
depth_map,
|
||||
fg_probability,
|
||||
mask_crop,
|
||||
keys_prefix=(keys_prefix + "prev_stage_"),
|
||||
)
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
@torch.no_grad()
|
||||
def _rasterize_mc_samples(
|
||||
self,
|
||||
|
@ -6,64 +6,180 @@
|
||||
|
||||
|
||||
import warnings
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools import metric_utils as utils
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
from pytorch3d.renderer import utils as rend_utils
|
||||
|
||||
from .renderer.base import RendererOutput
|
||||
|
||||
|
||||
class RegularizationMetricsBase(ReplaceableBase, torch.nn.Module):
|
||||
"""
|
||||
Replaceable abstract base for regularization metrics.
|
||||
`forward()` method produces regularization metrics and (unlike ViewMetrics) can
|
||||
depend on the model's parameters.
|
||||
"""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self, model: Any, keys_prefix: str = "loss_", **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculates various regularization terms useful for supervising differentiable
|
||||
rendering pipelines.
|
||||
|
||||
Args:
|
||||
model: A model instance. Useful, for example, to implement
|
||||
weights-based regularization.
|
||||
keys_prefix: A common prefix for all keys in the output dictionary
|
||||
containing all regularization metrics.
|
||||
|
||||
Returns:
|
||||
A dictionary with the resulting regularization metrics. The items
|
||||
will have form `{metric_name_i: metric_value_i}` keyed by the
|
||||
names of the output metrics `metric_name_i` with their corresponding
|
||||
values `metric_value_i` represented as 0-dimensional float tensors.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ViewMetricsBase(ReplaceableBase, torch.nn.Module):
|
||||
"""
|
||||
Replaceable abstract base for model metrics.
|
||||
`forward()` method produces losses and other metrics.
|
||||
"""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
class ViewMetrics(torch.nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
image_sampling_grid: torch.Tensor,
|
||||
images: Optional[torch.Tensor] = None,
|
||||
images_pred: Optional[torch.Tensor] = None,
|
||||
depths: Optional[torch.Tensor] = None,
|
||||
depths_pred: Optional[torch.Tensor] = None,
|
||||
masks: Optional[torch.Tensor] = None,
|
||||
masks_pred: Optional[torch.Tensor] = None,
|
||||
masks_crop: Optional[torch.Tensor] = None,
|
||||
grad_theta: Optional[torch.Tensor] = None,
|
||||
density_grid: Optional[torch.Tensor] = None,
|
||||
raymarched: RendererOutput,
|
||||
xys: torch.Tensor,
|
||||
image_rgb: Optional[torch.Tensor] = None,
|
||||
depth_map: Optional[torch.Tensor] = None,
|
||||
fg_probability: Optional[torch.Tensor] = None,
|
||||
mask_crop: Optional[torch.Tensor] = None,
|
||||
keys_prefix: str = "loss_",
|
||||
mask_renders_by_pred: bool = False,
|
||||
**kwargs,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculates various metrics and loss functions useful for supervising
|
||||
differentiable rendering pipelines. Any additional parameters can be passed
|
||||
in the `raymarched.aux` dictionary.
|
||||
|
||||
Args:
|
||||
results: A dictionary with the resulting view metrics. The items
|
||||
will have form `{metric_name_i: metric_value_i}` keyed by the
|
||||
names of the output metrics `metric_name_i` with their corresponding
|
||||
values `metric_value_i` represented as 0-dimensional float tensors.
|
||||
raymarched: Output of the renderer.
|
||||
xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which
|
||||
the predictions are defined. All ground truth inputs are sampled at
|
||||
these locations in order to extract values that correspond to the
|
||||
predictions.
|
||||
image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb
|
||||
values.
|
||||
depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth
|
||||
values.
|
||||
fg_probability: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth
|
||||
foreground masks.
|
||||
keys_prefix: A common prefix for all keys in the output dictionary
|
||||
containing all view metrics.
|
||||
|
||||
Returns:
|
||||
A dictionary with the resulting view metrics. The items
|
||||
will have form `{metric_name_i: metric_value_i}` keyed by the
|
||||
names of the output metrics `metric_name_i` with their corresponding
|
||||
values `metric_value_i` represented as 0-dimensional float tensors.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@registry.register
|
||||
class RegularizationMetrics(RegularizationMetricsBase):
|
||||
def forward(
|
||||
self, model: Any, keys_prefix: str = "loss_", **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculates the AD penalty, or returns an empty dict if the model's autoencoder
|
||||
is inactive.
|
||||
|
||||
Args:
|
||||
model: A model instance.
|
||||
keys_prefix: A common prefix for all keys in the output dictionary
|
||||
containing all regularization metrics.
|
||||
|
||||
Returns:
|
||||
A dictionary with the resulting regularization metrics. The items
|
||||
will have form `{metric_name_i: metric_value_i}` keyed by the
|
||||
names of the output metrics `metric_name_i` with their corresponding
|
||||
values `metric_value_i` represented as 0-dimensional float tensors.
|
||||
|
||||
The calculated metric is:
|
||||
autoencoder_norm: Autoencoder weight norm regularization term.
|
||||
"""
|
||||
metrics = {}
|
||||
if getattr(model, "sequence_autodecoder", None) is not None:
|
||||
ad_penalty = model.sequence_autodecoder.calc_squared_encoding_norm()
|
||||
if ad_penalty is not None:
|
||||
metrics["autodecoder_norm"] = ad_penalty
|
||||
|
||||
if keys_prefix is not None:
|
||||
metrics = {(keys_prefix + k): v for k, v in metrics.items()}
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
@registry.register
|
||||
class ViewMetrics(ViewMetricsBase):
|
||||
def forward(
|
||||
self,
|
||||
raymarched: RendererOutput,
|
||||
xys: torch.Tensor,
|
||||
image_rgb: Optional[torch.Tensor] = None,
|
||||
depth_map: Optional[torch.Tensor] = None,
|
||||
fg_probability: Optional[torch.Tensor] = None,
|
||||
mask_crop: Optional[torch.Tensor] = None,
|
||||
keys_prefix: str = "loss_",
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculates various differentiable metrics useful for supervising
|
||||
differentiable rendering pipelines.
|
||||
|
||||
Args:
|
||||
image_sampling_grid: A tensor of shape `(B, ..., 2)` containing 2D
|
||||
image locations at which the predictions are defined.
|
||||
All ground truth inputs are sampled at these
|
||||
locations in order to extract values that correspond
|
||||
to the predictions.
|
||||
images: A tensor of shape `(B, H, W, 3)` containing ground truth
|
||||
rgb values.
|
||||
images_pred: A tensor of shape `(B, ..., 3)` containing predicted
|
||||
rgb values.
|
||||
depths: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth
|
||||
depth values.
|
||||
depths_pred: A tensor of shape `(B, ..., 1)` containing predicted
|
||||
depth values.
|
||||
masks: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth
|
||||
foreground masks.
|
||||
masks_pred: A tensor of shape `(B, ..., 1)` containing predicted
|
||||
foreground masks.
|
||||
grad_theta: A tensor of shape `(B, ..., 3)` containing an evaluation
|
||||
of a gradient of a signed distance function w.r.t.
|
||||
results: A dict to store the results in.
|
||||
raymarched.features: Predicted rgb or feature values.
|
||||
raymarched.depths: A tensor of shape `(B, ..., 1)` containing
|
||||
predicted depth values.
|
||||
raymarched.masks: A tensor of shape `(B, ..., 1)` containing
|
||||
predicted foreground masks.
|
||||
raymarched.aux["grad_theta"]: A tensor of shape `(B, ..., 3)` containing an
|
||||
evaluation of a gradient of a signed distance function w.r.t.
|
||||
input 3D coordinates used to compute the eikonal loss.
|
||||
density_grid: A tensor of shape `(B, Hg, Wg, Dg, 1)` containing a
|
||||
`Hg x Wg x Dg` voxel grid of density values.
|
||||
raymarched.aux["density_grid"]: A tensor of shape `(B, Hg, Wg, Dg, 1)`
|
||||
containing a `Hg x Wg x Dg` voxel grid of density values.
|
||||
xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which
|
||||
the predictions are defined. All ground truth inputs are sampled at
|
||||
these locations in order to extract values that correspond to the
|
||||
predictions.
|
||||
image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb
|
||||
values.
|
||||
depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth
|
||||
values.
|
||||
fg_probability: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth
|
||||
foreground masks.
|
||||
keys_prefix: A common prefix for all keys in the output dictionary
|
||||
containing all metrics.
|
||||
mask_renders_by_pred: If `True`, masks rendered images by the predicted
|
||||
`masks_pred` prior to computing all rgb metrics.
|
||||
containing all view metrics.
|
||||
|
||||
Returns:
|
||||
metrics: A dictionary `{metric_name_i: metric_value_i}` keyed by the
|
||||
A dictionary `{metric_name_i: metric_value_i}` keyed by the
|
||||
names of the output metrics `metric_name_i` with their corresponding
|
||||
values `metric_value_i` represented as 0-dimensional float tensors.
|
||||
|
||||
@ -91,109 +207,142 @@ class ViewMetrics(torch.nn.Module):
|
||||
depth_neg_penalty: `min(depth_pred, 0)**2` penalizing negative
|
||||
predicted depth values.
|
||||
"""
|
||||
metrics = self._calculate_stage(
|
||||
raymarched,
|
||||
xys,
|
||||
image_rgb,
|
||||
depth_map,
|
||||
fg_probability,
|
||||
mask_crop,
|
||||
keys_prefix,
|
||||
)
|
||||
|
||||
if raymarched.prev_stage:
|
||||
metrics.update(
|
||||
self(
|
||||
raymarched.prev_stage,
|
||||
xys,
|
||||
image_rgb,
|
||||
depth_map,
|
||||
fg_probability,
|
||||
mask_crop,
|
||||
keys_prefix=(keys_prefix + "prev_stage_"),
|
||||
)
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
def _calculate_stage(
|
||||
self,
|
||||
raymarched: RendererOutput,
|
||||
xys: torch.Tensor,
|
||||
image_rgb: Optional[torch.Tensor] = None,
|
||||
depth_map: Optional[torch.Tensor] = None,
|
||||
fg_probability: Optional[torch.Tensor] = None,
|
||||
mask_crop: Optional[torch.Tensor] = None,
|
||||
keys_prefix: str = "loss_",
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculate metrics for the current stage.
|
||||
"""
|
||||
# TODO: extract functions
|
||||
|
||||
# reshape from B x ... x DIM to B x DIM x -1 x 1
|
||||
images_pred, masks_pred, depths_pred = [
|
||||
_reshape_nongrid_var(x) for x in [images_pred, masks_pred, depths_pred]
|
||||
image_rgb_pred, fg_probability_pred, depth_map_pred = [
|
||||
_reshape_nongrid_var(x)
|
||||
for x in [raymarched.features, raymarched.masks, raymarched.depths]
|
||||
]
|
||||
# reshape the sampling grid as well
|
||||
# TODO: we can get rid of the singular dimension here and in _reshape_nongrid_var
|
||||
# now that we use rend_utils.ndc_grid_sample
|
||||
image_sampling_grid = image_sampling_grid.reshape(
|
||||
image_sampling_grid.shape[0], -1, 1, 2
|
||||
)
|
||||
xys = xys.reshape(xys.shape[0], -1, 1, 2)
|
||||
|
||||
# closure with the given image_sampling_grid
|
||||
# closure with the given xys
|
||||
def sample(tensor, mode):
|
||||
if tensor is None:
|
||||
return tensor
|
||||
return rend_utils.ndc_grid_sample(tensor, image_sampling_grid, mode=mode)
|
||||
return rend_utils.ndc_grid_sample(tensor, xys, mode=mode)
|
||||
|
||||
# eval all results in this size
|
||||
images = sample(images, mode="bilinear")
|
||||
depths = sample(depths, mode="nearest")
|
||||
masks = sample(masks, mode="nearest")
|
||||
masks_crop = sample(masks_crop, mode="nearest")
|
||||
if masks_crop is None and images_pred is not None:
|
||||
masks_crop = torch.ones_like(images_pred[:, :1])
|
||||
if masks_crop is None and depths_pred is not None:
|
||||
masks_crop = torch.ones_like(depths_pred[:, :1])
|
||||
image_rgb = sample(image_rgb, mode="bilinear")
|
||||
depth_map = sample(depth_map, mode="nearest")
|
||||
fg_probability = sample(fg_probability, mode="nearest")
|
||||
mask_crop = sample(mask_crop, mode="nearest")
|
||||
if mask_crop is None and image_rgb_pred is not None:
|
||||
mask_crop = torch.ones_like(image_rgb_pred[:, :1])
|
||||
if mask_crop is None and depth_map_pred is not None:
|
||||
mask_crop = torch.ones_like(depth_map_pred[:, :1])
|
||||
|
||||
preds = {}
|
||||
if images is not None and images_pred is not None:
|
||||
# TODO: mask_renders_by_pred is always false; simplify
|
||||
preds.update(
|
||||
metrics = {}
|
||||
if image_rgb is not None and image_rgb_pred is not None:
|
||||
metrics.update(
|
||||
_rgb_metrics(
|
||||
images,
|
||||
images_pred,
|
||||
masks,
|
||||
masks_pred,
|
||||
masks_crop,
|
||||
mask_renders_by_pred,
|
||||
image_rgb,
|
||||
image_rgb_pred,
|
||||
fg_probability,
|
||||
fg_probability_pred,
|
||||
mask_crop,
|
||||
)
|
||||
)
|
||||
|
||||
if masks_pred is not None:
|
||||
preds["mask_beta_prior"] = utils.beta_prior(masks_pred)
|
||||
if masks is not None and masks_pred is not None:
|
||||
preds["mask_neg_iou"] = utils.neg_iou_loss(
|
||||
masks_pred, masks, mask=masks_crop
|
||||
if fg_probability_pred is not None:
|
||||
metrics["mask_beta_prior"] = utils.beta_prior(fg_probability_pred)
|
||||
if fg_probability is not None and fg_probability_pred is not None:
|
||||
metrics["mask_neg_iou"] = utils.neg_iou_loss(
|
||||
fg_probability_pred, fg_probability, mask=mask_crop
|
||||
)
|
||||
metrics["mask_bce"] = utils.calc_bce(
|
||||
fg_probability_pred, fg_probability, mask=mask_crop
|
||||
)
|
||||
preds["mask_bce"] = utils.calc_bce(masks_pred, masks, mask=masks_crop)
|
||||
|
||||
if depths is not None and depths_pred is not None:
|
||||
assert masks_crop is not None
|
||||
if depth_map is not None and depth_map_pred is not None:
|
||||
assert mask_crop is not None
|
||||
_, abs_ = utils.eval_depth(
|
||||
depths_pred, depths, get_best_scale=True, mask=masks_crop, crop=0
|
||||
depth_map_pred, depth_map, get_best_scale=True, mask=mask_crop, crop=0
|
||||
)
|
||||
preds["depth_abs"] = abs_.mean()
|
||||
metrics["depth_abs"] = abs_.mean()
|
||||
|
||||
if masks is not None:
|
||||
mask = masks * masks_crop
|
||||
if fg_probability is not None:
|
||||
mask = fg_probability * mask_crop
|
||||
_, abs_ = utils.eval_depth(
|
||||
depths_pred, depths, get_best_scale=True, mask=mask, crop=0
|
||||
depth_map_pred, depth_map, get_best_scale=True, mask=mask, crop=0
|
||||
)
|
||||
preds["depth_abs_fg"] = abs_.mean()
|
||||
metrics["depth_abs_fg"] = abs_.mean()
|
||||
|
||||
# regularizers
|
||||
grad_theta = raymarched.aux.get("grad_theta")
|
||||
if grad_theta is not None:
|
||||
preds["eikonal"] = _get_eikonal_loss(grad_theta)
|
||||
metrics["eikonal"] = _get_eikonal_loss(grad_theta)
|
||||
|
||||
density_grid = raymarched.aux.get("density_grid")
|
||||
if density_grid is not None:
|
||||
preds["density_tv"] = _get_grid_tv_loss(density_grid)
|
||||
metrics["density_tv"] = _get_grid_tv_loss(density_grid)
|
||||
|
||||
if depths_pred is not None:
|
||||
preds["depth_neg_penalty"] = _get_depth_neg_penalty_loss(depths_pred)
|
||||
if depth_map_pred is not None:
|
||||
metrics["depth_neg_penalty"] = _get_depth_neg_penalty_loss(depth_map_pred)
|
||||
|
||||
if keys_prefix is not None:
|
||||
preds = {(keys_prefix + k): v for k, v in preds.items()}
|
||||
metrics = {(keys_prefix + k): v for k, v in metrics.items()}
|
||||
|
||||
return preds
|
||||
return metrics
|
||||
|
||||
|
||||
def _rgb_metrics(
|
||||
images, images_pred, masks, masks_pred, masks_crop, mask_renders_by_pred
|
||||
):
|
||||
def _rgb_metrics(images, images_pred, masks, masks_pred, masks_crop):
|
||||
assert masks_crop is not None
|
||||
if mask_renders_by_pred:
|
||||
images = images[..., masks_pred.reshape(-1), :]
|
||||
masks_crop = masks_crop[..., masks_pred.reshape(-1), :]
|
||||
masks = masks is not None and masks[..., masks_pred.reshape(-1), :]
|
||||
rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True)
|
||||
rgb_loss = utils.huber(rgb_squared, scaling=0.03)
|
||||
crop_mass = masks_crop.sum().clamp(1.0)
|
||||
preds = {
|
||||
results = {
|
||||
"rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass,
|
||||
"rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass,
|
||||
"rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop),
|
||||
}
|
||||
if masks is not None:
|
||||
masks = masks_crop * masks
|
||||
preds["rgb_psnr_fg"] = utils.calc_psnr(images_pred, images, mask=masks)
|
||||
preds["rgb_mse_fg"] = (rgb_squared * masks).sum() / masks.sum().clamp(1.0)
|
||||
return preds
|
||||
results["rgb_psnr_fg"] = utils.calc_psnr(images_pred, images, mask=masks)
|
||||
results["rgb_mse_fg"] = (rgb_squared * masks).sum() / masks.sum().clamp(1.0)
|
||||
return results
|
||||
|
||||
|
||||
def _get_eikonal_loss(grad_theta):
|
||||
|
@ -20,6 +20,8 @@ renderer_class_type: LSTMRenderer
|
||||
image_feature_extractor_class_type: ResNetFeatureExtractor
|
||||
view_pooler_enabled: true
|
||||
implicit_function_class_type: IdrFeatureField
|
||||
view_metrics_class_type: ViewMetrics
|
||||
regularization_metrics_class_type: RegularizationMetrics
|
||||
loss_weights:
|
||||
loss_rgb_mse: 1.0
|
||||
loss_prev_stage_rgb_mse: 1.0
|
||||
@ -122,3 +124,5 @@ implicit_function_IdrFeatureField_args:
|
||||
n_harmonic_functions_xyz: 1729
|
||||
pooled_feature_dim: 0
|
||||
encoding_dim: 0
|
||||
view_metrics_ViewMetrics_args: {}
|
||||
regularization_metrics_RegularizationMetrics_args: {}
|
||||
|
Loading…
x
Reference in New Issue
Block a user