diff --git a/projects/implicitron_trainer/tests/experiment.yaml b/projects/implicitron_trainer/tests/experiment.yaml index eda38f3a..1bdc0ddf 100644 --- a/projects/implicitron_trainer/tests/experiment.yaml +++ b/projects/implicitron_trainer/tests/experiment.yaml @@ -21,6 +21,8 @@ generic_model_args: image_feature_extractor_class_type: null view_pooler_enabled: false implicit_function_class_type: NeuralRadianceFieldImplicitFunction + view_metrics_class_type: ViewMetrics + regularization_metrics_class_type: RegularizationMetrics loss_weights: loss_rgb_mse: 1.0 loss_prev_stage_rgb_mse: 1.0 @@ -268,6 +270,8 @@ generic_model_args: in_features: 256 out_features: 3 ray_dir_in_camera_coords: false + view_metrics_ViewMetrics_args: {} + regularization_metrics_RegularizationMetrics_args: {} solver_args: breed: adam weight_decay: 0.0 diff --git a/pytorch3d/implicitron/models/base_model.py b/pytorch3d/implicitron/models/base_model.py index 1bcb577c..ffd7d19f 100644 --- a/pytorch3d/implicitron/models/base_model.py +++ b/pytorch3d/implicitron/models/base_model.py @@ -38,7 +38,8 @@ class ImplicitronRender: class ImplicitronModelBase(ReplaceableBase): - """Replaceable abstract base for all image generation / rendering models. + """ + Replaceable abstract base for all image generation / rendering models. `forward()` method produces a render with a depth map. """ diff --git a/pytorch3d/implicitron/models/generic_model.py b/pytorch3d/implicitron/models/generic_model.py index 68c30e4d..e7ee80df 100644 --- a/pytorch3d/implicitron/models/generic_model.py +++ b/pytorch3d/implicitron/models/generic_model.py @@ -16,6 +16,12 @@ from typing import Any, Dict, List, Optional, Tuple import torch import tqdm +from pytorch3d.implicitron.models.metrics import ( # noqa + RegularizationMetrics, + RegularizationMetricsBase, + ViewMetrics, + ViewMetricsBase, +) from pytorch3d.implicitron.tools import image_utils, vis_utils from pytorch3d.implicitron.tools.config import ( expand_args_fields, @@ -42,7 +48,7 @@ from .implicit_function.scene_representation_networks import ( # noqa SRNHyperNetImplicitFunction, SRNImplicitFunction, ) -from .metrics import ViewMetrics + from .renderer.base import ( BaseRenderer, EvaluationMode, @@ -184,6 +190,14 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 is available in the global registry. implicit_function: An instance of ImplicitFunctionBase. The actual implicit functions are initialised to be in self._implicit_functions. + view_metrics: An instance of ViewMetricsBase used to compute loss terms which + are independent of the model's parameters. + view_metrics_class_type: The type of view metrics to use, must be available in + the global registry. + regularization_metrics: An instance of RegularizationMetricsBase used to compute + regularization terms which can depend on the model's parameters. + regularization_metrics_class_type: The type of regularization metrics to use, + must be available in the global registry. loss_weights: A dictionary with a {loss_name: weight} mapping; see documentation for `ViewMetrics` class for available loss functions. log_vars: A list of variable names which should be logged. @@ -232,6 +246,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 # The actual implicit functions live in self._implicit_functions implicit_function: ImplicitFunctionBase + # ----- metrics + view_metrics: ViewMetricsBase + view_metrics_class_type: str = "ViewMetrics" + + regularization_metrics: RegularizationMetricsBase + regularization_metrics_class_type: str = "RegularizationMetrics" + # ---- loss weights loss_weights: Dict[str, float] = field( default_factory=lambda: { @@ -269,7 +290,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 def __post_init__(self): super().__init__() - self.view_metrics = ViewMetrics() if self.view_pooler_enabled: if self.image_feature_extractor_class_type is None: @@ -424,15 +444,31 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 for func in self._implicit_functions: func.unbind_args() - preds = self._get_view_metrics( - raymarched=rendered, - xys=ray_bundle.xys, - image_rgb=None if image_rgb is None else image_rgb[:n_targets], - depth_map=None if depth_map is None else depth_map[:n_targets], - fg_probability=None - if fg_probability is None - else fg_probability[:n_targets], - mask_crop=None if mask_crop is None else mask_crop[:n_targets], + # A dict to store losses as well as rendering results. + preds: Dict[str, Any] = {} + + def safe_slice_targets( + tensor: Optional[torch.Tensor], + ) -> Optional[torch.Tensor]: + return None if tensor is None else tensor[:n_targets] + + preds.update( + self.view_metrics( + results=preds, + raymarched=rendered, + xys=ray_bundle.xys, + image_rgb=safe_slice_targets(image_rgb), + depth_map=safe_slice_targets(depth_map), + fg_probability=safe_slice_targets(fg_probability), + mask_crop=safe_slice_targets(mask_crop), + ) + ) + + preds.update( + self.regularization_metrics( + results=preds, + model=self, + ) ) if sampling_mode == RenderSamplingMode.MASK_SAMPLE: @@ -462,11 +498,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 else: raise AssertionError("Unreachable state") - # calc the AD penalty, returns None if autodecoder is not active - ad_penalty = self.sequence_autodecoder.calc_squared_encoding_norm() - if ad_penalty is not None: - preds["loss_autodecoder_norm"] = ad_penalty - # (7) Compute losses # finally get the optimization objective using self.loss_weights objective = self._get_objective(preds) @@ -744,45 +775,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 return image_rgb, fg_mask, depth_map - def _get_view_metrics( - self, - raymarched: RendererOutput, - xys: torch.Tensor, - image_rgb: Optional[torch.Tensor] = None, - depth_map: Optional[torch.Tensor] = None, - fg_probability: Optional[torch.Tensor] = None, - mask_crop: Optional[torch.Tensor] = None, - keys_prefix: str = "loss_", - ): - # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. - metrics = self.view_metrics( - image_sampling_grid=xys, - images_pred=raymarched.features, - images=image_rgb, - depths_pred=raymarched.depths, - depths=depth_map, - masks_pred=raymarched.masks, - masks=fg_probability, - masks_crop=mask_crop, - keys_prefix=keys_prefix, - **raymarched.aux, - ) - - if raymarched.prev_stage: - metrics.update( - self._get_view_metrics( - raymarched.prev_stage, - xys, - image_rgb, - depth_map, - fg_probability, - mask_crop, - keys_prefix=(keys_prefix + "prev_stage_"), - ) - ) - - return metrics - @torch.no_grad() def _rasterize_mc_samples( self, diff --git a/pytorch3d/implicitron/models/metrics.py b/pytorch3d/implicitron/models/metrics.py index 527bc4b4..168c8bde 100644 --- a/pytorch3d/implicitron/models/metrics.py +++ b/pytorch3d/implicitron/models/metrics.py @@ -6,64 +6,180 @@ import warnings -from typing import Dict, Optional +from typing import Any, Dict, Optional import torch from pytorch3d.implicitron.tools import metric_utils as utils +from pytorch3d.implicitron.tools.config import registry, ReplaceableBase from pytorch3d.renderer import utils as rend_utils +from .renderer.base import RendererOutput + + +class RegularizationMetricsBase(ReplaceableBase, torch.nn.Module): + """ + Replaceable abstract base for regularization metrics. + `forward()` method produces regularization metrics and (unlike ViewMetrics) can + depend on the model's parameters. + """ + + def __post_init__(self) -> None: + super().__init__() + + def forward( + self, model: Any, keys_prefix: str = "loss_", **kwargs + ) -> Dict[str, Any]: + """ + Calculates various regularization terms useful for supervising differentiable + rendering pipelines. + + Args: + model: A model instance. Useful, for example, to implement + weights-based regularization. + keys_prefix: A common prefix for all keys in the output dictionary + containing all regularization metrics. + + Returns: + A dictionary with the resulting regularization metrics. The items + will have form `{metric_name_i: metric_value_i}` keyed by the + names of the output metrics `metric_name_i` with their corresponding + values `metric_value_i` represented as 0-dimensional float tensors. + """ + raise NotImplementedError + + +class ViewMetricsBase(ReplaceableBase, torch.nn.Module): + """ + Replaceable abstract base for model metrics. + `forward()` method produces losses and other metrics. + """ + + def __post_init__(self) -> None: + super().__init__() -class ViewMetrics(torch.nn.Module): def forward( self, - image_sampling_grid: torch.Tensor, - images: Optional[torch.Tensor] = None, - images_pred: Optional[torch.Tensor] = None, - depths: Optional[torch.Tensor] = None, - depths_pred: Optional[torch.Tensor] = None, - masks: Optional[torch.Tensor] = None, - masks_pred: Optional[torch.Tensor] = None, - masks_crop: Optional[torch.Tensor] = None, - grad_theta: Optional[torch.Tensor] = None, - density_grid: Optional[torch.Tensor] = None, + raymarched: RendererOutput, + xys: torch.Tensor, + image_rgb: Optional[torch.Tensor] = None, + depth_map: Optional[torch.Tensor] = None, + fg_probability: Optional[torch.Tensor] = None, + mask_crop: Optional[torch.Tensor] = None, keys_prefix: str = "loss_", - mask_renders_by_pred: bool = False, **kwargs, - ) -> Dict[str, torch.Tensor]: + ) -> Dict[str, Any]: + """ + Calculates various metrics and loss functions useful for supervising + differentiable rendering pipelines. Any additional parameters can be passed + in the `raymarched.aux` dictionary. + + Args: + results: A dictionary with the resulting view metrics. The items + will have form `{metric_name_i: metric_value_i}` keyed by the + names of the output metrics `metric_name_i` with their corresponding + values `metric_value_i` represented as 0-dimensional float tensors. + raymarched: Output of the renderer. + xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which + the predictions are defined. All ground truth inputs are sampled at + these locations in order to extract values that correspond to the + predictions. + image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb + values. + depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth + values. + fg_probability: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth + foreground masks. + keys_prefix: A common prefix for all keys in the output dictionary + containing all view metrics. + + Returns: + A dictionary with the resulting view metrics. The items + will have form `{metric_name_i: metric_value_i}` keyed by the + names of the output metrics `metric_name_i` with their corresponding + values `metric_value_i` represented as 0-dimensional float tensors. + """ + raise NotImplementedError() + + +@registry.register +class RegularizationMetrics(RegularizationMetricsBase): + def forward( + self, model: Any, keys_prefix: str = "loss_", **kwargs + ) -> Dict[str, Any]: + """ + Calculates the AD penalty, or returns an empty dict if the model's autoencoder + is inactive. + + Args: + model: A model instance. + keys_prefix: A common prefix for all keys in the output dictionary + containing all regularization metrics. + + Returns: + A dictionary with the resulting regularization metrics. The items + will have form `{metric_name_i: metric_value_i}` keyed by the + names of the output metrics `metric_name_i` with their corresponding + values `metric_value_i` represented as 0-dimensional float tensors. + + The calculated metric is: + autoencoder_norm: Autoencoder weight norm regularization term. + """ + metrics = {} + if getattr(model, "sequence_autodecoder", None) is not None: + ad_penalty = model.sequence_autodecoder.calc_squared_encoding_norm() + if ad_penalty is not None: + metrics["autodecoder_norm"] = ad_penalty + + if keys_prefix is not None: + metrics = {(keys_prefix + k): v for k, v in metrics.items()} + + return metrics + + +@registry.register +class ViewMetrics(ViewMetricsBase): + def forward( + self, + raymarched: RendererOutput, + xys: torch.Tensor, + image_rgb: Optional[torch.Tensor] = None, + depth_map: Optional[torch.Tensor] = None, + fg_probability: Optional[torch.Tensor] = None, + mask_crop: Optional[torch.Tensor] = None, + keys_prefix: str = "loss_", + **kwargs, + ) -> Dict[str, Any]: """ Calculates various differentiable metrics useful for supervising differentiable rendering pipelines. Args: - image_sampling_grid: A tensor of shape `(B, ..., 2)` containing 2D - image locations at which the predictions are defined. - All ground truth inputs are sampled at these - locations in order to extract values that correspond - to the predictions. - images: A tensor of shape `(B, H, W, 3)` containing ground truth - rgb values. - images_pred: A tensor of shape `(B, ..., 3)` containing predicted - rgb values. - depths: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth - depth values. - depths_pred: A tensor of shape `(B, ..., 1)` containing predicted - depth values. - masks: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth - foreground masks. - masks_pred: A tensor of shape `(B, ..., 1)` containing predicted - foreground masks. - grad_theta: A tensor of shape `(B, ..., 3)` containing an evaluation - of a gradient of a signed distance function w.r.t. + results: A dict to store the results in. + raymarched.features: Predicted rgb or feature values. + raymarched.depths: A tensor of shape `(B, ..., 1)` containing + predicted depth values. + raymarched.masks: A tensor of shape `(B, ..., 1)` containing + predicted foreground masks. + raymarched.aux["grad_theta"]: A tensor of shape `(B, ..., 3)` containing an + evaluation of a gradient of a signed distance function w.r.t. input 3D coordinates used to compute the eikonal loss. - density_grid: A tensor of shape `(B, Hg, Wg, Dg, 1)` containing a - `Hg x Wg x Dg` voxel grid of density values. + raymarched.aux["density_grid"]: A tensor of shape `(B, Hg, Wg, Dg, 1)` + containing a `Hg x Wg x Dg` voxel grid of density values. + xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which + the predictions are defined. All ground truth inputs are sampled at + these locations in order to extract values that correspond to the + predictions. + image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb + values. + depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth + values. + fg_probability: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth + foreground masks. keys_prefix: A common prefix for all keys in the output dictionary - containing all metrics. - mask_renders_by_pred: If `True`, masks rendered images by the predicted - `masks_pred` prior to computing all rgb metrics. + containing all view metrics. Returns: - metrics: A dictionary `{metric_name_i: metric_value_i}` keyed by the + A dictionary `{metric_name_i: metric_value_i}` keyed by the names of the output metrics `metric_name_i` with their corresponding values `metric_value_i` represented as 0-dimensional float tensors. @@ -91,109 +207,142 @@ class ViewMetrics(torch.nn.Module): depth_neg_penalty: `min(depth_pred, 0)**2` penalizing negative predicted depth values. """ + metrics = self._calculate_stage( + raymarched, + xys, + image_rgb, + depth_map, + fg_probability, + mask_crop, + keys_prefix, + ) + if raymarched.prev_stage: + metrics.update( + self( + raymarched.prev_stage, + xys, + image_rgb, + depth_map, + fg_probability, + mask_crop, + keys_prefix=(keys_prefix + "prev_stage_"), + ) + ) + + return metrics + + def _calculate_stage( + self, + raymarched: RendererOutput, + xys: torch.Tensor, + image_rgb: Optional[torch.Tensor] = None, + depth_map: Optional[torch.Tensor] = None, + fg_probability: Optional[torch.Tensor] = None, + mask_crop: Optional[torch.Tensor] = None, + keys_prefix: str = "loss_", + **kwargs, + ) -> Dict[str, Any]: + """ + Calculate metrics for the current stage. + """ # TODO: extract functions # reshape from B x ... x DIM to B x DIM x -1 x 1 - images_pred, masks_pred, depths_pred = [ - _reshape_nongrid_var(x) for x in [images_pred, masks_pred, depths_pred] + image_rgb_pred, fg_probability_pred, depth_map_pred = [ + _reshape_nongrid_var(x) + for x in [raymarched.features, raymarched.masks, raymarched.depths] ] # reshape the sampling grid as well # TODO: we can get rid of the singular dimension here and in _reshape_nongrid_var # now that we use rend_utils.ndc_grid_sample - image_sampling_grid = image_sampling_grid.reshape( - image_sampling_grid.shape[0], -1, 1, 2 - ) + xys = xys.reshape(xys.shape[0], -1, 1, 2) - # closure with the given image_sampling_grid + # closure with the given xys def sample(tensor, mode): if tensor is None: return tensor - return rend_utils.ndc_grid_sample(tensor, image_sampling_grid, mode=mode) + return rend_utils.ndc_grid_sample(tensor, xys, mode=mode) # eval all results in this size - images = sample(images, mode="bilinear") - depths = sample(depths, mode="nearest") - masks = sample(masks, mode="nearest") - masks_crop = sample(masks_crop, mode="nearest") - if masks_crop is None and images_pred is not None: - masks_crop = torch.ones_like(images_pred[:, :1]) - if masks_crop is None and depths_pred is not None: - masks_crop = torch.ones_like(depths_pred[:, :1]) + image_rgb = sample(image_rgb, mode="bilinear") + depth_map = sample(depth_map, mode="nearest") + fg_probability = sample(fg_probability, mode="nearest") + mask_crop = sample(mask_crop, mode="nearest") + if mask_crop is None and image_rgb_pred is not None: + mask_crop = torch.ones_like(image_rgb_pred[:, :1]) + if mask_crop is None and depth_map_pred is not None: + mask_crop = torch.ones_like(depth_map_pred[:, :1]) - preds = {} - if images is not None and images_pred is not None: - # TODO: mask_renders_by_pred is always false; simplify - preds.update( + metrics = {} + if image_rgb is not None and image_rgb_pred is not None: + metrics.update( _rgb_metrics( - images, - images_pred, - masks, - masks_pred, - masks_crop, - mask_renders_by_pred, + image_rgb, + image_rgb_pred, + fg_probability, + fg_probability_pred, + mask_crop, ) ) - if masks_pred is not None: - preds["mask_beta_prior"] = utils.beta_prior(masks_pred) - if masks is not None and masks_pred is not None: - preds["mask_neg_iou"] = utils.neg_iou_loss( - masks_pred, masks, mask=masks_crop + if fg_probability_pred is not None: + metrics["mask_beta_prior"] = utils.beta_prior(fg_probability_pred) + if fg_probability is not None and fg_probability_pred is not None: + metrics["mask_neg_iou"] = utils.neg_iou_loss( + fg_probability_pred, fg_probability, mask=mask_crop + ) + metrics["mask_bce"] = utils.calc_bce( + fg_probability_pred, fg_probability, mask=mask_crop ) - preds["mask_bce"] = utils.calc_bce(masks_pred, masks, mask=masks_crop) - if depths is not None and depths_pred is not None: - assert masks_crop is not None + if depth_map is not None and depth_map_pred is not None: + assert mask_crop is not None _, abs_ = utils.eval_depth( - depths_pred, depths, get_best_scale=True, mask=masks_crop, crop=0 + depth_map_pred, depth_map, get_best_scale=True, mask=mask_crop, crop=0 ) - preds["depth_abs"] = abs_.mean() + metrics["depth_abs"] = abs_.mean() - if masks is not None: - mask = masks * masks_crop + if fg_probability is not None: + mask = fg_probability * mask_crop _, abs_ = utils.eval_depth( - depths_pred, depths, get_best_scale=True, mask=mask, crop=0 + depth_map_pred, depth_map, get_best_scale=True, mask=mask, crop=0 ) - preds["depth_abs_fg"] = abs_.mean() + metrics["depth_abs_fg"] = abs_.mean() # regularizers + grad_theta = raymarched.aux.get("grad_theta") if grad_theta is not None: - preds["eikonal"] = _get_eikonal_loss(grad_theta) + metrics["eikonal"] = _get_eikonal_loss(grad_theta) + density_grid = raymarched.aux.get("density_grid") if density_grid is not None: - preds["density_tv"] = _get_grid_tv_loss(density_grid) + metrics["density_tv"] = _get_grid_tv_loss(density_grid) - if depths_pred is not None: - preds["depth_neg_penalty"] = _get_depth_neg_penalty_loss(depths_pred) + if depth_map_pred is not None: + metrics["depth_neg_penalty"] = _get_depth_neg_penalty_loss(depth_map_pred) if keys_prefix is not None: - preds = {(keys_prefix + k): v for k, v in preds.items()} + metrics = {(keys_prefix + k): v for k, v in metrics.items()} - return preds + return metrics -def _rgb_metrics( - images, images_pred, masks, masks_pred, masks_crop, mask_renders_by_pred -): +def _rgb_metrics(images, images_pred, masks, masks_pred, masks_crop): assert masks_crop is not None - if mask_renders_by_pred: - images = images[..., masks_pred.reshape(-1), :] - masks_crop = masks_crop[..., masks_pred.reshape(-1), :] - masks = masks is not None and masks[..., masks_pred.reshape(-1), :] rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True) rgb_loss = utils.huber(rgb_squared, scaling=0.03) crop_mass = masks_crop.sum().clamp(1.0) - preds = { + results = { "rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass, "rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass, "rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop), } if masks is not None: masks = masks_crop * masks - preds["rgb_psnr_fg"] = utils.calc_psnr(images_pred, images, mask=masks) - preds["rgb_mse_fg"] = (rgb_squared * masks).sum() / masks.sum().clamp(1.0) - return preds + results["rgb_psnr_fg"] = utils.calc_psnr(images_pred, images, mask=masks) + results["rgb_mse_fg"] = (rgb_squared * masks).sum() / masks.sum().clamp(1.0) + return results def _get_eikonal_loss(grad_theta): diff --git a/tests/implicitron/data/overrides.yaml b/tests/implicitron/data/overrides.yaml index b1e0489f..f1826ae2 100644 --- a/tests/implicitron/data/overrides.yaml +++ b/tests/implicitron/data/overrides.yaml @@ -20,6 +20,8 @@ renderer_class_type: LSTMRenderer image_feature_extractor_class_type: ResNetFeatureExtractor view_pooler_enabled: true implicit_function_class_type: IdrFeatureField +view_metrics_class_type: ViewMetrics +regularization_metrics_class_type: RegularizationMetrics loss_weights: loss_rgb_mse: 1.0 loss_prev_stage_rgb_mse: 1.0 @@ -122,3 +124,5 @@ implicit_function_IdrFeatureField_args: n_harmonic_functions_xyz: 1729 pooled_feature_dim: 0 encoding_dim: 0 +view_metrics_ViewMetrics_args: {} +regularization_metrics_RegularizationMetrics_args: {}