mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Refactor ViewMetrics
Summary: Make ViewMetrics easy to replace by putting them into an OmegaConf dataclass. Also, re-word a few variable names and fix minor TODOs. Reviewed By: bottler Differential Revision: D37327157 fbshipit-source-id: 78d8e39bbb3548b952f10abbe05688409fb987cc
This commit is contained in:
parent
f4dd151037
commit
ae35824f21
@ -21,6 +21,8 @@ generic_model_args:
|
|||||||
image_feature_extractor_class_type: null
|
image_feature_extractor_class_type: null
|
||||||
view_pooler_enabled: false
|
view_pooler_enabled: false
|
||||||
implicit_function_class_type: NeuralRadianceFieldImplicitFunction
|
implicit_function_class_type: NeuralRadianceFieldImplicitFunction
|
||||||
|
view_metrics_class_type: ViewMetrics
|
||||||
|
regularization_metrics_class_type: RegularizationMetrics
|
||||||
loss_weights:
|
loss_weights:
|
||||||
loss_rgb_mse: 1.0
|
loss_rgb_mse: 1.0
|
||||||
loss_prev_stage_rgb_mse: 1.0
|
loss_prev_stage_rgb_mse: 1.0
|
||||||
@ -268,6 +270,8 @@ generic_model_args:
|
|||||||
in_features: 256
|
in_features: 256
|
||||||
out_features: 3
|
out_features: 3
|
||||||
ray_dir_in_camera_coords: false
|
ray_dir_in_camera_coords: false
|
||||||
|
view_metrics_ViewMetrics_args: {}
|
||||||
|
regularization_metrics_RegularizationMetrics_args: {}
|
||||||
solver_args:
|
solver_args:
|
||||||
breed: adam
|
breed: adam
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
@ -38,7 +38,8 @@ class ImplicitronRender:
|
|||||||
|
|
||||||
|
|
||||||
class ImplicitronModelBase(ReplaceableBase):
|
class ImplicitronModelBase(ReplaceableBase):
|
||||||
"""Replaceable abstract base for all image generation / rendering models.
|
"""
|
||||||
|
Replaceable abstract base for all image generation / rendering models.
|
||||||
`forward()` method produces a render with a depth map.
|
`forward()` method produces a render with a depth map.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -16,6 +16,12 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
from pytorch3d.implicitron.models.metrics import ( # noqa
|
||||||
|
RegularizationMetrics,
|
||||||
|
RegularizationMetricsBase,
|
||||||
|
ViewMetrics,
|
||||||
|
ViewMetricsBase,
|
||||||
|
)
|
||||||
from pytorch3d.implicitron.tools import image_utils, vis_utils
|
from pytorch3d.implicitron.tools import image_utils, vis_utils
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
expand_args_fields,
|
expand_args_fields,
|
||||||
@ -42,7 +48,7 @@ from .implicit_function.scene_representation_networks import ( # noqa
|
|||||||
SRNHyperNetImplicitFunction,
|
SRNHyperNetImplicitFunction,
|
||||||
SRNImplicitFunction,
|
SRNImplicitFunction,
|
||||||
)
|
)
|
||||||
from .metrics import ViewMetrics
|
|
||||||
from .renderer.base import (
|
from .renderer.base import (
|
||||||
BaseRenderer,
|
BaseRenderer,
|
||||||
EvaluationMode,
|
EvaluationMode,
|
||||||
@ -184,6 +190,14 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
is available in the global registry.
|
is available in the global registry.
|
||||||
implicit_function: An instance of ImplicitFunctionBase. The actual implicit functions
|
implicit_function: An instance of ImplicitFunctionBase. The actual implicit functions
|
||||||
are initialised to be in self._implicit_functions.
|
are initialised to be in self._implicit_functions.
|
||||||
|
view_metrics: An instance of ViewMetricsBase used to compute loss terms which
|
||||||
|
are independent of the model's parameters.
|
||||||
|
view_metrics_class_type: The type of view metrics to use, must be available in
|
||||||
|
the global registry.
|
||||||
|
regularization_metrics: An instance of RegularizationMetricsBase used to compute
|
||||||
|
regularization terms which can depend on the model's parameters.
|
||||||
|
regularization_metrics_class_type: The type of regularization metrics to use,
|
||||||
|
must be available in the global registry.
|
||||||
loss_weights: A dictionary with a {loss_name: weight} mapping; see documentation
|
loss_weights: A dictionary with a {loss_name: weight} mapping; see documentation
|
||||||
for `ViewMetrics` class for available loss functions.
|
for `ViewMetrics` class for available loss functions.
|
||||||
log_vars: A list of variable names which should be logged.
|
log_vars: A list of variable names which should be logged.
|
||||||
@ -232,6 +246,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
# The actual implicit functions live in self._implicit_functions
|
# The actual implicit functions live in self._implicit_functions
|
||||||
implicit_function: ImplicitFunctionBase
|
implicit_function: ImplicitFunctionBase
|
||||||
|
|
||||||
|
# ----- metrics
|
||||||
|
view_metrics: ViewMetricsBase
|
||||||
|
view_metrics_class_type: str = "ViewMetrics"
|
||||||
|
|
||||||
|
regularization_metrics: RegularizationMetricsBase
|
||||||
|
regularization_metrics_class_type: str = "RegularizationMetrics"
|
||||||
|
|
||||||
# ---- loss weights
|
# ---- loss weights
|
||||||
loss_weights: Dict[str, float] = field(
|
loss_weights: Dict[str, float] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
@ -269,7 +290,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.view_metrics = ViewMetrics()
|
|
||||||
|
|
||||||
if self.view_pooler_enabled:
|
if self.view_pooler_enabled:
|
||||||
if self.image_feature_extractor_class_type is None:
|
if self.image_feature_extractor_class_type is None:
|
||||||
@ -424,15 +444,31 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
for func in self._implicit_functions:
|
for func in self._implicit_functions:
|
||||||
func.unbind_args()
|
func.unbind_args()
|
||||||
|
|
||||||
preds = self._get_view_metrics(
|
# A dict to store losses as well as rendering results.
|
||||||
|
preds: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
def safe_slice_targets(
|
||||||
|
tensor: Optional[torch.Tensor],
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
return None if tensor is None else tensor[:n_targets]
|
||||||
|
|
||||||
|
preds.update(
|
||||||
|
self.view_metrics(
|
||||||
|
results=preds,
|
||||||
raymarched=rendered,
|
raymarched=rendered,
|
||||||
xys=ray_bundle.xys,
|
xys=ray_bundle.xys,
|
||||||
image_rgb=None if image_rgb is None else image_rgb[:n_targets],
|
image_rgb=safe_slice_targets(image_rgb),
|
||||||
depth_map=None if depth_map is None else depth_map[:n_targets],
|
depth_map=safe_slice_targets(depth_map),
|
||||||
fg_probability=None
|
fg_probability=safe_slice_targets(fg_probability),
|
||||||
if fg_probability is None
|
mask_crop=safe_slice_targets(mask_crop),
|
||||||
else fg_probability[:n_targets],
|
)
|
||||||
mask_crop=None if mask_crop is None else mask_crop[:n_targets],
|
)
|
||||||
|
|
||||||
|
preds.update(
|
||||||
|
self.regularization_metrics(
|
||||||
|
results=preds,
|
||||||
|
model=self,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if sampling_mode == RenderSamplingMode.MASK_SAMPLE:
|
if sampling_mode == RenderSamplingMode.MASK_SAMPLE:
|
||||||
@ -462,11 +498,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
else:
|
else:
|
||||||
raise AssertionError("Unreachable state")
|
raise AssertionError("Unreachable state")
|
||||||
|
|
||||||
# calc the AD penalty, returns None if autodecoder is not active
|
|
||||||
ad_penalty = self.sequence_autodecoder.calc_squared_encoding_norm()
|
|
||||||
if ad_penalty is not None:
|
|
||||||
preds["loss_autodecoder_norm"] = ad_penalty
|
|
||||||
|
|
||||||
# (7) Compute losses
|
# (7) Compute losses
|
||||||
# finally get the optimization objective using self.loss_weights
|
# finally get the optimization objective using self.loss_weights
|
||||||
objective = self._get_objective(preds)
|
objective = self._get_objective(preds)
|
||||||
@ -744,45 +775,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
|
|
||||||
return image_rgb, fg_mask, depth_map
|
return image_rgb, fg_mask, depth_map
|
||||||
|
|
||||||
def _get_view_metrics(
|
|
||||||
self,
|
|
||||||
raymarched: RendererOutput,
|
|
||||||
xys: torch.Tensor,
|
|
||||||
image_rgb: Optional[torch.Tensor] = None,
|
|
||||||
depth_map: Optional[torch.Tensor] = None,
|
|
||||||
fg_probability: Optional[torch.Tensor] = None,
|
|
||||||
mask_crop: Optional[torch.Tensor] = None,
|
|
||||||
keys_prefix: str = "loss_",
|
|
||||||
):
|
|
||||||
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
|
||||||
metrics = self.view_metrics(
|
|
||||||
image_sampling_grid=xys,
|
|
||||||
images_pred=raymarched.features,
|
|
||||||
images=image_rgb,
|
|
||||||
depths_pred=raymarched.depths,
|
|
||||||
depths=depth_map,
|
|
||||||
masks_pred=raymarched.masks,
|
|
||||||
masks=fg_probability,
|
|
||||||
masks_crop=mask_crop,
|
|
||||||
keys_prefix=keys_prefix,
|
|
||||||
**raymarched.aux,
|
|
||||||
)
|
|
||||||
|
|
||||||
if raymarched.prev_stage:
|
|
||||||
metrics.update(
|
|
||||||
self._get_view_metrics(
|
|
||||||
raymarched.prev_stage,
|
|
||||||
xys,
|
|
||||||
image_rgb,
|
|
||||||
depth_map,
|
|
||||||
fg_probability,
|
|
||||||
mask_crop,
|
|
||||||
keys_prefix=(keys_prefix + "prev_stage_"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return metrics
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _rasterize_mc_samples(
|
def _rasterize_mc_samples(
|
||||||
self,
|
self,
|
||||||
|
@ -6,64 +6,180 @@
|
|||||||
|
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.tools import metric_utils as utils
|
from pytorch3d.implicitron.tools import metric_utils as utils
|
||||||
|
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||||
from pytorch3d.renderer import utils as rend_utils
|
from pytorch3d.renderer import utils as rend_utils
|
||||||
|
|
||||||
|
from .renderer.base import RendererOutput
|
||||||
|
|
||||||
|
|
||||||
|
class RegularizationMetricsBase(ReplaceableBase, torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Replaceable abstract base for regularization metrics.
|
||||||
|
`forward()` method produces regularization metrics and (unlike ViewMetrics) can
|
||||||
|
depend on the model's parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, model: Any, keys_prefix: str = "loss_", **kwargs
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Calculates various regularization terms useful for supervising differentiable
|
||||||
|
rendering pipelines.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: A model instance. Useful, for example, to implement
|
||||||
|
weights-based regularization.
|
||||||
|
keys_prefix: A common prefix for all keys in the output dictionary
|
||||||
|
containing all regularization metrics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary with the resulting regularization metrics. The items
|
||||||
|
will have form `{metric_name_i: metric_value_i}` keyed by the
|
||||||
|
names of the output metrics `metric_name_i` with their corresponding
|
||||||
|
values `metric_value_i` represented as 0-dimensional float tensors.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class ViewMetricsBase(ReplaceableBase, torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Replaceable abstract base for model metrics.
|
||||||
|
`forward()` method produces losses and other metrics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
class ViewMetrics(torch.nn.Module):
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
image_sampling_grid: torch.Tensor,
|
raymarched: RendererOutput,
|
||||||
images: Optional[torch.Tensor] = None,
|
xys: torch.Tensor,
|
||||||
images_pred: Optional[torch.Tensor] = None,
|
image_rgb: Optional[torch.Tensor] = None,
|
||||||
depths: Optional[torch.Tensor] = None,
|
depth_map: Optional[torch.Tensor] = None,
|
||||||
depths_pred: Optional[torch.Tensor] = None,
|
fg_probability: Optional[torch.Tensor] = None,
|
||||||
masks: Optional[torch.Tensor] = None,
|
mask_crop: Optional[torch.Tensor] = None,
|
||||||
masks_pred: Optional[torch.Tensor] = None,
|
|
||||||
masks_crop: Optional[torch.Tensor] = None,
|
|
||||||
grad_theta: Optional[torch.Tensor] = None,
|
|
||||||
density_grid: Optional[torch.Tensor] = None,
|
|
||||||
keys_prefix: str = "loss_",
|
keys_prefix: str = "loss_",
|
||||||
mask_renders_by_pred: bool = False,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Calculates various metrics and loss functions useful for supervising
|
||||||
|
differentiable rendering pipelines. Any additional parameters can be passed
|
||||||
|
in the `raymarched.aux` dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: A dictionary with the resulting view metrics. The items
|
||||||
|
will have form `{metric_name_i: metric_value_i}` keyed by the
|
||||||
|
names of the output metrics `metric_name_i` with their corresponding
|
||||||
|
values `metric_value_i` represented as 0-dimensional float tensors.
|
||||||
|
raymarched: Output of the renderer.
|
||||||
|
xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which
|
||||||
|
the predictions are defined. All ground truth inputs are sampled at
|
||||||
|
these locations in order to extract values that correspond to the
|
||||||
|
predictions.
|
||||||
|
image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb
|
||||||
|
values.
|
||||||
|
depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth
|
||||||
|
values.
|
||||||
|
fg_probability: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth
|
||||||
|
foreground masks.
|
||||||
|
keys_prefix: A common prefix for all keys in the output dictionary
|
||||||
|
containing all view metrics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary with the resulting view metrics. The items
|
||||||
|
will have form `{metric_name_i: metric_value_i}` keyed by the
|
||||||
|
names of the output metrics `metric_name_i` with their corresponding
|
||||||
|
values `metric_value_i` represented as 0-dimensional float tensors.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class RegularizationMetrics(RegularizationMetricsBase):
|
||||||
|
def forward(
|
||||||
|
self, model: Any, keys_prefix: str = "loss_", **kwargs
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Calculates the AD penalty, or returns an empty dict if the model's autoencoder
|
||||||
|
is inactive.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: A model instance.
|
||||||
|
keys_prefix: A common prefix for all keys in the output dictionary
|
||||||
|
containing all regularization metrics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary with the resulting regularization metrics. The items
|
||||||
|
will have form `{metric_name_i: metric_value_i}` keyed by the
|
||||||
|
names of the output metrics `metric_name_i` with their corresponding
|
||||||
|
values `metric_value_i` represented as 0-dimensional float tensors.
|
||||||
|
|
||||||
|
The calculated metric is:
|
||||||
|
autoencoder_norm: Autoencoder weight norm regularization term.
|
||||||
|
"""
|
||||||
|
metrics = {}
|
||||||
|
if getattr(model, "sequence_autodecoder", None) is not None:
|
||||||
|
ad_penalty = model.sequence_autodecoder.calc_squared_encoding_norm()
|
||||||
|
if ad_penalty is not None:
|
||||||
|
metrics["autodecoder_norm"] = ad_penalty
|
||||||
|
|
||||||
|
if keys_prefix is not None:
|
||||||
|
metrics = {(keys_prefix + k): v for k, v in metrics.items()}
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class ViewMetrics(ViewMetricsBase):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
raymarched: RendererOutput,
|
||||||
|
xys: torch.Tensor,
|
||||||
|
image_rgb: Optional[torch.Tensor] = None,
|
||||||
|
depth_map: Optional[torch.Tensor] = None,
|
||||||
|
fg_probability: Optional[torch.Tensor] = None,
|
||||||
|
mask_crop: Optional[torch.Tensor] = None,
|
||||||
|
keys_prefix: str = "loss_",
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Calculates various differentiable metrics useful for supervising
|
Calculates various differentiable metrics useful for supervising
|
||||||
differentiable rendering pipelines.
|
differentiable rendering pipelines.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image_sampling_grid: A tensor of shape `(B, ..., 2)` containing 2D
|
results: A dict to store the results in.
|
||||||
image locations at which the predictions are defined.
|
raymarched.features: Predicted rgb or feature values.
|
||||||
All ground truth inputs are sampled at these
|
raymarched.depths: A tensor of shape `(B, ..., 1)` containing
|
||||||
locations in order to extract values that correspond
|
predicted depth values.
|
||||||
to the predictions.
|
raymarched.masks: A tensor of shape `(B, ..., 1)` containing
|
||||||
images: A tensor of shape `(B, H, W, 3)` containing ground truth
|
predicted foreground masks.
|
||||||
rgb values.
|
raymarched.aux["grad_theta"]: A tensor of shape `(B, ..., 3)` containing an
|
||||||
images_pred: A tensor of shape `(B, ..., 3)` containing predicted
|
evaluation of a gradient of a signed distance function w.r.t.
|
||||||
rgb values.
|
|
||||||
depths: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth
|
|
||||||
depth values.
|
|
||||||
depths_pred: A tensor of shape `(B, ..., 1)` containing predicted
|
|
||||||
depth values.
|
|
||||||
masks: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth
|
|
||||||
foreground masks.
|
|
||||||
masks_pred: A tensor of shape `(B, ..., 1)` containing predicted
|
|
||||||
foreground masks.
|
|
||||||
grad_theta: A tensor of shape `(B, ..., 3)` containing an evaluation
|
|
||||||
of a gradient of a signed distance function w.r.t.
|
|
||||||
input 3D coordinates used to compute the eikonal loss.
|
input 3D coordinates used to compute the eikonal loss.
|
||||||
density_grid: A tensor of shape `(B, Hg, Wg, Dg, 1)` containing a
|
raymarched.aux["density_grid"]: A tensor of shape `(B, Hg, Wg, Dg, 1)`
|
||||||
`Hg x Wg x Dg` voxel grid of density values.
|
containing a `Hg x Wg x Dg` voxel grid of density values.
|
||||||
|
xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which
|
||||||
|
the predictions are defined. All ground truth inputs are sampled at
|
||||||
|
these locations in order to extract values that correspond to the
|
||||||
|
predictions.
|
||||||
|
image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb
|
||||||
|
values.
|
||||||
|
depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth
|
||||||
|
values.
|
||||||
|
fg_probability: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth
|
||||||
|
foreground masks.
|
||||||
keys_prefix: A common prefix for all keys in the output dictionary
|
keys_prefix: A common prefix for all keys in the output dictionary
|
||||||
containing all metrics.
|
containing all view metrics.
|
||||||
mask_renders_by_pred: If `True`, masks rendered images by the predicted
|
|
||||||
`masks_pred` prior to computing all rgb metrics.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
metrics: A dictionary `{metric_name_i: metric_value_i}` keyed by the
|
A dictionary `{metric_name_i: metric_value_i}` keyed by the
|
||||||
names of the output metrics `metric_name_i` with their corresponding
|
names of the output metrics `metric_name_i` with their corresponding
|
||||||
values `metric_value_i` represented as 0-dimensional float tensors.
|
values `metric_value_i` represented as 0-dimensional float tensors.
|
||||||
|
|
||||||
@ -91,109 +207,142 @@ class ViewMetrics(torch.nn.Module):
|
|||||||
depth_neg_penalty: `min(depth_pred, 0)**2` penalizing negative
|
depth_neg_penalty: `min(depth_pred, 0)**2` penalizing negative
|
||||||
predicted depth values.
|
predicted depth values.
|
||||||
"""
|
"""
|
||||||
|
metrics = self._calculate_stage(
|
||||||
|
raymarched,
|
||||||
|
xys,
|
||||||
|
image_rgb,
|
||||||
|
depth_map,
|
||||||
|
fg_probability,
|
||||||
|
mask_crop,
|
||||||
|
keys_prefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
if raymarched.prev_stage:
|
||||||
|
metrics.update(
|
||||||
|
self(
|
||||||
|
raymarched.prev_stage,
|
||||||
|
xys,
|
||||||
|
image_rgb,
|
||||||
|
depth_map,
|
||||||
|
fg_probability,
|
||||||
|
mask_crop,
|
||||||
|
keys_prefix=(keys_prefix + "prev_stage_"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def _calculate_stage(
|
||||||
|
self,
|
||||||
|
raymarched: RendererOutput,
|
||||||
|
xys: torch.Tensor,
|
||||||
|
image_rgb: Optional[torch.Tensor] = None,
|
||||||
|
depth_map: Optional[torch.Tensor] = None,
|
||||||
|
fg_probability: Optional[torch.Tensor] = None,
|
||||||
|
mask_crop: Optional[torch.Tensor] = None,
|
||||||
|
keys_prefix: str = "loss_",
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Calculate metrics for the current stage.
|
||||||
|
"""
|
||||||
# TODO: extract functions
|
# TODO: extract functions
|
||||||
|
|
||||||
# reshape from B x ... x DIM to B x DIM x -1 x 1
|
# reshape from B x ... x DIM to B x DIM x -1 x 1
|
||||||
images_pred, masks_pred, depths_pred = [
|
image_rgb_pred, fg_probability_pred, depth_map_pred = [
|
||||||
_reshape_nongrid_var(x) for x in [images_pred, masks_pred, depths_pred]
|
_reshape_nongrid_var(x)
|
||||||
|
for x in [raymarched.features, raymarched.masks, raymarched.depths]
|
||||||
]
|
]
|
||||||
# reshape the sampling grid as well
|
# reshape the sampling grid as well
|
||||||
# TODO: we can get rid of the singular dimension here and in _reshape_nongrid_var
|
# TODO: we can get rid of the singular dimension here and in _reshape_nongrid_var
|
||||||
# now that we use rend_utils.ndc_grid_sample
|
# now that we use rend_utils.ndc_grid_sample
|
||||||
image_sampling_grid = image_sampling_grid.reshape(
|
xys = xys.reshape(xys.shape[0], -1, 1, 2)
|
||||||
image_sampling_grid.shape[0], -1, 1, 2
|
|
||||||
)
|
|
||||||
|
|
||||||
# closure with the given image_sampling_grid
|
# closure with the given xys
|
||||||
def sample(tensor, mode):
|
def sample(tensor, mode):
|
||||||
if tensor is None:
|
if tensor is None:
|
||||||
return tensor
|
return tensor
|
||||||
return rend_utils.ndc_grid_sample(tensor, image_sampling_grid, mode=mode)
|
return rend_utils.ndc_grid_sample(tensor, xys, mode=mode)
|
||||||
|
|
||||||
# eval all results in this size
|
# eval all results in this size
|
||||||
images = sample(images, mode="bilinear")
|
image_rgb = sample(image_rgb, mode="bilinear")
|
||||||
depths = sample(depths, mode="nearest")
|
depth_map = sample(depth_map, mode="nearest")
|
||||||
masks = sample(masks, mode="nearest")
|
fg_probability = sample(fg_probability, mode="nearest")
|
||||||
masks_crop = sample(masks_crop, mode="nearest")
|
mask_crop = sample(mask_crop, mode="nearest")
|
||||||
if masks_crop is None and images_pred is not None:
|
if mask_crop is None and image_rgb_pred is not None:
|
||||||
masks_crop = torch.ones_like(images_pred[:, :1])
|
mask_crop = torch.ones_like(image_rgb_pred[:, :1])
|
||||||
if masks_crop is None and depths_pred is not None:
|
if mask_crop is None and depth_map_pred is not None:
|
||||||
masks_crop = torch.ones_like(depths_pred[:, :1])
|
mask_crop = torch.ones_like(depth_map_pred[:, :1])
|
||||||
|
|
||||||
preds = {}
|
metrics = {}
|
||||||
if images is not None and images_pred is not None:
|
if image_rgb is not None and image_rgb_pred is not None:
|
||||||
# TODO: mask_renders_by_pred is always false; simplify
|
metrics.update(
|
||||||
preds.update(
|
|
||||||
_rgb_metrics(
|
_rgb_metrics(
|
||||||
images,
|
image_rgb,
|
||||||
images_pred,
|
image_rgb_pred,
|
||||||
masks,
|
fg_probability,
|
||||||
masks_pred,
|
fg_probability_pred,
|
||||||
masks_crop,
|
mask_crop,
|
||||||
mask_renders_by_pred,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if masks_pred is not None:
|
if fg_probability_pred is not None:
|
||||||
preds["mask_beta_prior"] = utils.beta_prior(masks_pred)
|
metrics["mask_beta_prior"] = utils.beta_prior(fg_probability_pred)
|
||||||
if masks is not None and masks_pred is not None:
|
if fg_probability is not None and fg_probability_pred is not None:
|
||||||
preds["mask_neg_iou"] = utils.neg_iou_loss(
|
metrics["mask_neg_iou"] = utils.neg_iou_loss(
|
||||||
masks_pred, masks, mask=masks_crop
|
fg_probability_pred, fg_probability, mask=mask_crop
|
||||||
|
)
|
||||||
|
metrics["mask_bce"] = utils.calc_bce(
|
||||||
|
fg_probability_pred, fg_probability, mask=mask_crop
|
||||||
)
|
)
|
||||||
preds["mask_bce"] = utils.calc_bce(masks_pred, masks, mask=masks_crop)
|
|
||||||
|
|
||||||
if depths is not None and depths_pred is not None:
|
if depth_map is not None and depth_map_pred is not None:
|
||||||
assert masks_crop is not None
|
assert mask_crop is not None
|
||||||
_, abs_ = utils.eval_depth(
|
_, abs_ = utils.eval_depth(
|
||||||
depths_pred, depths, get_best_scale=True, mask=masks_crop, crop=0
|
depth_map_pred, depth_map, get_best_scale=True, mask=mask_crop, crop=0
|
||||||
)
|
)
|
||||||
preds["depth_abs"] = abs_.mean()
|
metrics["depth_abs"] = abs_.mean()
|
||||||
|
|
||||||
if masks is not None:
|
if fg_probability is not None:
|
||||||
mask = masks * masks_crop
|
mask = fg_probability * mask_crop
|
||||||
_, abs_ = utils.eval_depth(
|
_, abs_ = utils.eval_depth(
|
||||||
depths_pred, depths, get_best_scale=True, mask=mask, crop=0
|
depth_map_pred, depth_map, get_best_scale=True, mask=mask, crop=0
|
||||||
)
|
)
|
||||||
preds["depth_abs_fg"] = abs_.mean()
|
metrics["depth_abs_fg"] = abs_.mean()
|
||||||
|
|
||||||
# regularizers
|
# regularizers
|
||||||
|
grad_theta = raymarched.aux.get("grad_theta")
|
||||||
if grad_theta is not None:
|
if grad_theta is not None:
|
||||||
preds["eikonal"] = _get_eikonal_loss(grad_theta)
|
metrics["eikonal"] = _get_eikonal_loss(grad_theta)
|
||||||
|
|
||||||
|
density_grid = raymarched.aux.get("density_grid")
|
||||||
if density_grid is not None:
|
if density_grid is not None:
|
||||||
preds["density_tv"] = _get_grid_tv_loss(density_grid)
|
metrics["density_tv"] = _get_grid_tv_loss(density_grid)
|
||||||
|
|
||||||
if depths_pred is not None:
|
if depth_map_pred is not None:
|
||||||
preds["depth_neg_penalty"] = _get_depth_neg_penalty_loss(depths_pred)
|
metrics["depth_neg_penalty"] = _get_depth_neg_penalty_loss(depth_map_pred)
|
||||||
|
|
||||||
if keys_prefix is not None:
|
if keys_prefix is not None:
|
||||||
preds = {(keys_prefix + k): v for k, v in preds.items()}
|
metrics = {(keys_prefix + k): v for k, v in metrics.items()}
|
||||||
|
|
||||||
return preds
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
def _rgb_metrics(
|
def _rgb_metrics(images, images_pred, masks, masks_pred, masks_crop):
|
||||||
images, images_pred, masks, masks_pred, masks_crop, mask_renders_by_pred
|
|
||||||
):
|
|
||||||
assert masks_crop is not None
|
assert masks_crop is not None
|
||||||
if mask_renders_by_pred:
|
|
||||||
images = images[..., masks_pred.reshape(-1), :]
|
|
||||||
masks_crop = masks_crop[..., masks_pred.reshape(-1), :]
|
|
||||||
masks = masks is not None and masks[..., masks_pred.reshape(-1), :]
|
|
||||||
rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True)
|
rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True)
|
||||||
rgb_loss = utils.huber(rgb_squared, scaling=0.03)
|
rgb_loss = utils.huber(rgb_squared, scaling=0.03)
|
||||||
crop_mass = masks_crop.sum().clamp(1.0)
|
crop_mass = masks_crop.sum().clamp(1.0)
|
||||||
preds = {
|
results = {
|
||||||
"rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass,
|
"rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass,
|
||||||
"rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass,
|
"rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass,
|
||||||
"rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop),
|
"rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop),
|
||||||
}
|
}
|
||||||
if masks is not None:
|
if masks is not None:
|
||||||
masks = masks_crop * masks
|
masks = masks_crop * masks
|
||||||
preds["rgb_psnr_fg"] = utils.calc_psnr(images_pred, images, mask=masks)
|
results["rgb_psnr_fg"] = utils.calc_psnr(images_pred, images, mask=masks)
|
||||||
preds["rgb_mse_fg"] = (rgb_squared * masks).sum() / masks.sum().clamp(1.0)
|
results["rgb_mse_fg"] = (rgb_squared * masks).sum() / masks.sum().clamp(1.0)
|
||||||
return preds
|
return results
|
||||||
|
|
||||||
|
|
||||||
def _get_eikonal_loss(grad_theta):
|
def _get_eikonal_loss(grad_theta):
|
||||||
|
@ -20,6 +20,8 @@ renderer_class_type: LSTMRenderer
|
|||||||
image_feature_extractor_class_type: ResNetFeatureExtractor
|
image_feature_extractor_class_type: ResNetFeatureExtractor
|
||||||
view_pooler_enabled: true
|
view_pooler_enabled: true
|
||||||
implicit_function_class_type: IdrFeatureField
|
implicit_function_class_type: IdrFeatureField
|
||||||
|
view_metrics_class_type: ViewMetrics
|
||||||
|
regularization_metrics_class_type: RegularizationMetrics
|
||||||
loss_weights:
|
loss_weights:
|
||||||
loss_rgb_mse: 1.0
|
loss_rgb_mse: 1.0
|
||||||
loss_prev_stage_rgb_mse: 1.0
|
loss_prev_stage_rgb_mse: 1.0
|
||||||
@ -122,3 +124,5 @@ implicit_function_IdrFeatureField_args:
|
|||||||
n_harmonic_functions_xyz: 1729
|
n_harmonic_functions_xyz: 1729
|
||||||
pooled_feature_dim: 0
|
pooled_feature_dim: 0
|
||||||
encoding_dim: 0
|
encoding_dim: 0
|
||||||
|
view_metrics_ViewMetrics_args: {}
|
||||||
|
regularization_metrics_RegularizationMetrics_args: {}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user