mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
Summary: Rasterize MC was not adapted to heterogeneous bundles. There are some caveats though: 1) on CO3D, we get up to 18 points per image, which is too few for a reasonable visualisation (see below); 2) rasterising for a batch of 100 is slow. I also moved the unpacking code close to the bundle to be able to reuse it. {F789678778} Reviewed By: bottler, davnov134 Differential Revision: D41008600 fbshipit-source-id: 9f10f1f9f9a174cf8c534b9b9859587d69832b71
407 lines
16 KiB
Python
407 lines
16 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import warnings
|
|
from typing import Any, Dict, Optional
|
|
|
|
import torch
|
|
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle
|
|
from pytorch3d.implicitron.tools import metric_utils as utils
|
|
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
|
from pytorch3d.ops import padded_to_packed
|
|
from pytorch3d.renderer import utils as rend_utils
|
|
|
|
from .renderer.base import RendererOutput
|
|
|
|
|
|
class RegularizationMetricsBase(ReplaceableBase, torch.nn.Module):
|
|
"""
|
|
Replaceable abstract base for regularization metrics.
|
|
`forward()` method produces regularization metrics and (unlike ViewMetrics) can
|
|
depend on the model's parameters.
|
|
"""
|
|
|
|
def __post_init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(
|
|
self, model: Any, keys_prefix: str = "loss_", **kwargs
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Calculates various regularization terms useful for supervising differentiable
|
|
rendering pipelines.
|
|
|
|
Args:
|
|
model: A model instance. Useful, for example, to implement
|
|
weights-based regularization.
|
|
keys_prefix: A common prefix for all keys in the output dictionary
|
|
containing all regularization metrics.
|
|
|
|
Returns:
|
|
A dictionary with the resulting regularization metrics. The items
|
|
will have form `{metric_name_i: metric_value_i}` keyed by the
|
|
names of the output metrics `metric_name_i` with their corresponding
|
|
values `metric_value_i` represented as 0-dimensional float tensors.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class ViewMetricsBase(ReplaceableBase, torch.nn.Module):
|
|
"""
|
|
Replaceable abstract base for model metrics.
|
|
`forward()` method produces losses and other metrics.
|
|
"""
|
|
|
|
def __post_init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(
|
|
self,
|
|
raymarched: RendererOutput,
|
|
ray_bundle: ImplicitronRayBundle,
|
|
image_rgb: Optional[torch.Tensor] = None,
|
|
depth_map: Optional[torch.Tensor] = None,
|
|
fg_probability: Optional[torch.Tensor] = None,
|
|
mask_crop: Optional[torch.Tensor] = None,
|
|
keys_prefix: str = "loss_",
|
|
**kwargs,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Calculates various metrics and loss functions useful for supervising
|
|
differentiable rendering pipelines. Any additional parameters can be passed
|
|
in the `raymarched.aux` dictionary.
|
|
|
|
Args:
|
|
results: A dictionary with the resulting view metrics. The items
|
|
will have form `{metric_name_i: metric_value_i}` keyed by the
|
|
names of the output metrics `metric_name_i` with their corresponding
|
|
values `metric_value_i` represented as 0-dimensional float tensors.
|
|
raymarched: Output of the renderer.
|
|
ray_bundle: ImplicitronRayBundle object which was used to produce the raymarched
|
|
object
|
|
image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb
|
|
values.
|
|
depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth
|
|
values.
|
|
fg_probability: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth
|
|
foreground masks.
|
|
keys_prefix: A common prefix for all keys in the output dictionary
|
|
containing all view metrics.
|
|
|
|
Returns:
|
|
A dictionary with the resulting view metrics. The items
|
|
will have form `{metric_name_i: metric_value_i}` keyed by the
|
|
names of the output metrics `metric_name_i` with their corresponding
|
|
values `metric_value_i` represented as 0-dimensional float tensors.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
|
|
@registry.register
|
|
class RegularizationMetrics(RegularizationMetricsBase):
|
|
def forward(
|
|
self, model: Any, keys_prefix: str = "loss_", **kwargs
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Calculates the AD penalty, or returns an empty dict if the model's autoencoder
|
|
is inactive.
|
|
|
|
Args:
|
|
model: A model instance.
|
|
keys_prefix: A common prefix for all keys in the output dictionary
|
|
containing all regularization metrics.
|
|
|
|
Returns:
|
|
A dictionary with the resulting regularization metrics. The items
|
|
will have form `{metric_name_i: metric_value_i}` keyed by the
|
|
names of the output metrics `metric_name_i` with their corresponding
|
|
values `metric_value_i` represented as 0-dimensional float tensors.
|
|
|
|
The calculated metric is:
|
|
autoencoder_norm: Autoencoder weight norm regularization term.
|
|
"""
|
|
metrics = {}
|
|
if getattr(model, "sequence_autodecoder", None) is not None:
|
|
ad_penalty = model.sequence_autodecoder.calculate_squared_encoding_norm()
|
|
if ad_penalty is not None:
|
|
metrics["autodecoder_norm"] = ad_penalty
|
|
|
|
if keys_prefix is not None:
|
|
metrics = {(keys_prefix + k): v for k, v in metrics.items()}
|
|
|
|
return metrics
|
|
|
|
|
|
@registry.register
|
|
class ViewMetrics(ViewMetricsBase):
|
|
def forward(
|
|
self,
|
|
raymarched: RendererOutput,
|
|
ray_bundle: ImplicitronRayBundle,
|
|
image_rgb: Optional[torch.Tensor] = None,
|
|
depth_map: Optional[torch.Tensor] = None,
|
|
fg_probability: Optional[torch.Tensor] = None,
|
|
mask_crop: Optional[torch.Tensor] = None,
|
|
keys_prefix: str = "loss_",
|
|
**kwargs,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Calculates various differentiable metrics useful for supervising
|
|
differentiable rendering pipelines.
|
|
|
|
Args:
|
|
results: A dict to store the results in.
|
|
raymarched.features: Predicted rgb or feature values.
|
|
raymarched.depths: A tensor of shape `(B, ..., 1)` containing
|
|
predicted depth values.
|
|
raymarched.masks: A tensor of shape `(B, ..., 1)` containing
|
|
predicted foreground masks.
|
|
raymarched.aux["grad_theta"]: A tensor of shape `(B, ..., 3)` containing an
|
|
evaluation of a gradient of a signed distance function w.r.t.
|
|
input 3D coordinates used to compute the eikonal loss.
|
|
raymarched.aux["density_grid"]: A tensor of shape `(B, Hg, Wg, Dg, 1)`
|
|
containing a `Hg x Wg x Dg` voxel grid of density values.
|
|
ray_bundle: ImplicitronRayBundle object which was used to produce the raymarched
|
|
object
|
|
image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb
|
|
values.
|
|
depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth
|
|
values.
|
|
fg_probability: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth
|
|
foreground masks.
|
|
keys_prefix: A common prefix for all keys in the output dictionary
|
|
containing all view metrics.
|
|
|
|
Returns:
|
|
A dictionary `{metric_name_i: metric_value_i}` keyed by the
|
|
names of the output metrics `metric_name_i` with their corresponding
|
|
values `metric_value_i` represented as 0-dimensional float tensors.
|
|
|
|
The calculated metrics are:
|
|
rgb_huber: A robust huber loss between `image_pred` and `image`.
|
|
rgb_mse: Mean squared error between `image_pred` and `image`.
|
|
rgb_psnr: Peak signal-to-noise ratio between `image_pred` and `image`.
|
|
rgb_psnr_fg: Peak signal-to-noise ratio between the foreground
|
|
region of `image_pred` and `image` as defined by `mask`.
|
|
rgb_mse_fg: Mean squared error between the foreground
|
|
region of `image_pred` and `image` as defined by `mask`.
|
|
mask_neg_iou: (1 - intersection-over-union) between `mask_pred`
|
|
and `mask`.
|
|
mask_bce: Binary cross entropy between `mask_pred` and `mask`.
|
|
mask_beta_prior: A loss enforcing strictly binary values
|
|
of `mask_pred`: `log(mask_pred) + log(1-mask_pred)`
|
|
depth_abs: Mean per-pixel L1 distance between
|
|
`depth_pred` and `depth`.
|
|
depth_abs_fg: Mean per-pixel L1 distance between the foreground
|
|
region of `depth_pred` and `depth` as defined by `mask`.
|
|
eikonal: Eikonal regularizer `(||grad_theta|| - 1)**2`.
|
|
density_tv: The Total Variation regularizer of density
|
|
values in `density_grid` (sum of L1 distances of values
|
|
of all 4-neighbouring cells).
|
|
depth_neg_penalty: `min(depth_pred, 0)**2` penalizing negative
|
|
predicted depth values.
|
|
"""
|
|
metrics = self._calculate_stage(
|
|
raymarched,
|
|
ray_bundle,
|
|
image_rgb,
|
|
depth_map,
|
|
fg_probability,
|
|
mask_crop,
|
|
keys_prefix,
|
|
)
|
|
|
|
if raymarched.prev_stage:
|
|
metrics.update(
|
|
self(
|
|
raymarched.prev_stage,
|
|
ray_bundle,
|
|
image_rgb,
|
|
depth_map,
|
|
fg_probability,
|
|
mask_crop,
|
|
keys_prefix=(keys_prefix + "prev_stage_"),
|
|
)
|
|
)
|
|
|
|
return metrics
|
|
|
|
def _calculate_stage(
|
|
self,
|
|
raymarched: RendererOutput,
|
|
ray_bundle: ImplicitronRayBundle,
|
|
image_rgb: Optional[torch.Tensor] = None,
|
|
depth_map: Optional[torch.Tensor] = None,
|
|
fg_probability: Optional[torch.Tensor] = None,
|
|
mask_crop: Optional[torch.Tensor] = None,
|
|
keys_prefix: str = "loss_",
|
|
**kwargs,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Calculate metrics for the current stage.
|
|
"""
|
|
# TODO: extract functions
|
|
|
|
# reshape from B x ... x DIM to B x DIM x -1 x 1
|
|
image_rgb_pred, fg_probability_pred, depth_map_pred = [
|
|
_reshape_nongrid_var(x)
|
|
for x in [raymarched.features, raymarched.masks, raymarched.depths]
|
|
]
|
|
xys = ray_bundle.xys
|
|
|
|
# If ray_bundle is packed than we can sample images in padded state to lower
|
|
# memory requirements. Instead of having one image for every element in
|
|
# ray_bundle we can than have one image per unique sampled camera.
|
|
if ray_bundle.is_packed():
|
|
xys, first_idxs, num_inputs = ray_bundle.get_padded_xys()
|
|
|
|
# reshape the sampling grid as well
|
|
# TODO: we can get rid of the singular dimension here and in _reshape_nongrid_var
|
|
# now that we use rend_utils.ndc_grid_sample
|
|
xys = xys.reshape(xys.shape[0], -1, 1, 2)
|
|
|
|
# closure with the given xys
|
|
def sample_full(tensor, mode):
|
|
if tensor is None:
|
|
return tensor
|
|
return rend_utils.ndc_grid_sample(tensor, xys, mode=mode)
|
|
|
|
def sample_packed(tensor, mode):
|
|
if tensor is None:
|
|
return tensor
|
|
|
|
# select images that corespond to sampled cameras if raybundle is packed
|
|
tensor = tensor[ray_bundle.camera_ids]
|
|
if ray_bundle.is_packed():
|
|
# select images that corespond to sampled cameras if raybundle is packed
|
|
tensor = tensor[ray_bundle.camera_ids]
|
|
result = rend_utils.ndc_grid_sample(tensor, xys, mode=mode)
|
|
return padded_to_packed(result, first_idxs, num_inputs, max_size_dim=2)[
|
|
:, :, None
|
|
] # the result is [n_rays_total_training, 3, 1, 1]
|
|
|
|
sample = sample_packed if ray_bundle.is_packed() else sample_full
|
|
|
|
# eval all results in this size
|
|
image_rgb = sample(image_rgb, mode="bilinear")
|
|
depth_map = sample(depth_map, mode="nearest")
|
|
fg_probability = sample(fg_probability, mode="nearest")
|
|
mask_crop = sample(mask_crop, mode="nearest")
|
|
if mask_crop is None and image_rgb_pred is not None:
|
|
mask_crop = torch.ones_like(image_rgb_pred[:, :1])
|
|
if mask_crop is None and depth_map_pred is not None:
|
|
mask_crop = torch.ones_like(depth_map_pred[:, :1])
|
|
|
|
metrics = {}
|
|
if image_rgb is not None and image_rgb_pred is not None:
|
|
metrics.update(
|
|
_rgb_metrics(
|
|
image_rgb,
|
|
image_rgb_pred,
|
|
fg_probability,
|
|
fg_probability_pred,
|
|
mask_crop,
|
|
)
|
|
)
|
|
|
|
if fg_probability_pred is not None:
|
|
metrics["mask_beta_prior"] = utils.beta_prior(fg_probability_pred)
|
|
if fg_probability is not None and fg_probability_pred is not None:
|
|
metrics["mask_neg_iou"] = utils.neg_iou_loss(
|
|
fg_probability_pred, fg_probability, mask=mask_crop
|
|
)
|
|
metrics["mask_bce"] = utils.calc_bce(
|
|
fg_probability_pred, fg_probability, mask=mask_crop
|
|
)
|
|
|
|
if depth_map is not None and depth_map_pred is not None:
|
|
assert mask_crop is not None
|
|
_, abs_ = utils.eval_depth(
|
|
depth_map_pred, depth_map, get_best_scale=True, mask=mask_crop, crop=0
|
|
)
|
|
metrics["depth_abs"] = abs_.mean()
|
|
|
|
if fg_probability is not None:
|
|
mask = fg_probability * mask_crop
|
|
_, abs_ = utils.eval_depth(
|
|
depth_map_pred, depth_map, get_best_scale=True, mask=mask, crop=0
|
|
)
|
|
metrics["depth_abs_fg"] = abs_.mean()
|
|
|
|
# regularizers
|
|
grad_theta = raymarched.aux.get("grad_theta")
|
|
if grad_theta is not None:
|
|
metrics["eikonal"] = _get_eikonal_loss(grad_theta)
|
|
|
|
density_grid = raymarched.aux.get("density_grid")
|
|
if density_grid is not None:
|
|
metrics["density_tv"] = _get_grid_tv_loss(density_grid)
|
|
|
|
if depth_map_pred is not None:
|
|
metrics["depth_neg_penalty"] = _get_depth_neg_penalty_loss(depth_map_pred)
|
|
|
|
if keys_prefix is not None:
|
|
metrics = {(keys_prefix + k): v for k, v in metrics.items()}
|
|
|
|
return metrics
|
|
|
|
|
|
def _rgb_metrics(images, images_pred, masks, masks_pred, masks_crop):
|
|
assert masks_crop is not None
|
|
if images.shape[1] != images_pred.shape[1]:
|
|
raise ValueError(
|
|
f"Network output's RGB images had {images_pred.shape[1]} "
|
|
f"channels. {images.shape[1]} expected."
|
|
)
|
|
rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True)
|
|
rgb_loss = utils.huber(rgb_squared, scaling=0.03)
|
|
crop_mass = masks_crop.sum().clamp(1.0)
|
|
results = {
|
|
"rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass,
|
|
"rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass,
|
|
"rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop),
|
|
}
|
|
if masks is not None:
|
|
masks = masks_crop * masks
|
|
results["rgb_psnr_fg"] = utils.calc_psnr(images_pred, images, mask=masks)
|
|
results["rgb_mse_fg"] = (rgb_squared * masks).sum() / masks.sum().clamp(1.0)
|
|
return results
|
|
|
|
|
|
def _get_eikonal_loss(grad_theta):
|
|
return ((grad_theta.norm(2, dim=1) - 1) ** 2).mean()
|
|
|
|
|
|
def _get_grid_tv_loss(grid, log_domain: bool = True, eps: float = 1e-5):
|
|
if log_domain:
|
|
if (grid <= -eps).any():
|
|
warnings.warn("Grid has negative values; this will produce NaN loss")
|
|
grid = torch.log(grid + eps)
|
|
|
|
# this is an isotropic version, note that it ignores last rows/cols
|
|
return torch.mean(
|
|
utils.safe_sqrt(
|
|
(grid[..., :-1, :-1, 1:] - grid[..., :-1, :-1, :-1]) ** 2
|
|
+ (grid[..., :-1, 1:, :-1] - grid[..., :-1, :-1, :-1]) ** 2
|
|
+ (grid[..., 1:, :-1, :-1] - grid[..., :-1, :-1, :-1]) ** 2,
|
|
eps=1e-5,
|
|
)
|
|
)
|
|
|
|
|
|
def _get_depth_neg_penalty_loss(depth):
|
|
neg_penalty = depth.clamp(min=None, max=0.0) ** 2
|
|
return torch.mean(neg_penalty)
|
|
|
|
|
|
def _reshape_nongrid_var(x):
|
|
if x is None:
|
|
return None
|
|
|
|
ba, *_, dim = x.shape
|
|
return x.reshape(ba, -1, 1, dim).permute(0, 3, 1, 2).contiguous()
|