Refactor ViewMetrics

Summary:
Make ViewMetrics easy to replace by putting them into an OmegaConf dataclass.

Also, re-word a few variable names and fix minor TODOs.

Reviewed By: bottler

Differential Revision: D37327157

fbshipit-source-id: 78d8e39bbb3548b952f10abbe05688409fb987cc
This commit is contained in:
Krzysztof Chalupka 2022-06-30 09:22:01 -07:00 committed by Facebook GitHub Bot
parent f4dd151037
commit ae35824f21
5 changed files with 301 additions and 151 deletions

View File

@ -21,6 +21,8 @@ generic_model_args:
image_feature_extractor_class_type: null
view_pooler_enabled: false
implicit_function_class_type: NeuralRadianceFieldImplicitFunction
view_metrics_class_type: ViewMetrics
regularization_metrics_class_type: RegularizationMetrics
loss_weights:
loss_rgb_mse: 1.0
loss_prev_stage_rgb_mse: 1.0
@ -268,6 +270,8 @@ generic_model_args:
in_features: 256
out_features: 3
ray_dir_in_camera_coords: false
view_metrics_ViewMetrics_args: {}
regularization_metrics_RegularizationMetrics_args: {}
solver_args:
breed: adam
weight_decay: 0.0

View File

@ -38,7 +38,8 @@ class ImplicitronRender:
class ImplicitronModelBase(ReplaceableBase):
"""Replaceable abstract base for all image generation / rendering models.
"""
Replaceable abstract base for all image generation / rendering models.
`forward()` method produces a render with a depth map.
"""

View File

@ -16,6 +16,12 @@ from typing import Any, Dict, List, Optional, Tuple
import torch
import tqdm
from pytorch3d.implicitron.models.metrics import ( # noqa
RegularizationMetrics,
RegularizationMetricsBase,
ViewMetrics,
ViewMetricsBase,
)
from pytorch3d.implicitron.tools import image_utils, vis_utils
from pytorch3d.implicitron.tools.config import (
expand_args_fields,
@ -42,7 +48,7 @@ from .implicit_function.scene_representation_networks import ( # noqa
SRNHyperNetImplicitFunction,
SRNImplicitFunction,
)
from .metrics import ViewMetrics
from .renderer.base import (
BaseRenderer,
EvaluationMode,
@ -184,6 +190,14 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
is available in the global registry.
implicit_function: An instance of ImplicitFunctionBase. The actual implicit functions
are initialised to be in self._implicit_functions.
view_metrics: An instance of ViewMetricsBase used to compute loss terms which
are independent of the model's parameters.
view_metrics_class_type: The type of view metrics to use, must be available in
the global registry.
regularization_metrics: An instance of RegularizationMetricsBase used to compute
regularization terms which can depend on the model's parameters.
regularization_metrics_class_type: The type of regularization metrics to use,
must be available in the global registry.
loss_weights: A dictionary with a {loss_name: weight} mapping; see documentation
for `ViewMetrics` class for available loss functions.
log_vars: A list of variable names which should be logged.
@ -232,6 +246,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
# The actual implicit functions live in self._implicit_functions
implicit_function: ImplicitFunctionBase
# ----- metrics
view_metrics: ViewMetricsBase
view_metrics_class_type: str = "ViewMetrics"
regularization_metrics: RegularizationMetricsBase
regularization_metrics_class_type: str = "RegularizationMetrics"
# ---- loss weights
loss_weights: Dict[str, float] = field(
default_factory=lambda: {
@ -269,7 +290,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
def __post_init__(self):
super().__init__()
self.view_metrics = ViewMetrics()
if self.view_pooler_enabled:
if self.image_feature_extractor_class_type is None:
@ -424,15 +444,31 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
for func in self._implicit_functions:
func.unbind_args()
preds = self._get_view_metrics(
raymarched=rendered,
xys=ray_bundle.xys,
image_rgb=None if image_rgb is None else image_rgb[:n_targets],
depth_map=None if depth_map is None else depth_map[:n_targets],
fg_probability=None
if fg_probability is None
else fg_probability[:n_targets],
mask_crop=None if mask_crop is None else mask_crop[:n_targets],
# A dict to store losses as well as rendering results.
preds: Dict[str, Any] = {}
def safe_slice_targets(
tensor: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
return None if tensor is None else tensor[:n_targets]
preds.update(
self.view_metrics(
results=preds,
raymarched=rendered,
xys=ray_bundle.xys,
image_rgb=safe_slice_targets(image_rgb),
depth_map=safe_slice_targets(depth_map),
fg_probability=safe_slice_targets(fg_probability),
mask_crop=safe_slice_targets(mask_crop),
)
)
preds.update(
self.regularization_metrics(
results=preds,
model=self,
)
)
if sampling_mode == RenderSamplingMode.MASK_SAMPLE:
@ -462,11 +498,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
else:
raise AssertionError("Unreachable state")
# calc the AD penalty, returns None if autodecoder is not active
ad_penalty = self.sequence_autodecoder.calc_squared_encoding_norm()
if ad_penalty is not None:
preds["loss_autodecoder_norm"] = ad_penalty
# (7) Compute losses
# finally get the optimization objective using self.loss_weights
objective = self._get_objective(preds)
@ -744,45 +775,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
return image_rgb, fg_mask, depth_map
def _get_view_metrics(
self,
raymarched: RendererOutput,
xys: torch.Tensor,
image_rgb: Optional[torch.Tensor] = None,
depth_map: Optional[torch.Tensor] = None,
fg_probability: Optional[torch.Tensor] = None,
mask_crop: Optional[torch.Tensor] = None,
keys_prefix: str = "loss_",
):
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
metrics = self.view_metrics(
image_sampling_grid=xys,
images_pred=raymarched.features,
images=image_rgb,
depths_pred=raymarched.depths,
depths=depth_map,
masks_pred=raymarched.masks,
masks=fg_probability,
masks_crop=mask_crop,
keys_prefix=keys_prefix,
**raymarched.aux,
)
if raymarched.prev_stage:
metrics.update(
self._get_view_metrics(
raymarched.prev_stage,
xys,
image_rgb,
depth_map,
fg_probability,
mask_crop,
keys_prefix=(keys_prefix + "prev_stage_"),
)
)
return metrics
@torch.no_grad()
def _rasterize_mc_samples(
self,

View File

@ -6,64 +6,180 @@
import warnings
from typing import Dict, Optional
from typing import Any, Dict, Optional
import torch
from pytorch3d.implicitron.tools import metric_utils as utils
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from pytorch3d.renderer import utils as rend_utils
from .renderer.base import RendererOutput
class RegularizationMetricsBase(ReplaceableBase, torch.nn.Module):
"""
Replaceable abstract base for regularization metrics.
`forward()` method produces regularization metrics and (unlike ViewMetrics) can
depend on the model's parameters.
"""
def __post_init__(self) -> None:
super().__init__()
def forward(
self, model: Any, keys_prefix: str = "loss_", **kwargs
) -> Dict[str, Any]:
"""
Calculates various regularization terms useful for supervising differentiable
rendering pipelines.
Args:
model: A model instance. Useful, for example, to implement
weights-based regularization.
keys_prefix: A common prefix for all keys in the output dictionary
containing all regularization metrics.
Returns:
A dictionary with the resulting regularization metrics. The items
will have form `{metric_name_i: metric_value_i}` keyed by the
names of the output metrics `metric_name_i` with their corresponding
values `metric_value_i` represented as 0-dimensional float tensors.
"""
raise NotImplementedError
class ViewMetricsBase(ReplaceableBase, torch.nn.Module):
"""
Replaceable abstract base for model metrics.
`forward()` method produces losses and other metrics.
"""
def __post_init__(self) -> None:
super().__init__()
class ViewMetrics(torch.nn.Module):
def forward(
self,
image_sampling_grid: torch.Tensor,
images: Optional[torch.Tensor] = None,
images_pred: Optional[torch.Tensor] = None,
depths: Optional[torch.Tensor] = None,
depths_pred: Optional[torch.Tensor] = None,
masks: Optional[torch.Tensor] = None,
masks_pred: Optional[torch.Tensor] = None,
masks_crop: Optional[torch.Tensor] = None,
grad_theta: Optional[torch.Tensor] = None,
density_grid: Optional[torch.Tensor] = None,
raymarched: RendererOutput,
xys: torch.Tensor,
image_rgb: Optional[torch.Tensor] = None,
depth_map: Optional[torch.Tensor] = None,
fg_probability: Optional[torch.Tensor] = None,
mask_crop: Optional[torch.Tensor] = None,
keys_prefix: str = "loss_",
mask_renders_by_pred: bool = False,
**kwargs,
) -> Dict[str, torch.Tensor]:
) -> Dict[str, Any]:
"""
Calculates various metrics and loss functions useful for supervising
differentiable rendering pipelines. Any additional parameters can be passed
in the `raymarched.aux` dictionary.
Args:
results: A dictionary with the resulting view metrics. The items
will have form `{metric_name_i: metric_value_i}` keyed by the
names of the output metrics `metric_name_i` with their corresponding
values `metric_value_i` represented as 0-dimensional float tensors.
raymarched: Output of the renderer.
xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which
the predictions are defined. All ground truth inputs are sampled at
these locations in order to extract values that correspond to the
predictions.
image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb
values.
depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth
values.
fg_probability: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth
foreground masks.
keys_prefix: A common prefix for all keys in the output dictionary
containing all view metrics.
Returns:
A dictionary with the resulting view metrics. The items
will have form `{metric_name_i: metric_value_i}` keyed by the
names of the output metrics `metric_name_i` with their corresponding
values `metric_value_i` represented as 0-dimensional float tensors.
"""
raise NotImplementedError()
@registry.register
class RegularizationMetrics(RegularizationMetricsBase):
def forward(
self, model: Any, keys_prefix: str = "loss_", **kwargs
) -> Dict[str, Any]:
"""
Calculates the AD penalty, or returns an empty dict if the model's autoencoder
is inactive.
Args:
model: A model instance.
keys_prefix: A common prefix for all keys in the output dictionary
containing all regularization metrics.
Returns:
A dictionary with the resulting regularization metrics. The items
will have form `{metric_name_i: metric_value_i}` keyed by the
names of the output metrics `metric_name_i` with their corresponding
values `metric_value_i` represented as 0-dimensional float tensors.
The calculated metric is:
autoencoder_norm: Autoencoder weight norm regularization term.
"""
metrics = {}
if getattr(model, "sequence_autodecoder", None) is not None:
ad_penalty = model.sequence_autodecoder.calc_squared_encoding_norm()
if ad_penalty is not None:
metrics["autodecoder_norm"] = ad_penalty
if keys_prefix is not None:
metrics = {(keys_prefix + k): v for k, v in metrics.items()}
return metrics
@registry.register
class ViewMetrics(ViewMetricsBase):
def forward(
self,
raymarched: RendererOutput,
xys: torch.Tensor,
image_rgb: Optional[torch.Tensor] = None,
depth_map: Optional[torch.Tensor] = None,
fg_probability: Optional[torch.Tensor] = None,
mask_crop: Optional[torch.Tensor] = None,
keys_prefix: str = "loss_",
**kwargs,
) -> Dict[str, Any]:
"""
Calculates various differentiable metrics useful for supervising
differentiable rendering pipelines.
Args:
image_sampling_grid: A tensor of shape `(B, ..., 2)` containing 2D
image locations at which the predictions are defined.
All ground truth inputs are sampled at these
locations in order to extract values that correspond
to the predictions.
images: A tensor of shape `(B, H, W, 3)` containing ground truth
rgb values.
images_pred: A tensor of shape `(B, ..., 3)` containing predicted
rgb values.
depths: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth
depth values.
depths_pred: A tensor of shape `(B, ..., 1)` containing predicted
depth values.
masks: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth
foreground masks.
masks_pred: A tensor of shape `(B, ..., 1)` containing predicted
foreground masks.
grad_theta: A tensor of shape `(B, ..., 3)` containing an evaluation
of a gradient of a signed distance function w.r.t.
results: A dict to store the results in.
raymarched.features: Predicted rgb or feature values.
raymarched.depths: A tensor of shape `(B, ..., 1)` containing
predicted depth values.
raymarched.masks: A tensor of shape `(B, ..., 1)` containing
predicted foreground masks.
raymarched.aux["grad_theta"]: A tensor of shape `(B, ..., 3)` containing an
evaluation of a gradient of a signed distance function w.r.t.
input 3D coordinates used to compute the eikonal loss.
density_grid: A tensor of shape `(B, Hg, Wg, Dg, 1)` containing a
`Hg x Wg x Dg` voxel grid of density values.
raymarched.aux["density_grid"]: A tensor of shape `(B, Hg, Wg, Dg, 1)`
containing a `Hg x Wg x Dg` voxel grid of density values.
xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which
the predictions are defined. All ground truth inputs are sampled at
these locations in order to extract values that correspond to the
predictions.
image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb
values.
depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth
values.
fg_probability: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth
foreground masks.
keys_prefix: A common prefix for all keys in the output dictionary
containing all metrics.
mask_renders_by_pred: If `True`, masks rendered images by the predicted
`masks_pred` prior to computing all rgb metrics.
containing all view metrics.
Returns:
metrics: A dictionary `{metric_name_i: metric_value_i}` keyed by the
A dictionary `{metric_name_i: metric_value_i}` keyed by the
names of the output metrics `metric_name_i` with their corresponding
values `metric_value_i` represented as 0-dimensional float tensors.
@ -91,109 +207,142 @@ class ViewMetrics(torch.nn.Module):
depth_neg_penalty: `min(depth_pred, 0)**2` penalizing negative
predicted depth values.
"""
metrics = self._calculate_stage(
raymarched,
xys,
image_rgb,
depth_map,
fg_probability,
mask_crop,
keys_prefix,
)
if raymarched.prev_stage:
metrics.update(
self(
raymarched.prev_stage,
xys,
image_rgb,
depth_map,
fg_probability,
mask_crop,
keys_prefix=(keys_prefix + "prev_stage_"),
)
)
return metrics
def _calculate_stage(
self,
raymarched: RendererOutput,
xys: torch.Tensor,
image_rgb: Optional[torch.Tensor] = None,
depth_map: Optional[torch.Tensor] = None,
fg_probability: Optional[torch.Tensor] = None,
mask_crop: Optional[torch.Tensor] = None,
keys_prefix: str = "loss_",
**kwargs,
) -> Dict[str, Any]:
"""
Calculate metrics for the current stage.
"""
# TODO: extract functions
# reshape from B x ... x DIM to B x DIM x -1 x 1
images_pred, masks_pred, depths_pred = [
_reshape_nongrid_var(x) for x in [images_pred, masks_pred, depths_pred]
image_rgb_pred, fg_probability_pred, depth_map_pred = [
_reshape_nongrid_var(x)
for x in [raymarched.features, raymarched.masks, raymarched.depths]
]
# reshape the sampling grid as well
# TODO: we can get rid of the singular dimension here and in _reshape_nongrid_var
# now that we use rend_utils.ndc_grid_sample
image_sampling_grid = image_sampling_grid.reshape(
image_sampling_grid.shape[0], -1, 1, 2
)
xys = xys.reshape(xys.shape[0], -1, 1, 2)
# closure with the given image_sampling_grid
# closure with the given xys
def sample(tensor, mode):
if tensor is None:
return tensor
return rend_utils.ndc_grid_sample(tensor, image_sampling_grid, mode=mode)
return rend_utils.ndc_grid_sample(tensor, xys, mode=mode)
# eval all results in this size
images = sample(images, mode="bilinear")
depths = sample(depths, mode="nearest")
masks = sample(masks, mode="nearest")
masks_crop = sample(masks_crop, mode="nearest")
if masks_crop is None and images_pred is not None:
masks_crop = torch.ones_like(images_pred[:, :1])
if masks_crop is None and depths_pred is not None:
masks_crop = torch.ones_like(depths_pred[:, :1])
image_rgb = sample(image_rgb, mode="bilinear")
depth_map = sample(depth_map, mode="nearest")
fg_probability = sample(fg_probability, mode="nearest")
mask_crop = sample(mask_crop, mode="nearest")
if mask_crop is None and image_rgb_pred is not None:
mask_crop = torch.ones_like(image_rgb_pred[:, :1])
if mask_crop is None and depth_map_pred is not None:
mask_crop = torch.ones_like(depth_map_pred[:, :1])
preds = {}
if images is not None and images_pred is not None:
# TODO: mask_renders_by_pred is always false; simplify
preds.update(
metrics = {}
if image_rgb is not None and image_rgb_pred is not None:
metrics.update(
_rgb_metrics(
images,
images_pred,
masks,
masks_pred,
masks_crop,
mask_renders_by_pred,
image_rgb,
image_rgb_pred,
fg_probability,
fg_probability_pred,
mask_crop,
)
)
if masks_pred is not None:
preds["mask_beta_prior"] = utils.beta_prior(masks_pred)
if masks is not None and masks_pred is not None:
preds["mask_neg_iou"] = utils.neg_iou_loss(
masks_pred, masks, mask=masks_crop
if fg_probability_pred is not None:
metrics["mask_beta_prior"] = utils.beta_prior(fg_probability_pred)
if fg_probability is not None and fg_probability_pred is not None:
metrics["mask_neg_iou"] = utils.neg_iou_loss(
fg_probability_pred, fg_probability, mask=mask_crop
)
metrics["mask_bce"] = utils.calc_bce(
fg_probability_pred, fg_probability, mask=mask_crop
)
preds["mask_bce"] = utils.calc_bce(masks_pred, masks, mask=masks_crop)
if depths is not None and depths_pred is not None:
assert masks_crop is not None
if depth_map is not None and depth_map_pred is not None:
assert mask_crop is not None
_, abs_ = utils.eval_depth(
depths_pred, depths, get_best_scale=True, mask=masks_crop, crop=0
depth_map_pred, depth_map, get_best_scale=True, mask=mask_crop, crop=0
)
preds["depth_abs"] = abs_.mean()
metrics["depth_abs"] = abs_.mean()
if masks is not None:
mask = masks * masks_crop
if fg_probability is not None:
mask = fg_probability * mask_crop
_, abs_ = utils.eval_depth(
depths_pred, depths, get_best_scale=True, mask=mask, crop=0
depth_map_pred, depth_map, get_best_scale=True, mask=mask, crop=0
)
preds["depth_abs_fg"] = abs_.mean()
metrics["depth_abs_fg"] = abs_.mean()
# regularizers
grad_theta = raymarched.aux.get("grad_theta")
if grad_theta is not None:
preds["eikonal"] = _get_eikonal_loss(grad_theta)
metrics["eikonal"] = _get_eikonal_loss(grad_theta)
density_grid = raymarched.aux.get("density_grid")
if density_grid is not None:
preds["density_tv"] = _get_grid_tv_loss(density_grid)
metrics["density_tv"] = _get_grid_tv_loss(density_grid)
if depths_pred is not None:
preds["depth_neg_penalty"] = _get_depth_neg_penalty_loss(depths_pred)
if depth_map_pred is not None:
metrics["depth_neg_penalty"] = _get_depth_neg_penalty_loss(depth_map_pred)
if keys_prefix is not None:
preds = {(keys_prefix + k): v for k, v in preds.items()}
metrics = {(keys_prefix + k): v for k, v in metrics.items()}
return preds
return metrics
def _rgb_metrics(
images, images_pred, masks, masks_pred, masks_crop, mask_renders_by_pred
):
def _rgb_metrics(images, images_pred, masks, masks_pred, masks_crop):
assert masks_crop is not None
if mask_renders_by_pred:
images = images[..., masks_pred.reshape(-1), :]
masks_crop = masks_crop[..., masks_pred.reshape(-1), :]
masks = masks is not None and masks[..., masks_pred.reshape(-1), :]
rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True)
rgb_loss = utils.huber(rgb_squared, scaling=0.03)
crop_mass = masks_crop.sum().clamp(1.0)
preds = {
results = {
"rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass,
"rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass,
"rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop),
}
if masks is not None:
masks = masks_crop * masks
preds["rgb_psnr_fg"] = utils.calc_psnr(images_pred, images, mask=masks)
preds["rgb_mse_fg"] = (rgb_squared * masks).sum() / masks.sum().clamp(1.0)
return preds
results["rgb_psnr_fg"] = utils.calc_psnr(images_pred, images, mask=masks)
results["rgb_mse_fg"] = (rgb_squared * masks).sum() / masks.sum().clamp(1.0)
return results
def _get_eikonal_loss(grad_theta):

View File

@ -20,6 +20,8 @@ renderer_class_type: LSTMRenderer
image_feature_extractor_class_type: ResNetFeatureExtractor
view_pooler_enabled: true
implicit_function_class_type: IdrFeatureField
view_metrics_class_type: ViewMetrics
regularization_metrics_class_type: RegularizationMetrics
loss_weights:
loss_rgb_mse: 1.0
loss_prev_stage_rgb_mse: 1.0
@ -122,3 +124,5 @@ implicit_function_IdrFeatureField_args:
n_harmonic_functions_xyz: 1729
pooled_feature_dim: 0
encoding_dim: 0
view_metrics_ViewMetrics_args: {}
regularization_metrics_RegularizationMetrics_args: {}