mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
Summary: GenericModel crashes in case the `aux` field of any Renderer is populated. This is because the `rendered.aux` is unpacked to ViewMetrics.forward whose signature does not contain **kwargs. Hence, the contents of `aux` are unknown to forward's signature resulting in a crash. Reviewed By: bottler Differential Revision: D36166118 fbshipit-source-id: 906a067ea02a3648a69667422466451bc219ebf6
231 lines
9.6 KiB
Python
231 lines
9.6 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import warnings
|
|
from typing import Dict, Optional
|
|
|
|
import torch
|
|
from pytorch3d.implicitron.tools import metric_utils as utils
|
|
from pytorch3d.renderer import utils as rend_utils
|
|
|
|
|
|
class ViewMetrics(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
image_sampling_grid: torch.Tensor,
|
|
images: Optional[torch.Tensor] = None,
|
|
images_pred: Optional[torch.Tensor] = None,
|
|
depths: Optional[torch.Tensor] = None,
|
|
depths_pred: Optional[torch.Tensor] = None,
|
|
masks: Optional[torch.Tensor] = None,
|
|
masks_pred: Optional[torch.Tensor] = None,
|
|
masks_crop: Optional[torch.Tensor] = None,
|
|
grad_theta: Optional[torch.Tensor] = None,
|
|
density_grid: Optional[torch.Tensor] = None,
|
|
keys_prefix: str = "loss_",
|
|
mask_renders_by_pred: bool = False,
|
|
**kwargs,
|
|
) -> Dict[str, torch.Tensor]:
|
|
"""
|
|
Calculates various differentiable metrics useful for supervising
|
|
differentiable rendering pipelines.
|
|
|
|
Args:
|
|
image_sampling_grid: A tensor of shape `(B, ..., 2)` containing 2D
|
|
image locations at which the predictions are defined.
|
|
All ground truth inputs are sampled at these
|
|
locations in order to extract values that correspond
|
|
to the predictions.
|
|
images: A tensor of shape `(B, H, W, 3)` containing ground truth
|
|
rgb values.
|
|
images_pred: A tensor of shape `(B, ..., 3)` containing predicted
|
|
rgb values.
|
|
depths: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth
|
|
depth values.
|
|
depths_pred: A tensor of shape `(B, ..., 1)` containing predicted
|
|
depth values.
|
|
masks: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth
|
|
foreground masks.
|
|
masks_pred: A tensor of shape `(B, ..., 1)` containing predicted
|
|
foreground masks.
|
|
grad_theta: A tensor of shape `(B, ..., 3)` containing an evaluation
|
|
of a gradient of a signed distance function w.r.t.
|
|
input 3D coordinates used to compute the eikonal loss.
|
|
density_grid: A tensor of shape `(B, Hg, Wg, Dg, 1)` containing a
|
|
`Hg x Wg x Dg` voxel grid of density values.
|
|
keys_prefix: A common prefix for all keys in the output dictionary
|
|
containing all metrics.
|
|
mask_renders_by_pred: If `True`, masks rendered images by the predicted
|
|
`masks_pred` prior to computing all rgb metrics.
|
|
|
|
Returns:
|
|
metrics: A dictionary `{metric_name_i: metric_value_i}` keyed by the
|
|
names of the output metrics `metric_name_i` with their corresponding
|
|
values `metric_value_i` represented as 0-dimensional float tensors.
|
|
|
|
The calculated metrics are:
|
|
rgb_huber: A robust huber loss between `image_pred` and `image`.
|
|
rgb_mse: Mean squared error between `image_pred` and `image`.
|
|
rgb_psnr: Peak signal-to-noise ratio between `image_pred` and `image`.
|
|
rgb_psnr_fg: Peak signal-to-noise ratio between the foreground
|
|
region of `image_pred` and `image` as defined by `mask`.
|
|
rgb_mse_fg: Mean squared error between the foreground
|
|
region of `image_pred` and `image` as defined by `mask`.
|
|
mask_neg_iou: (1 - intersection-over-union) between `mask_pred`
|
|
and `mask`.
|
|
mask_bce: Binary cross entropy between `mask_pred` and `mask`.
|
|
mask_beta_prior: A loss enforcing strictly binary values
|
|
of `mask_pred`: `log(mask_pred) + log(1-mask_pred)`
|
|
depth_abs: Mean per-pixel L1 distance between
|
|
`depth_pred` and `depth`.
|
|
depth_abs_fg: Mean per-pixel L1 distance between the foreground
|
|
region of `depth_pred` and `depth` as defined by `mask`.
|
|
eikonal: Eikonal regularizer `(||grad_theta|| - 1)**2`.
|
|
density_tv: The Total Variation regularizer of density
|
|
values in `density_grid` (sum of L1 distances of values
|
|
of all 4-neighbouring cells).
|
|
depth_neg_penalty: `min(depth_pred, 0)**2` penalizing negative
|
|
predicted depth values.
|
|
"""
|
|
|
|
# TODO: extract functions
|
|
|
|
# reshape from B x ... x DIM to B x DIM x -1 x 1
|
|
images_pred, masks_pred, depths_pred = [
|
|
_reshape_nongrid_var(x) for x in [images_pred, masks_pred, depths_pred]
|
|
]
|
|
# reshape the sampling grid as well
|
|
# TODO: we can get rid of the singular dimension here and in _reshape_nongrid_var
|
|
# now that we use rend_utils.ndc_grid_sample
|
|
image_sampling_grid = image_sampling_grid.reshape(
|
|
image_sampling_grid.shape[0], -1, 1, 2
|
|
)
|
|
|
|
# closure with the given image_sampling_grid
|
|
def sample(tensor, mode):
|
|
if tensor is None:
|
|
return tensor
|
|
return rend_utils.ndc_grid_sample(tensor, image_sampling_grid, mode=mode)
|
|
|
|
# eval all results in this size
|
|
images = sample(images, mode="bilinear")
|
|
depths = sample(depths, mode="nearest")
|
|
masks = sample(masks, mode="nearest")
|
|
masks_crop = sample(masks_crop, mode="nearest")
|
|
if masks_crop is None and images_pred is not None:
|
|
masks_crop = torch.ones_like(images_pred[:, :1])
|
|
if masks_crop is None and depths_pred is not None:
|
|
masks_crop = torch.ones_like(depths_pred[:, :1])
|
|
|
|
preds = {}
|
|
if images is not None and images_pred is not None:
|
|
# TODO: mask_renders_by_pred is always false; simplify
|
|
preds.update(
|
|
_rgb_metrics(
|
|
images,
|
|
images_pred,
|
|
masks,
|
|
masks_pred,
|
|
masks_crop,
|
|
mask_renders_by_pred,
|
|
)
|
|
)
|
|
|
|
if masks_pred is not None:
|
|
preds["mask_beta_prior"] = utils.beta_prior(masks_pred)
|
|
if masks is not None and masks_pred is not None:
|
|
preds["mask_neg_iou"] = utils.neg_iou_loss(
|
|
masks_pred, masks, mask=masks_crop
|
|
)
|
|
preds["mask_bce"] = utils.calc_bce(masks_pred, masks, mask=masks_crop)
|
|
|
|
if depths is not None and depths_pred is not None:
|
|
assert masks_crop is not None
|
|
_, abs_ = utils.eval_depth(
|
|
depths_pred, depths, get_best_scale=True, mask=masks_crop, crop=0
|
|
)
|
|
preds["depth_abs"] = abs_.mean()
|
|
|
|
if masks is not None:
|
|
mask = masks * masks_crop
|
|
_, abs_ = utils.eval_depth(
|
|
depths_pred, depths, get_best_scale=True, mask=mask, crop=0
|
|
)
|
|
preds["depth_abs_fg"] = abs_.mean()
|
|
|
|
# regularizers
|
|
if grad_theta is not None:
|
|
preds["eikonal"] = _get_eikonal_loss(grad_theta)
|
|
|
|
if density_grid is not None:
|
|
preds["density_tv"] = _get_grid_tv_loss(density_grid)
|
|
|
|
if depths_pred is not None:
|
|
preds["depth_neg_penalty"] = _get_depth_neg_penalty_loss(depths_pred)
|
|
|
|
if keys_prefix is not None:
|
|
preds = {(keys_prefix + k): v for k, v in preds.items()}
|
|
|
|
return preds
|
|
|
|
|
|
def _rgb_metrics(
|
|
images, images_pred, masks, masks_pred, masks_crop, mask_renders_by_pred
|
|
):
|
|
assert masks_crop is not None
|
|
if mask_renders_by_pred:
|
|
images = images[..., masks_pred.reshape(-1), :]
|
|
masks_crop = masks_crop[..., masks_pred.reshape(-1), :]
|
|
masks = masks is not None and masks[..., masks_pred.reshape(-1), :]
|
|
rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True)
|
|
rgb_loss = utils.huber(rgb_squared, scaling=0.03)
|
|
crop_mass = masks_crop.sum().clamp(1.0)
|
|
preds = {
|
|
"rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass,
|
|
"rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass,
|
|
"rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop),
|
|
}
|
|
if masks is not None:
|
|
masks = masks_crop * masks
|
|
preds["rgb_psnr_fg"] = utils.calc_psnr(images_pred, images, mask=masks)
|
|
preds["rgb_mse_fg"] = (rgb_squared * masks).sum() / masks.sum().clamp(1.0)
|
|
return preds
|
|
|
|
|
|
def _get_eikonal_loss(grad_theta):
|
|
return ((grad_theta.norm(2, dim=1) - 1) ** 2).mean()
|
|
|
|
|
|
def _get_grid_tv_loss(grid, log_domain: bool = True, eps: float = 1e-5):
|
|
if log_domain:
|
|
if (grid <= -eps).any():
|
|
warnings.warn("Grid has negative values; this will produce NaN loss")
|
|
grid = torch.log(grid + eps)
|
|
|
|
# this is an isotropic version, note that it ignores last rows/cols
|
|
return torch.mean(
|
|
utils.safe_sqrt(
|
|
(grid[..., :-1, :-1, 1:] - grid[..., :-1, :-1, :-1]) ** 2
|
|
+ (grid[..., :-1, 1:, :-1] - grid[..., :-1, :-1, :-1]) ** 2
|
|
+ (grid[..., 1:, :-1, :-1] - grid[..., :-1, :-1, :-1]) ** 2,
|
|
eps=1e-5,
|
|
)
|
|
)
|
|
|
|
|
|
def _get_depth_neg_penalty_loss(depth):
|
|
neg_penalty = depth.clamp(min=None, max=0.0) ** 2
|
|
return torch.mean(neg_penalty)
|
|
|
|
|
|
def _reshape_nongrid_var(x):
|
|
if x is None:
|
|
return None
|
|
|
|
ba, *_, dim = x.shape
|
|
return x.reshape(ba, -1, 1, dim).permute(0, 3, 1, 2).contiguous()
|