mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Refactor ViewMetrics
Summary: Make ViewMetrics easy to replace by putting them into an OmegaConf dataclass. Also, re-word a few variable names and fix minor TODOs. Reviewed By: bottler Differential Revision: D37327157 fbshipit-source-id: 78d8e39bbb3548b952f10abbe05688409fb987cc
This commit is contained in:
		
							parent
							
								
									f4dd151037
								
							
						
					
					
						commit
						ae35824f21
					
				@ -21,6 +21,8 @@ generic_model_args:
 | 
			
		||||
  image_feature_extractor_class_type: null
 | 
			
		||||
  view_pooler_enabled: false
 | 
			
		||||
  implicit_function_class_type: NeuralRadianceFieldImplicitFunction
 | 
			
		||||
  view_metrics_class_type: ViewMetrics
 | 
			
		||||
  regularization_metrics_class_type: RegularizationMetrics
 | 
			
		||||
  loss_weights:
 | 
			
		||||
    loss_rgb_mse: 1.0
 | 
			
		||||
    loss_prev_stage_rgb_mse: 1.0
 | 
			
		||||
@ -268,6 +270,8 @@ generic_model_args:
 | 
			
		||||
      in_features: 256
 | 
			
		||||
      out_features: 3
 | 
			
		||||
      ray_dir_in_camera_coords: false
 | 
			
		||||
  view_metrics_ViewMetrics_args: {}
 | 
			
		||||
  regularization_metrics_RegularizationMetrics_args: {}
 | 
			
		||||
solver_args:
 | 
			
		||||
  breed: adam
 | 
			
		||||
  weight_decay: 0.0
 | 
			
		||||
 | 
			
		||||
@ -38,7 +38,8 @@ class ImplicitronRender:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ImplicitronModelBase(ReplaceableBase):
 | 
			
		||||
    """Replaceable abstract base for all image generation / rendering models.
 | 
			
		||||
    """
 | 
			
		||||
    Replaceable abstract base for all image generation / rendering models.
 | 
			
		||||
    `forward()` method produces a render with a depth map.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,12 @@ from typing import Any, Dict, List, Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import tqdm
 | 
			
		||||
from pytorch3d.implicitron.models.metrics import (  # noqa
 | 
			
		||||
    RegularizationMetrics,
 | 
			
		||||
    RegularizationMetricsBase,
 | 
			
		||||
    ViewMetrics,
 | 
			
		||||
    ViewMetricsBase,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.tools import image_utils, vis_utils
 | 
			
		||||
from pytorch3d.implicitron.tools.config import (
 | 
			
		||||
    expand_args_fields,
 | 
			
		||||
@ -42,7 +48,7 @@ from .implicit_function.scene_representation_networks import (  # noqa
 | 
			
		||||
    SRNHyperNetImplicitFunction,
 | 
			
		||||
    SRNImplicitFunction,
 | 
			
		||||
)
 | 
			
		||||
from .metrics import ViewMetrics
 | 
			
		||||
 | 
			
		||||
from .renderer.base import (
 | 
			
		||||
    BaseRenderer,
 | 
			
		||||
    EvaluationMode,
 | 
			
		||||
@ -184,6 +190,14 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
            is available in the global registry.
 | 
			
		||||
        implicit_function: An instance of ImplicitFunctionBase. The actual implicit functions
 | 
			
		||||
            are initialised to be in self._implicit_functions.
 | 
			
		||||
        view_metrics: An instance of ViewMetricsBase used to compute loss terms which
 | 
			
		||||
            are independent of the model's parameters.
 | 
			
		||||
        view_metrics_class_type: The type of view metrics to use, must be available in
 | 
			
		||||
            the global registry.
 | 
			
		||||
        regularization_metrics: An instance of RegularizationMetricsBase used to compute
 | 
			
		||||
            regularization terms which can depend on the model's parameters.
 | 
			
		||||
        regularization_metrics_class_type: The type of regularization metrics to use,
 | 
			
		||||
            must be available in the global registry.
 | 
			
		||||
        loss_weights: A dictionary with a {loss_name: weight} mapping; see documentation
 | 
			
		||||
            for `ViewMetrics` class for available loss functions.
 | 
			
		||||
        log_vars: A list of variable names which should be logged.
 | 
			
		||||
@ -232,6 +246,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
    # The actual implicit functions live in self._implicit_functions
 | 
			
		||||
    implicit_function: ImplicitFunctionBase
 | 
			
		||||
 | 
			
		||||
    # ----- metrics
 | 
			
		||||
    view_metrics: ViewMetricsBase
 | 
			
		||||
    view_metrics_class_type: str = "ViewMetrics"
 | 
			
		||||
 | 
			
		||||
    regularization_metrics: RegularizationMetricsBase
 | 
			
		||||
    regularization_metrics_class_type: str = "RegularizationMetrics"
 | 
			
		||||
 | 
			
		||||
    # ---- loss weights
 | 
			
		||||
    loss_weights: Dict[str, float] = field(
 | 
			
		||||
        default_factory=lambda: {
 | 
			
		||||
@ -269,7 +290,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.view_metrics = ViewMetrics()
 | 
			
		||||
 | 
			
		||||
        if self.view_pooler_enabled:
 | 
			
		||||
            if self.image_feature_extractor_class_type is None:
 | 
			
		||||
@ -424,15 +444,31 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
        for func in self._implicit_functions:
 | 
			
		||||
            func.unbind_args()
 | 
			
		||||
 | 
			
		||||
        preds = self._get_view_metrics(
 | 
			
		||||
            raymarched=rendered,
 | 
			
		||||
            xys=ray_bundle.xys,
 | 
			
		||||
            image_rgb=None if image_rgb is None else image_rgb[:n_targets],
 | 
			
		||||
            depth_map=None if depth_map is None else depth_map[:n_targets],
 | 
			
		||||
            fg_probability=None
 | 
			
		||||
            if fg_probability is None
 | 
			
		||||
            else fg_probability[:n_targets],
 | 
			
		||||
            mask_crop=None if mask_crop is None else mask_crop[:n_targets],
 | 
			
		||||
        # A dict to store losses as well as rendering results.
 | 
			
		||||
        preds: Dict[str, Any] = {}
 | 
			
		||||
 | 
			
		||||
        def safe_slice_targets(
 | 
			
		||||
            tensor: Optional[torch.Tensor],
 | 
			
		||||
        ) -> Optional[torch.Tensor]:
 | 
			
		||||
            return None if tensor is None else tensor[:n_targets]
 | 
			
		||||
 | 
			
		||||
        preds.update(
 | 
			
		||||
            self.view_metrics(
 | 
			
		||||
                results=preds,
 | 
			
		||||
                raymarched=rendered,
 | 
			
		||||
                xys=ray_bundle.xys,
 | 
			
		||||
                image_rgb=safe_slice_targets(image_rgb),
 | 
			
		||||
                depth_map=safe_slice_targets(depth_map),
 | 
			
		||||
                fg_probability=safe_slice_targets(fg_probability),
 | 
			
		||||
                mask_crop=safe_slice_targets(mask_crop),
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        preds.update(
 | 
			
		||||
            self.regularization_metrics(
 | 
			
		||||
                results=preds,
 | 
			
		||||
                model=self,
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if sampling_mode == RenderSamplingMode.MASK_SAMPLE:
 | 
			
		||||
@ -462,11 +498,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
        else:
 | 
			
		||||
            raise AssertionError("Unreachable state")
 | 
			
		||||
 | 
			
		||||
        # calc the AD penalty, returns None if autodecoder is not active
 | 
			
		||||
        ad_penalty = self.sequence_autodecoder.calc_squared_encoding_norm()
 | 
			
		||||
        if ad_penalty is not None:
 | 
			
		||||
            preds["loss_autodecoder_norm"] = ad_penalty
 | 
			
		||||
 | 
			
		||||
        # (7) Compute losses
 | 
			
		||||
        # finally get the optimization objective using self.loss_weights
 | 
			
		||||
        objective = self._get_objective(preds)
 | 
			
		||||
@ -744,45 +775,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module):  # pyre-ignore: 13
 | 
			
		||||
 | 
			
		||||
        return image_rgb, fg_mask, depth_map
 | 
			
		||||
 | 
			
		||||
    def _get_view_metrics(
 | 
			
		||||
        self,
 | 
			
		||||
        raymarched: RendererOutput,
 | 
			
		||||
        xys: torch.Tensor,
 | 
			
		||||
        image_rgb: Optional[torch.Tensor] = None,
 | 
			
		||||
        depth_map: Optional[torch.Tensor] = None,
 | 
			
		||||
        fg_probability: Optional[torch.Tensor] = None,
 | 
			
		||||
        mask_crop: Optional[torch.Tensor] = None,
 | 
			
		||||
        keys_prefix: str = "loss_",
 | 
			
		||||
    ):
 | 
			
		||||
        # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
 | 
			
		||||
        metrics = self.view_metrics(
 | 
			
		||||
            image_sampling_grid=xys,
 | 
			
		||||
            images_pred=raymarched.features,
 | 
			
		||||
            images=image_rgb,
 | 
			
		||||
            depths_pred=raymarched.depths,
 | 
			
		||||
            depths=depth_map,
 | 
			
		||||
            masks_pred=raymarched.masks,
 | 
			
		||||
            masks=fg_probability,
 | 
			
		||||
            masks_crop=mask_crop,
 | 
			
		||||
            keys_prefix=keys_prefix,
 | 
			
		||||
            **raymarched.aux,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if raymarched.prev_stage:
 | 
			
		||||
            metrics.update(
 | 
			
		||||
                self._get_view_metrics(
 | 
			
		||||
                    raymarched.prev_stage,
 | 
			
		||||
                    xys,
 | 
			
		||||
                    image_rgb,
 | 
			
		||||
                    depth_map,
 | 
			
		||||
                    fg_probability,
 | 
			
		||||
                    mask_crop,
 | 
			
		||||
                    keys_prefix=(keys_prefix + "prev_stage_"),
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        return metrics
 | 
			
		||||
 | 
			
		||||
    @torch.no_grad()
 | 
			
		||||
    def _rasterize_mc_samples(
 | 
			
		||||
        self,
 | 
			
		||||
 | 
			
		||||
@ -6,64 +6,180 @@
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import Dict, Optional
 | 
			
		||||
from typing import Any, Dict, Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.implicitron.tools import metric_utils as utils
 | 
			
		||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
 | 
			
		||||
from pytorch3d.renderer import utils as rend_utils
 | 
			
		||||
 | 
			
		||||
from .renderer.base import RendererOutput
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RegularizationMetricsBase(ReplaceableBase, torch.nn.Module):
 | 
			
		||||
    """
 | 
			
		||||
    Replaceable abstract base for regularization metrics.
 | 
			
		||||
    `forward()` method produces regularization metrics and (unlike ViewMetrics) can
 | 
			
		||||
    depend on the model's parameters.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self, model: Any, keys_prefix: str = "loss_", **kwargs
 | 
			
		||||
    ) -> Dict[str, Any]:
 | 
			
		||||
        """
 | 
			
		||||
        Calculates various regularization terms useful for supervising differentiable
 | 
			
		||||
        rendering pipelines.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            model: A model instance. Useful, for example, to implement
 | 
			
		||||
                weights-based regularization.
 | 
			
		||||
            keys_prefix: A common prefix for all keys in the output dictionary
 | 
			
		||||
                containing all regularization metrics.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            A dictionary with the resulting regularization metrics. The items
 | 
			
		||||
                will have form `{metric_name_i: metric_value_i}` keyed by the
 | 
			
		||||
                names of the output metrics `metric_name_i` with their corresponding
 | 
			
		||||
                values `metric_value_i` represented as 0-dimensional float tensors.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ViewMetricsBase(ReplaceableBase, torch.nn.Module):
 | 
			
		||||
    """
 | 
			
		||||
    Replaceable abstract base for model metrics.
 | 
			
		||||
    `forward()` method produces losses and other metrics.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
class ViewMetrics(torch.nn.Module):
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        image_sampling_grid: torch.Tensor,
 | 
			
		||||
        images: Optional[torch.Tensor] = None,
 | 
			
		||||
        images_pred: Optional[torch.Tensor] = None,
 | 
			
		||||
        depths: Optional[torch.Tensor] = None,
 | 
			
		||||
        depths_pred: Optional[torch.Tensor] = None,
 | 
			
		||||
        masks: Optional[torch.Tensor] = None,
 | 
			
		||||
        masks_pred: Optional[torch.Tensor] = None,
 | 
			
		||||
        masks_crop: Optional[torch.Tensor] = None,
 | 
			
		||||
        grad_theta: Optional[torch.Tensor] = None,
 | 
			
		||||
        density_grid: Optional[torch.Tensor] = None,
 | 
			
		||||
        raymarched: RendererOutput,
 | 
			
		||||
        xys: torch.Tensor,
 | 
			
		||||
        image_rgb: Optional[torch.Tensor] = None,
 | 
			
		||||
        depth_map: Optional[torch.Tensor] = None,
 | 
			
		||||
        fg_probability: Optional[torch.Tensor] = None,
 | 
			
		||||
        mask_crop: Optional[torch.Tensor] = None,
 | 
			
		||||
        keys_prefix: str = "loss_",
 | 
			
		||||
        mask_renders_by_pred: bool = False,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Dict[str, torch.Tensor]:
 | 
			
		||||
    ) -> Dict[str, Any]:
 | 
			
		||||
        """
 | 
			
		||||
        Calculates various metrics and loss functions useful for supervising
 | 
			
		||||
        differentiable rendering pipelines. Any additional parameters can be passed
 | 
			
		||||
        in the `raymarched.aux` dictionary.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            results: A dictionary with the resulting view metrics. The items
 | 
			
		||||
                will have form `{metric_name_i: metric_value_i}` keyed by the
 | 
			
		||||
                names of the output metrics `metric_name_i` with their corresponding
 | 
			
		||||
                values `metric_value_i` represented as 0-dimensional float tensors.
 | 
			
		||||
            raymarched: Output of the renderer.
 | 
			
		||||
            xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which
 | 
			
		||||
                the predictions are defined. All ground truth inputs are sampled at
 | 
			
		||||
                these locations in order to extract values that correspond to the
 | 
			
		||||
                predictions.
 | 
			
		||||
            image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb
 | 
			
		||||
                values.
 | 
			
		||||
            depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth
 | 
			
		||||
                values.
 | 
			
		||||
            fg_probability: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth
 | 
			
		||||
                foreground masks.
 | 
			
		||||
            keys_prefix: A common prefix for all keys in the output dictionary
 | 
			
		||||
                containing all view metrics.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            A dictionary with the resulting view metrics. The items
 | 
			
		||||
                will have form `{metric_name_i: metric_value_i}` keyed by the
 | 
			
		||||
                names of the output metrics `metric_name_i` with their corresponding
 | 
			
		||||
                values `metric_value_i` represented as 0-dimensional float tensors.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.register
 | 
			
		||||
class RegularizationMetrics(RegularizationMetricsBase):
 | 
			
		||||
    def forward(
 | 
			
		||||
        self, model: Any, keys_prefix: str = "loss_", **kwargs
 | 
			
		||||
    ) -> Dict[str, Any]:
 | 
			
		||||
        """
 | 
			
		||||
        Calculates the AD penalty, or returns an empty dict if the model's autoencoder
 | 
			
		||||
        is inactive.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            model: A model instance.
 | 
			
		||||
            keys_prefix: A common prefix for all keys in the output dictionary
 | 
			
		||||
                containing all regularization metrics.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            A dictionary with the resulting regularization metrics. The items
 | 
			
		||||
                will have form `{metric_name_i: metric_value_i}` keyed by the
 | 
			
		||||
                names of the output metrics `metric_name_i` with their corresponding
 | 
			
		||||
                values `metric_value_i` represented as 0-dimensional float tensors.
 | 
			
		||||
 | 
			
		||||
            The calculated metric is:
 | 
			
		||||
                autoencoder_norm: Autoencoder weight norm regularization term.
 | 
			
		||||
        """
 | 
			
		||||
        metrics = {}
 | 
			
		||||
        if getattr(model, "sequence_autodecoder", None) is not None:
 | 
			
		||||
            ad_penalty = model.sequence_autodecoder.calc_squared_encoding_norm()
 | 
			
		||||
            if ad_penalty is not None:
 | 
			
		||||
                metrics["autodecoder_norm"] = ad_penalty
 | 
			
		||||
 | 
			
		||||
        if keys_prefix is not None:
 | 
			
		||||
            metrics = {(keys_prefix + k): v for k, v in metrics.items()}
 | 
			
		||||
 | 
			
		||||
        return metrics
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.register
 | 
			
		||||
class ViewMetrics(ViewMetricsBase):
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        raymarched: RendererOutput,
 | 
			
		||||
        xys: torch.Tensor,
 | 
			
		||||
        image_rgb: Optional[torch.Tensor] = None,
 | 
			
		||||
        depth_map: Optional[torch.Tensor] = None,
 | 
			
		||||
        fg_probability: Optional[torch.Tensor] = None,
 | 
			
		||||
        mask_crop: Optional[torch.Tensor] = None,
 | 
			
		||||
        keys_prefix: str = "loss_",
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Dict[str, Any]:
 | 
			
		||||
        """
 | 
			
		||||
        Calculates various differentiable metrics useful for supervising
 | 
			
		||||
        differentiable rendering pipelines.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            image_sampling_grid: A tensor of shape `(B, ..., 2)` containing 2D
 | 
			
		||||
                    image locations at which the predictions are defined.
 | 
			
		||||
                    All ground truth inputs are sampled at these
 | 
			
		||||
                    locations in order to extract values that correspond
 | 
			
		||||
                    to the predictions.
 | 
			
		||||
            images: A tensor of shape `(B, H, W, 3)` containing ground truth
 | 
			
		||||
                rgb values.
 | 
			
		||||
            images_pred: A tensor of shape `(B, ..., 3)` containing predicted
 | 
			
		||||
                rgb values.
 | 
			
		||||
            depths: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth
 | 
			
		||||
                depth values.
 | 
			
		||||
            depths_pred: A tensor of shape `(B, ..., 1)` containing predicted
 | 
			
		||||
                depth values.
 | 
			
		||||
            masks: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth
 | 
			
		||||
                foreground masks.
 | 
			
		||||
            masks_pred: A tensor of shape `(B, ..., 1)` containing predicted
 | 
			
		||||
                foreground masks.
 | 
			
		||||
            grad_theta: A tensor of shape `(B, ..., 3)` containing an evaluation
 | 
			
		||||
                of a gradient of a signed distance function w.r.t.
 | 
			
		||||
            results: A dict to store the results in.
 | 
			
		||||
            raymarched.features: Predicted rgb or feature values.
 | 
			
		||||
            raymarched.depths: A tensor of shape `(B, ..., 1)` containing
 | 
			
		||||
                predicted depth values.
 | 
			
		||||
            raymarched.masks: A tensor of shape `(B, ..., 1)` containing
 | 
			
		||||
                predicted foreground masks.
 | 
			
		||||
            raymarched.aux["grad_theta"]: A tensor of shape `(B, ..., 3)` containing an
 | 
			
		||||
                evaluation of a gradient of a signed distance function w.r.t.
 | 
			
		||||
                input 3D coordinates used to compute the eikonal loss.
 | 
			
		||||
            density_grid: A tensor of shape `(B, Hg, Wg, Dg, 1)` containing a
 | 
			
		||||
                `Hg x Wg x Dg` voxel grid of density values.
 | 
			
		||||
            raymarched.aux["density_grid"]: A tensor of shape `(B, Hg, Wg, Dg, 1)`
 | 
			
		||||
                containing a `Hg x Wg x Dg` voxel grid of density values.
 | 
			
		||||
            xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which
 | 
			
		||||
                the predictions are defined. All ground truth inputs are sampled at
 | 
			
		||||
                these locations in order to extract values that correspond to the
 | 
			
		||||
                predictions.
 | 
			
		||||
            image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb
 | 
			
		||||
                values.
 | 
			
		||||
            depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth
 | 
			
		||||
                values.
 | 
			
		||||
            fg_probability: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth
 | 
			
		||||
                foreground masks.
 | 
			
		||||
            keys_prefix: A common prefix for all keys in the output dictionary
 | 
			
		||||
                containing all metrics.
 | 
			
		||||
            mask_renders_by_pred: If `True`, masks rendered images by the predicted
 | 
			
		||||
                `masks_pred` prior to computing all rgb metrics.
 | 
			
		||||
                containing all view metrics.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            metrics: A dictionary `{metric_name_i: metric_value_i}` keyed by the
 | 
			
		||||
            A dictionary `{metric_name_i: metric_value_i}` keyed by the
 | 
			
		||||
                names of the output metrics `metric_name_i` with their corresponding
 | 
			
		||||
                values `metric_value_i` represented as 0-dimensional float tensors.
 | 
			
		||||
 | 
			
		||||
@ -91,109 +207,142 @@ class ViewMetrics(torch.nn.Module):
 | 
			
		||||
                    depth_neg_penalty: `min(depth_pred, 0)**2` penalizing negative
 | 
			
		||||
                        predicted depth values.
 | 
			
		||||
        """
 | 
			
		||||
        metrics = self._calculate_stage(
 | 
			
		||||
            raymarched,
 | 
			
		||||
            xys,
 | 
			
		||||
            image_rgb,
 | 
			
		||||
            depth_map,
 | 
			
		||||
            fg_probability,
 | 
			
		||||
            mask_crop,
 | 
			
		||||
            keys_prefix,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if raymarched.prev_stage:
 | 
			
		||||
            metrics.update(
 | 
			
		||||
                self(
 | 
			
		||||
                    raymarched.prev_stage,
 | 
			
		||||
                    xys,
 | 
			
		||||
                    image_rgb,
 | 
			
		||||
                    depth_map,
 | 
			
		||||
                    fg_probability,
 | 
			
		||||
                    mask_crop,
 | 
			
		||||
                    keys_prefix=(keys_prefix + "prev_stage_"),
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        return metrics
 | 
			
		||||
 | 
			
		||||
    def _calculate_stage(
 | 
			
		||||
        self,
 | 
			
		||||
        raymarched: RendererOutput,
 | 
			
		||||
        xys: torch.Tensor,
 | 
			
		||||
        image_rgb: Optional[torch.Tensor] = None,
 | 
			
		||||
        depth_map: Optional[torch.Tensor] = None,
 | 
			
		||||
        fg_probability: Optional[torch.Tensor] = None,
 | 
			
		||||
        mask_crop: Optional[torch.Tensor] = None,
 | 
			
		||||
        keys_prefix: str = "loss_",
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Dict[str, Any]:
 | 
			
		||||
        """
 | 
			
		||||
        Calculate metrics for the current stage.
 | 
			
		||||
        """
 | 
			
		||||
        # TODO: extract functions
 | 
			
		||||
 | 
			
		||||
        # reshape from B x ... x DIM to B x DIM x -1 x 1
 | 
			
		||||
        images_pred, masks_pred, depths_pred = [
 | 
			
		||||
            _reshape_nongrid_var(x) for x in [images_pred, masks_pred, depths_pred]
 | 
			
		||||
        image_rgb_pred, fg_probability_pred, depth_map_pred = [
 | 
			
		||||
            _reshape_nongrid_var(x)
 | 
			
		||||
            for x in [raymarched.features, raymarched.masks, raymarched.depths]
 | 
			
		||||
        ]
 | 
			
		||||
        # reshape the sampling grid as well
 | 
			
		||||
        # TODO: we can get rid of the singular dimension here and in _reshape_nongrid_var
 | 
			
		||||
        # now that we use rend_utils.ndc_grid_sample
 | 
			
		||||
        image_sampling_grid = image_sampling_grid.reshape(
 | 
			
		||||
            image_sampling_grid.shape[0], -1, 1, 2
 | 
			
		||||
        )
 | 
			
		||||
        xys = xys.reshape(xys.shape[0], -1, 1, 2)
 | 
			
		||||
 | 
			
		||||
        # closure with the given image_sampling_grid
 | 
			
		||||
        # closure with the given xys
 | 
			
		||||
        def sample(tensor, mode):
 | 
			
		||||
            if tensor is None:
 | 
			
		||||
                return tensor
 | 
			
		||||
            return rend_utils.ndc_grid_sample(tensor, image_sampling_grid, mode=mode)
 | 
			
		||||
            return rend_utils.ndc_grid_sample(tensor, xys, mode=mode)
 | 
			
		||||
 | 
			
		||||
        # eval all results in this size
 | 
			
		||||
        images = sample(images, mode="bilinear")
 | 
			
		||||
        depths = sample(depths, mode="nearest")
 | 
			
		||||
        masks = sample(masks, mode="nearest")
 | 
			
		||||
        masks_crop = sample(masks_crop, mode="nearest")
 | 
			
		||||
        if masks_crop is None and images_pred is not None:
 | 
			
		||||
            masks_crop = torch.ones_like(images_pred[:, :1])
 | 
			
		||||
        if masks_crop is None and depths_pred is not None:
 | 
			
		||||
            masks_crop = torch.ones_like(depths_pred[:, :1])
 | 
			
		||||
        image_rgb = sample(image_rgb, mode="bilinear")
 | 
			
		||||
        depth_map = sample(depth_map, mode="nearest")
 | 
			
		||||
        fg_probability = sample(fg_probability, mode="nearest")
 | 
			
		||||
        mask_crop = sample(mask_crop, mode="nearest")
 | 
			
		||||
        if mask_crop is None and image_rgb_pred is not None:
 | 
			
		||||
            mask_crop = torch.ones_like(image_rgb_pred[:, :1])
 | 
			
		||||
        if mask_crop is None and depth_map_pred is not None:
 | 
			
		||||
            mask_crop = torch.ones_like(depth_map_pred[:, :1])
 | 
			
		||||
 | 
			
		||||
        preds = {}
 | 
			
		||||
        if images is not None and images_pred is not None:
 | 
			
		||||
            # TODO: mask_renders_by_pred is always false; simplify
 | 
			
		||||
            preds.update(
 | 
			
		||||
        metrics = {}
 | 
			
		||||
        if image_rgb is not None and image_rgb_pred is not None:
 | 
			
		||||
            metrics.update(
 | 
			
		||||
                _rgb_metrics(
 | 
			
		||||
                    images,
 | 
			
		||||
                    images_pred,
 | 
			
		||||
                    masks,
 | 
			
		||||
                    masks_pred,
 | 
			
		||||
                    masks_crop,
 | 
			
		||||
                    mask_renders_by_pred,
 | 
			
		||||
                    image_rgb,
 | 
			
		||||
                    image_rgb_pred,
 | 
			
		||||
                    fg_probability,
 | 
			
		||||
                    fg_probability_pred,
 | 
			
		||||
                    mask_crop,
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if masks_pred is not None:
 | 
			
		||||
            preds["mask_beta_prior"] = utils.beta_prior(masks_pred)
 | 
			
		||||
        if masks is not None and masks_pred is not None:
 | 
			
		||||
            preds["mask_neg_iou"] = utils.neg_iou_loss(
 | 
			
		||||
                masks_pred, masks, mask=masks_crop
 | 
			
		||||
        if fg_probability_pred is not None:
 | 
			
		||||
            metrics["mask_beta_prior"] = utils.beta_prior(fg_probability_pred)
 | 
			
		||||
        if fg_probability is not None and fg_probability_pred is not None:
 | 
			
		||||
            metrics["mask_neg_iou"] = utils.neg_iou_loss(
 | 
			
		||||
                fg_probability_pred, fg_probability, mask=mask_crop
 | 
			
		||||
            )
 | 
			
		||||
            metrics["mask_bce"] = utils.calc_bce(
 | 
			
		||||
                fg_probability_pred, fg_probability, mask=mask_crop
 | 
			
		||||
            )
 | 
			
		||||
            preds["mask_bce"] = utils.calc_bce(masks_pred, masks, mask=masks_crop)
 | 
			
		||||
 | 
			
		||||
        if depths is not None and depths_pred is not None:
 | 
			
		||||
            assert masks_crop is not None
 | 
			
		||||
        if depth_map is not None and depth_map_pred is not None:
 | 
			
		||||
            assert mask_crop is not None
 | 
			
		||||
            _, abs_ = utils.eval_depth(
 | 
			
		||||
                depths_pred, depths, get_best_scale=True, mask=masks_crop, crop=0
 | 
			
		||||
                depth_map_pred, depth_map, get_best_scale=True, mask=mask_crop, crop=0
 | 
			
		||||
            )
 | 
			
		||||
            preds["depth_abs"] = abs_.mean()
 | 
			
		||||
            metrics["depth_abs"] = abs_.mean()
 | 
			
		||||
 | 
			
		||||
            if masks is not None:
 | 
			
		||||
                mask = masks * masks_crop
 | 
			
		||||
            if fg_probability is not None:
 | 
			
		||||
                mask = fg_probability * mask_crop
 | 
			
		||||
                _, abs_ = utils.eval_depth(
 | 
			
		||||
                    depths_pred, depths, get_best_scale=True, mask=mask, crop=0
 | 
			
		||||
                    depth_map_pred, depth_map, get_best_scale=True, mask=mask, crop=0
 | 
			
		||||
                )
 | 
			
		||||
                preds["depth_abs_fg"] = abs_.mean()
 | 
			
		||||
                metrics["depth_abs_fg"] = abs_.mean()
 | 
			
		||||
 | 
			
		||||
        # regularizers
 | 
			
		||||
        grad_theta = raymarched.aux.get("grad_theta")
 | 
			
		||||
        if grad_theta is not None:
 | 
			
		||||
            preds["eikonal"] = _get_eikonal_loss(grad_theta)
 | 
			
		||||
            metrics["eikonal"] = _get_eikonal_loss(grad_theta)
 | 
			
		||||
 | 
			
		||||
        density_grid = raymarched.aux.get("density_grid")
 | 
			
		||||
        if density_grid is not None:
 | 
			
		||||
            preds["density_tv"] = _get_grid_tv_loss(density_grid)
 | 
			
		||||
            metrics["density_tv"] = _get_grid_tv_loss(density_grid)
 | 
			
		||||
 | 
			
		||||
        if depths_pred is not None:
 | 
			
		||||
            preds["depth_neg_penalty"] = _get_depth_neg_penalty_loss(depths_pred)
 | 
			
		||||
        if depth_map_pred is not None:
 | 
			
		||||
            metrics["depth_neg_penalty"] = _get_depth_neg_penalty_loss(depth_map_pred)
 | 
			
		||||
 | 
			
		||||
        if keys_prefix is not None:
 | 
			
		||||
            preds = {(keys_prefix + k): v for k, v in preds.items()}
 | 
			
		||||
            metrics = {(keys_prefix + k): v for k, v in metrics.items()}
 | 
			
		||||
 | 
			
		||||
        return preds
 | 
			
		||||
        return metrics
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _rgb_metrics(
 | 
			
		||||
    images, images_pred, masks, masks_pred, masks_crop, mask_renders_by_pred
 | 
			
		||||
):
 | 
			
		||||
def _rgb_metrics(images, images_pred, masks, masks_pred, masks_crop):
 | 
			
		||||
    assert masks_crop is not None
 | 
			
		||||
    if mask_renders_by_pred:
 | 
			
		||||
        images = images[..., masks_pred.reshape(-1), :]
 | 
			
		||||
        masks_crop = masks_crop[..., masks_pred.reshape(-1), :]
 | 
			
		||||
        masks = masks is not None and masks[..., masks_pred.reshape(-1), :]
 | 
			
		||||
    rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True)
 | 
			
		||||
    rgb_loss = utils.huber(rgb_squared, scaling=0.03)
 | 
			
		||||
    crop_mass = masks_crop.sum().clamp(1.0)
 | 
			
		||||
    preds = {
 | 
			
		||||
    results = {
 | 
			
		||||
        "rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass,
 | 
			
		||||
        "rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass,
 | 
			
		||||
        "rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop),
 | 
			
		||||
    }
 | 
			
		||||
    if masks is not None:
 | 
			
		||||
        masks = masks_crop * masks
 | 
			
		||||
        preds["rgb_psnr_fg"] = utils.calc_psnr(images_pred, images, mask=masks)
 | 
			
		||||
        preds["rgb_mse_fg"] = (rgb_squared * masks).sum() / masks.sum().clamp(1.0)
 | 
			
		||||
    return preds
 | 
			
		||||
        results["rgb_psnr_fg"] = utils.calc_psnr(images_pred, images, mask=masks)
 | 
			
		||||
        results["rgb_mse_fg"] = (rgb_squared * masks).sum() / masks.sum().clamp(1.0)
 | 
			
		||||
    return results
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_eikonal_loss(grad_theta):
 | 
			
		||||
 | 
			
		||||
@ -20,6 +20,8 @@ renderer_class_type: LSTMRenderer
 | 
			
		||||
image_feature_extractor_class_type: ResNetFeatureExtractor
 | 
			
		||||
view_pooler_enabled: true
 | 
			
		||||
implicit_function_class_type: IdrFeatureField
 | 
			
		||||
view_metrics_class_type: ViewMetrics
 | 
			
		||||
regularization_metrics_class_type: RegularizationMetrics
 | 
			
		||||
loss_weights:
 | 
			
		||||
  loss_rgb_mse: 1.0
 | 
			
		||||
  loss_prev_stage_rgb_mse: 1.0
 | 
			
		||||
@ -122,3 +124,5 @@ implicit_function_IdrFeatureField_args:
 | 
			
		||||
  n_harmonic_functions_xyz: 1729
 | 
			
		||||
  pooled_feature_dim: 0
 | 
			
		||||
  encoding_dim: 0
 | 
			
		||||
view_metrics_ViewMetrics_args: {}
 | 
			
		||||
regularization_metrics_RegularizationMetrics_args: {}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user