mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Add utils to approximate the conical frustums as multivariate gaussians.
Summary: Introduce methods to approximate the radii of conical frustums along rays as described in [MipNerf](https://arxiv.org/abs/2103.13415): - Two new attributes are added to ImplicitronRayBundle: bins and radii. Bins is of size n_pts_per_ray + 1. It allows us to manipulate easily and n_pts_per_ray intervals. For example we need the intervals coordinates in the radii computation for \(t_{\mu}, t_{\delta}\). Radii are used to store the radii of the conical frustums. - Add 3 new methods to compute the radii: - approximate_conical_frustum_as_gaussians: It computes the mean along the ray direction, the variance of the conical frustum with respect to t and variance of the conical frustum with respect to its radius. This implementation follows the stable computation defined in the paper. - compute_3d_diagonal_covariance_gaussian: Will leverage the two previously computed variances to find the diagonal covariance of the Gaussian. - conical_frustum_to_gaussian: Mix everything together to compute the means and the diagonal covariances along the ray of the Gaussians. - In AbstractMaskRaySampler, introduces the attribute `cast_ray_bundle_as_cone`. If False it won't change the previous behaviour of the RaySampler. However if True, the samplers will sample `n_pts_per_ray +1` instead of `n_pts_per_ray`. This points are then used to set the bins attribute of ImplicitronRayBundle. The support of HeterogeneousRayBundle has not been added since the current code does not allow it. A safeguard has been added to avoid a silent bug in the future. Reviewed By: shapovalov Differential Revision: D45269190 fbshipit-source-id: bf22fad12d71d55392f054e3f680013aa0d59b78
This commit is contained in:
		
							parent
							
								
									4e7715ce66
								
							
						
					
					
						commit
						29b8ebd802
					
				@ -216,6 +216,7 @@ model_factory_ImplicitronModelFactory_args:
 | 
				
			|||||||
      n_rays_total_training: null
 | 
					      n_rays_total_training: null
 | 
				
			||||||
      stratified_point_sampling_training: true
 | 
					      stratified_point_sampling_training: true
 | 
				
			||||||
      stratified_point_sampling_evaluation: false
 | 
					      stratified_point_sampling_evaluation: false
 | 
				
			||||||
 | 
					      cast_ray_bundle_as_cone: false
 | 
				
			||||||
      scene_extent: 8.0
 | 
					      scene_extent: 8.0
 | 
				
			||||||
      scene_center:
 | 
					      scene_center:
 | 
				
			||||||
      - 0.0
 | 
					      - 0.0
 | 
				
			||||||
@ -228,6 +229,7 @@ model_factory_ImplicitronModelFactory_args:
 | 
				
			|||||||
      n_rays_total_training: null
 | 
					      n_rays_total_training: null
 | 
				
			||||||
      stratified_point_sampling_training: true
 | 
					      stratified_point_sampling_training: true
 | 
				
			||||||
      stratified_point_sampling_evaluation: false
 | 
					      stratified_point_sampling_evaluation: false
 | 
				
			||||||
 | 
					      cast_ray_bundle_as_cone: false
 | 
				
			||||||
      min_depth: 0.1
 | 
					      min_depth: 0.1
 | 
				
			||||||
      max_depth: 8.0
 | 
					      max_depth: 8.0
 | 
				
			||||||
    renderer_LSTMRenderer_args:
 | 
					    renderer_LSTMRenderer_args:
 | 
				
			||||||
@ -642,6 +644,7 @@ model_factory_ImplicitronModelFactory_args:
 | 
				
			|||||||
      n_rays_total_training: null
 | 
					      n_rays_total_training: null
 | 
				
			||||||
      stratified_point_sampling_training: true
 | 
					      stratified_point_sampling_training: true
 | 
				
			||||||
      stratified_point_sampling_evaluation: false
 | 
					      stratified_point_sampling_evaluation: false
 | 
				
			||||||
 | 
					      cast_ray_bundle_as_cone: false
 | 
				
			||||||
      scene_extent: 8.0
 | 
					      scene_extent: 8.0
 | 
				
			||||||
      scene_center:
 | 
					      scene_center:
 | 
				
			||||||
      - 0.0
 | 
					      - 0.0
 | 
				
			||||||
@ -654,6 +657,7 @@ model_factory_ImplicitronModelFactory_args:
 | 
				
			|||||||
      n_rays_total_training: null
 | 
					      n_rays_total_training: null
 | 
				
			||||||
      stratified_point_sampling_training: true
 | 
					      stratified_point_sampling_training: true
 | 
				
			||||||
      stratified_point_sampling_evaluation: false
 | 
					      stratified_point_sampling_evaluation: false
 | 
				
			||||||
 | 
					      cast_ray_bundle_as_cone: false
 | 
				
			||||||
      min_depth: 0.1
 | 
					      min_depth: 0.1
 | 
				
			||||||
      max_depth: 8.0
 | 
					      max_depth: 8.0
 | 
				
			||||||
    renderer_LSTMRenderer_args:
 | 
					    renderer_LSTMRenderer_args:
 | 
				
			||||||
 | 
				
			|||||||
@ -16,6 +16,7 @@ from typing import Any, Dict, List, Optional, Tuple
 | 
				
			|||||||
import torch
 | 
					import torch
 | 
				
			||||||
from pytorch3d.implicitron.tools.config import ReplaceableBase
 | 
					from pytorch3d.implicitron.tools.config import ReplaceableBase
 | 
				
			||||||
from pytorch3d.ops import packed_to_padded
 | 
					from pytorch3d.ops import packed_to_padded
 | 
				
			||||||
 | 
					from pytorch3d.renderer.implicit.utils import ray_bundle_variables_to_ray_points
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class EvaluationMode(Enum):
 | 
					class EvaluationMode(Enum):
 | 
				
			||||||
@ -47,6 +48,27 @@ class ImplicitronRayBundle:
 | 
				
			|||||||
        camera_counts: A tensor of shape (N, ) which how many times the
 | 
					        camera_counts: A tensor of shape (N, ) which how many times the
 | 
				
			||||||
            coresponding camera in `camera_ids` was sampled.
 | 
					            coresponding camera in `camera_ids` was sampled.
 | 
				
			||||||
            `sum(camera_counts) == minibatch`, where `minibatch = origins.shape[0]`.
 | 
					            `sum(camera_counts) == minibatch`, where `minibatch = origins.shape[0]`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Attributes:
 | 
				
			||||||
 | 
					        origins: A tensor of shape `(..., 3)` denoting the
 | 
				
			||||||
 | 
					            origins of the sampling rays in world coords.
 | 
				
			||||||
 | 
					        directions: A tensor of shape `(..., 3)` containing the direction
 | 
				
			||||||
 | 
					            vectors of sampling rays in world coords. They don't have to be normalized;
 | 
				
			||||||
 | 
					            they define unit vectors in the respective 1D coordinate systems; see
 | 
				
			||||||
 | 
					            documentation for :func:`ray_bundle_to_ray_points` for the conversion formula.
 | 
				
			||||||
 | 
					        lengths: A tensor of shape `(..., num_points_per_ray)`
 | 
				
			||||||
 | 
					            containing the lengths at which the rays are sampled.
 | 
				
			||||||
 | 
					        xys: A tensor of shape `(..., 2)`, the xy-locations (`xys`) of the ray pixels
 | 
				
			||||||
 | 
					        camera_ids: An optional tensor of shape (N, ) which indicates which camera
 | 
				
			||||||
 | 
					            was used to sample the rays. `N` is the number of unique sampled cameras.
 | 
				
			||||||
 | 
					        camera_counts: An optional tensor of shape (N, ) indicates how many times the
 | 
				
			||||||
 | 
					            coresponding camera in `camera_ids` was sampled.
 | 
				
			||||||
 | 
					            `sum(camera_counts)==total_number_of_rays`.
 | 
				
			||||||
 | 
					        bins: An optional tensor of shape `(..., num_points_per_ray + 1)`
 | 
				
			||||||
 | 
					            containing the bins at which the rays are sampled. In this case
 | 
				
			||||||
 | 
					            lengths should be equal to the midpoints of bins `(..., num_points_per_ray)`.
 | 
				
			||||||
 | 
					        pixel_radii_2d: An optional tensor of shape `(..., 1)`
 | 
				
			||||||
 | 
					            base radii of the conical frustums.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    origins: torch.Tensor
 | 
					    origins: torch.Tensor
 | 
				
			||||||
@ -55,6 +77,45 @@ class ImplicitronRayBundle:
 | 
				
			|||||||
    xys: torch.Tensor
 | 
					    xys: torch.Tensor
 | 
				
			||||||
    camera_ids: Optional[torch.LongTensor] = None
 | 
					    camera_ids: Optional[torch.LongTensor] = None
 | 
				
			||||||
    camera_counts: Optional[torch.LongTensor] = None
 | 
					    camera_counts: Optional[torch.LongTensor] = None
 | 
				
			||||||
 | 
					    bins: Optional[torch.Tensor] = None
 | 
				
			||||||
 | 
					    pixel_radii_2d: Optional[torch.Tensor] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def from_bins(
 | 
				
			||||||
 | 
					        cls,
 | 
				
			||||||
 | 
					        origins: torch.Tensor,
 | 
				
			||||||
 | 
					        directions: torch.Tensor,
 | 
				
			||||||
 | 
					        bins: torch.Tensor,
 | 
				
			||||||
 | 
					        xys: torch.Tensor,
 | 
				
			||||||
 | 
					        **kwargs,
 | 
				
			||||||
 | 
					    ) -> "ImplicitronRayBundle":
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Creates a new instance from bins instead of lengths.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Attributes:
 | 
				
			||||||
 | 
					            origins: A tensor of shape `(..., 3)` denoting the
 | 
				
			||||||
 | 
					                origins of the sampling rays in world coords.
 | 
				
			||||||
 | 
					            directions: A tensor of shape `(..., 3)` containing the direction
 | 
				
			||||||
 | 
					                vectors of sampling rays in world coords. They don't have to be normalized;
 | 
				
			||||||
 | 
					                they define unit vectors in the respective 1D coordinate systems; see
 | 
				
			||||||
 | 
					                documentation for :func:`ray_bundle_to_ray_points` for the conversion formula.
 | 
				
			||||||
 | 
					            bins: A tensor of shape `(..., num_points_per_ray + 1)`
 | 
				
			||||||
 | 
					                containing the bins at which the rays are sampled. In this case
 | 
				
			||||||
 | 
					                lengths is equal to the midpoints of bins `(..., num_points_per_ray)`.
 | 
				
			||||||
 | 
					            xys: A tensor of shape `(..., 2)`, the xy-locations (`xys`) of the ray pixels
 | 
				
			||||||
 | 
					            kwargs: Additional arguments passed to the constructor of ImplicitronRayBundle
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            An instance of ImplicitronRayBundle.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if bins.shape[-1] <= 1:
 | 
				
			||||||
 | 
					            raise ValueError(
 | 
				
			||||||
 | 
					                "The last dim of bins must be at least superior or equal to 2."
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        # equivalent to: 0.5 * (bins[..., 1:] + bins[..., :-1]) but more efficient
 | 
				
			||||||
 | 
					        lengths = torch.lerp(bins[..., 1:], bins[..., :-1], 0.5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return cls(origins, directions, lengths, xys, bins=bins, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def is_packed(self) -> bool:
 | 
					    def is_packed(self) -> bool:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
@ -195,3 +256,154 @@ class BaseRenderer(ABC, ReplaceableBase):
 | 
				
			|||||||
            instance of RendererOutput
 | 
					            instance of RendererOutput
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        pass
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def compute_3d_diagonal_covariance_gaussian(
 | 
				
			||||||
 | 
					    rays_directions: torch.Tensor,
 | 
				
			||||||
 | 
					    rays_dir_variance: torch.Tensor,
 | 
				
			||||||
 | 
					    radii_variance: torch.Tensor,
 | 
				
			||||||
 | 
					    eps: float = 1e-6,
 | 
				
			||||||
 | 
					) -> torch.Tensor:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Transform the variances (rays_dir_variance, radii_variance) of the gaussians from
 | 
				
			||||||
 | 
					    the coordinate frame of the conical frustum to 3D world coordinates.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    It follows the equation 16 of `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        rays_directions: A tensor of shape `(..., 3)`
 | 
				
			||||||
 | 
					        rays_dir_variance: A tensor of shape `(..., num_intervals)` representing
 | 
				
			||||||
 | 
					            the variance of the conical frustum  with respect to the rays direction.
 | 
				
			||||||
 | 
					        radii_variance: A tensor of shape `(..., num_intervals)` representing
 | 
				
			||||||
 | 
					            the variance of the conical frustum with respect to its radius.
 | 
				
			||||||
 | 
					        eps: a small number to prevent division by zero.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					        A tensor of shape `(..., num_intervals, 3)` containing the diagonal
 | 
				
			||||||
 | 
					            of the covariance matrix.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    d_outer_diag = torch.pow(rays_directions, 2)
 | 
				
			||||||
 | 
					    dir_mag_sq = torch.clamp(torch.sum(d_outer_diag, dim=-1, keepdim=True), min=eps)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    null_outer_diag = 1 - d_outer_diag / dir_mag_sq
 | 
				
			||||||
 | 
					    ray_dir_cov_diag = rays_dir_variance[..., None] * d_outer_diag[..., None, :]
 | 
				
			||||||
 | 
					    xy_cov_diag = radii_variance[..., None] * null_outer_diag[..., None, :]
 | 
				
			||||||
 | 
					    return ray_dir_cov_diag + xy_cov_diag
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def approximate_conical_frustum_as_gaussians(
 | 
				
			||||||
 | 
					    bins: torch.Tensor, radii: torch.Tensor
 | 
				
			||||||
 | 
					) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Approximates a conical frustum as two Gaussian distributions.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The Gaussian distributions are characterized by
 | 
				
			||||||
 | 
					    three values:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    - rays_dir_mean: mean along the rays direction
 | 
				
			||||||
 | 
					        (defined as t in the parametric representation of a cone).
 | 
				
			||||||
 | 
					    - rays_dir_variance: the variance of the conical frustum  along the rays direction.
 | 
				
			||||||
 | 
					    - radii_variance: variance of the conical frustum with respect to its radius.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The computation is stable and follows equation 7
 | 
				
			||||||
 | 
					    of `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    For more information on how the mean and variances are computed
 | 
				
			||||||
 | 
					    refers to the appendix of the paper.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        bins: A tensor of shape `(..., num_points_per_ray + 1)`
 | 
				
			||||||
 | 
					            containing the bins at which the rays are sampled.
 | 
				
			||||||
 | 
					            `bin[..., t]` and `bin[..., t+1]` represent respectively
 | 
				
			||||||
 | 
					            the left and right coordinates of the interval.
 | 
				
			||||||
 | 
					        t0: A tensor of shape `(..., num_points_per_ray)`
 | 
				
			||||||
 | 
					            containing the left coordinates of the intervals
 | 
				
			||||||
 | 
					            on which the rays are sampled.
 | 
				
			||||||
 | 
					        t1: A tensor of shape `(..., num_points_per_ray)`
 | 
				
			||||||
 | 
					            containing the rights coordinates of the intervals
 | 
				
			||||||
 | 
					            on which the rays are sampled.
 | 
				
			||||||
 | 
					        radii: A tensor of shape `(..., 1)`
 | 
				
			||||||
 | 
					            base radii of the conical frustums.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					        rays_dir_mean: A tensor of shape `(..., num_intervals)` representing
 | 
				
			||||||
 | 
					            the mean along the rays direction
 | 
				
			||||||
 | 
					            (t in the parametric represention of the cone)
 | 
				
			||||||
 | 
					        rays_dir_variance: A tensor of shape `(...,  num_intervals)` representing
 | 
				
			||||||
 | 
					            the variance of the conical frustum along the rays
 | 
				
			||||||
 | 
					            (t in the parametric represention of the cone).
 | 
				
			||||||
 | 
					        radii_variance: A tensor of shape `(..., num_intervals)` representing
 | 
				
			||||||
 | 
					            the variance of the conical frustum with respect to its radius.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    t_mu = torch.lerp(bins[..., 1:], bins[..., :-1], 0.5)
 | 
				
			||||||
 | 
					    t_delta = torch.diff(bins, dim=-1) / 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    t_mu_pow2 = torch.pow(t_mu, 2)
 | 
				
			||||||
 | 
					    t_delta_pow2 = torch.pow(t_delta, 2)
 | 
				
			||||||
 | 
					    t_delta_pow4 = torch.pow(t_delta, 4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    den = 3 * t_mu_pow2 + t_delta_pow2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # mean along the rays direction
 | 
				
			||||||
 | 
					    rays_dir_mean = t_mu + 2 * t_mu * t_delta_pow2 / den
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Variance of the conical frustum with along the rays directions
 | 
				
			||||||
 | 
					    rays_dir_variance = t_delta_pow2 / 3 - (4 / 15) * (
 | 
				
			||||||
 | 
					        t_delta_pow4 * (12 * t_mu_pow2 - t_delta_pow2) / torch.pow(den, 2)
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Variance of the conical frustum with respect to its radius
 | 
				
			||||||
 | 
					    radii_variance = torch.pow(radii, 2) * (
 | 
				
			||||||
 | 
					        t_mu_pow2 / 4 + (5 / 12) * t_delta_pow2 - 4 / 15 * (t_delta_pow4) / den
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    return rays_dir_mean, rays_dir_variance, radii_variance
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def conical_frustum_to_gaussian(
 | 
				
			||||||
 | 
					    ray_bundle: ImplicitronRayBundle,
 | 
				
			||||||
 | 
					) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Approximate a conical frustum following a ray bundle as a Gaussian.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        ray_bundle: A `RayBundle` or `HeterogeneousRayBundle` object with fields:
 | 
				
			||||||
 | 
					            origins: A tensor of shape `(..., 3)`
 | 
				
			||||||
 | 
					            directions: A tensor of shape `(..., 3)`
 | 
				
			||||||
 | 
					            lengths: A tensor of shape `(..., num_points_per_ray)`
 | 
				
			||||||
 | 
					            bins: A tensor of shape `(..., num_points_per_ray + 1)`
 | 
				
			||||||
 | 
					                containing the bins at which the rays are sampled. .
 | 
				
			||||||
 | 
					            pixel_radii_2d: A tensor of shape `(..., 1)`
 | 
				
			||||||
 | 
					                base radii of the conical frustums.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					        means: A tensor of shape `(..., num_points_per_ray - 1, 3)`
 | 
				
			||||||
 | 
					            representing the means of the Gaussians
 | 
				
			||||||
 | 
					            approximating the conical frustums.
 | 
				
			||||||
 | 
					        diag_covariances: A tensor of shape `(...,num_points_per_ray -1, 3)`
 | 
				
			||||||
 | 
					            representing the diagonal covariance matrices of our Gaussians.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if ray_bundle.pixel_radii_2d is None or ray_bundle.bins is None:
 | 
				
			||||||
 | 
					        raise ValueError(
 | 
				
			||||||
 | 
					            "RayBundle pixel_radii_2d or bins have not been provided."
 | 
				
			||||||
 | 
					            " Look at pytorch3d.renderer.implicit.renderer.ray_sampler::"
 | 
				
			||||||
 | 
					            "AbstractMaskRaySampler to see how to compute them. Have you forgot to set"
 | 
				
			||||||
 | 
					            "`cast_ray_bundle_as_cone` to True?"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    (
 | 
				
			||||||
 | 
					        rays_dir_mean,
 | 
				
			||||||
 | 
					        rays_dir_variance,
 | 
				
			||||||
 | 
					        radii_variance,
 | 
				
			||||||
 | 
					    ) = approximate_conical_frustum_as_gaussians(
 | 
				
			||||||
 | 
					        ray_bundle.bins,
 | 
				
			||||||
 | 
					        ray_bundle.pixel_radii_2d,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    means = ray_bundle_variables_to_ray_points(
 | 
				
			||||||
 | 
					        ray_bundle.origins, ray_bundle.directions, rays_dir_mean
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    diag_covariances = compute_3d_diagonal_covariance_gaussian(
 | 
				
			||||||
 | 
					        ray_bundle.directions, rays_dir_variance, radii_variance
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    return means, diag_covariances
 | 
				
			||||||
 | 
				
			|||||||
@ -11,6 +11,7 @@ from pytorch3d.implicitron.tools import camera_utils
 | 
				
			|||||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
 | 
					from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
 | 
				
			||||||
from pytorch3d.renderer import NDCMultinomialRaysampler
 | 
					from pytorch3d.renderer import NDCMultinomialRaysampler
 | 
				
			||||||
from pytorch3d.renderer.cameras import CamerasBase
 | 
					from pytorch3d.renderer.cameras import CamerasBase
 | 
				
			||||||
 | 
					from pytorch3d.renderer.implicit.utils import HeterogeneousRayBundle
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .base import EvaluationMode, ImplicitronRayBundle, RenderSamplingMode
 | 
					from .base import EvaluationMode, ImplicitronRayBundle, RenderSamplingMode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -83,7 +84,20 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
 | 
				
			|||||||
        stratified_point_sampling_training: if set, performs stratified random sampling
 | 
					        stratified_point_sampling_training: if set, performs stratified random sampling
 | 
				
			||||||
            along the ray; otherwise takes ray points at deterministic offsets.
 | 
					            along the ray; otherwise takes ray points at deterministic offsets.
 | 
				
			||||||
        stratified_point_sampling_evaluation: Same as above but for evaluation.
 | 
					        stratified_point_sampling_evaluation: Same as above but for evaluation.
 | 
				
			||||||
 | 
					        cast_ray_bundle_as_cone: If True, the sampling will generate the bins and radii
 | 
				
			||||||
 | 
					            attribute of ImplicitronRayBundle. The `bins` contain the z-coordinate
 | 
				
			||||||
 | 
					            (=depth) of each ray in world units and are of shape
 | 
				
			||||||
 | 
					            `(batch_size, n_rays_per_image, n_pts_per_ray_training/evaluation + 1)`
 | 
				
			||||||
 | 
					            while `lengths` is equal to the midpoint of the bins:
 | 
				
			||||||
 | 
					            (0.5 * (bins[..., 1:] + bins[..., :-1]).
 | 
				
			||||||
 | 
					            If False, `bins` is None, `radii` is None and `lengths` contains
 | 
				
			||||||
 | 
					            the z-coordinate (=depth) of each ray in world units and are of shape
 | 
				
			||||||
 | 
					            `(batch_size, n_rays_per_image, n_pts_per_ray_training/evaluation)`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Raises:
 | 
				
			||||||
 | 
					        TypeError: if cast_ray_bundle_as_cone is set to True and n_rays_total_training
 | 
				
			||||||
 | 
					            is not None will result in an error. HeterogeneousRayBundle is
 | 
				
			||||||
 | 
					            not supported for conical frustum computation yet.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    image_width: int = 400
 | 
					    image_width: int = 400
 | 
				
			||||||
@ -97,6 +111,7 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
 | 
				
			|||||||
    # stratified sampling vs taking points at deterministic offsets
 | 
					    # stratified sampling vs taking points at deterministic offsets
 | 
				
			||||||
    stratified_point_sampling_training: bool = True
 | 
					    stratified_point_sampling_training: bool = True
 | 
				
			||||||
    stratified_point_sampling_evaluation: bool = False
 | 
					    stratified_point_sampling_evaluation: bool = False
 | 
				
			||||||
 | 
					    cast_ray_bundle_as_cone: bool = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __post_init__(self):
 | 
					    def __post_init__(self):
 | 
				
			||||||
        if (self.n_rays_per_image_sampled_from_mask is not None) and (
 | 
					        if (self.n_rays_per_image_sampled_from_mask is not None) and (
 | 
				
			||||||
@ -114,10 +129,20 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
 | 
				
			|||||||
            ),
 | 
					            ),
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        n_pts_per_ray_training = (
 | 
				
			||||||
 | 
					            self.n_pts_per_ray_training + 1
 | 
				
			||||||
 | 
					            if self.cast_ray_bundle_as_cone
 | 
				
			||||||
 | 
					            else self.n_pts_per_ray_training
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        n_pts_per_ray_evaluation = (
 | 
				
			||||||
 | 
					            self.n_pts_per_ray_evaluation + 1
 | 
				
			||||||
 | 
					            if self.cast_ray_bundle_as_cone
 | 
				
			||||||
 | 
					            else self.n_pts_per_ray_evaluation
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        self._training_raysampler = NDCMultinomialRaysampler(
 | 
					        self._training_raysampler = NDCMultinomialRaysampler(
 | 
				
			||||||
            image_width=self.image_width,
 | 
					            image_width=self.image_width,
 | 
				
			||||||
            image_height=self.image_height,
 | 
					            image_height=self.image_height,
 | 
				
			||||||
            n_pts_per_ray=self.n_pts_per_ray_training,
 | 
					            n_pts_per_ray=n_pts_per_ray_training,
 | 
				
			||||||
            min_depth=0.0,
 | 
					            min_depth=0.0,
 | 
				
			||||||
            max_depth=0.0,
 | 
					            max_depth=0.0,
 | 
				
			||||||
            n_rays_per_image=self.n_rays_per_image_sampled_from_mask
 | 
					            n_rays_per_image=self.n_rays_per_image_sampled_from_mask
 | 
				
			||||||
@ -132,7 +157,7 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
 | 
				
			|||||||
        self._evaluation_raysampler = NDCMultinomialRaysampler(
 | 
					        self._evaluation_raysampler = NDCMultinomialRaysampler(
 | 
				
			||||||
            image_width=self.image_width,
 | 
					            image_width=self.image_width,
 | 
				
			||||||
            image_height=self.image_height,
 | 
					            image_height=self.image_height,
 | 
				
			||||||
            n_pts_per_ray=self.n_pts_per_ray_evaluation,
 | 
					            n_pts_per_ray=n_pts_per_ray_evaluation,
 | 
				
			||||||
            min_depth=0.0,
 | 
					            min_depth=0.0,
 | 
				
			||||||
            max_depth=0.0,
 | 
					            max_depth=0.0,
 | 
				
			||||||
            n_rays_per_image=self.n_rays_per_image_sampled_from_mask
 | 
					            n_rays_per_image=self.n_rays_per_image_sampled_from_mask
 | 
				
			||||||
@ -143,6 +168,11 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
 | 
				
			|||||||
            stratified_sampling=self.stratified_point_sampling_evaluation,
 | 
					            stratified_sampling=self.stratified_point_sampling_evaluation,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        max_y, min_y = self._training_raysampler.max_y, self._training_raysampler.min_y
 | 
				
			||||||
 | 
					        max_x, min_x = self._training_raysampler.max_x, self._training_raysampler.min_x
 | 
				
			||||||
 | 
					        self.pixel_height: float = (max_y - min_y) / (self.image_height - 1)
 | 
				
			||||||
 | 
					        self.pixel_width: float = (max_x - min_x) / (self.image_width - 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
 | 
					    def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
 | 
				
			||||||
        raise NotImplementedError()
 | 
					        raise NotImplementedError()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -193,19 +223,34 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
 | 
				
			|||||||
            min_depth=min_depth,
 | 
					            min_depth=min_depth,
 | 
				
			||||||
            max_depth=max_depth,
 | 
					            max_depth=max_depth,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					        if self.cast_ray_bundle_as_cone and isinstance(
 | 
				
			||||||
        if isinstance(ray_bundle, tuple):
 | 
					            ray_bundle, HeterogeneousRayBundle
 | 
				
			||||||
            return ImplicitronRayBundle(
 | 
					        ):
 | 
				
			||||||
                # pyre-ignore[16]
 | 
					            # If this error rises it means that raysampler has among
 | 
				
			||||||
                **{k: v for k, v in ray_bundle._asdict().items()}
 | 
					            # its arguments `n_ray_totals`. If it is the case
 | 
				
			||||||
 | 
					            # then you should update the radii computation and lengths
 | 
				
			||||||
 | 
					            # computation to handle padding and unpadding.
 | 
				
			||||||
 | 
					            raise TypeError(
 | 
				
			||||||
 | 
					                "Heterogeneous ray bundle is not supported for conical frustum computation yet"
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					        elif self.cast_ray_bundle_as_cone:
 | 
				
			||||||
 | 
					            pixel_hw: Tuple[float, float] = (self.pixel_height, self.pixel_width)
 | 
				
			||||||
 | 
					            pixel_radii_2d = compute_radii(cameras, ray_bundle.xys[..., :2], pixel_hw)
 | 
				
			||||||
 | 
					            return ImplicitronRayBundle.from_bins(
 | 
				
			||||||
 | 
					                directions=ray_bundle.directions,
 | 
				
			||||||
 | 
					                origins=ray_bundle.origins,
 | 
				
			||||||
 | 
					                bins=ray_bundle.lengths,
 | 
				
			||||||
 | 
					                xys=ray_bundle.xys,
 | 
				
			||||||
 | 
					                pixel_radii_2d=pixel_radii_2d,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return ImplicitronRayBundle(
 | 
					        return ImplicitronRayBundle(
 | 
				
			||||||
            directions=ray_bundle.directions,
 | 
					            directions=ray_bundle.directions,
 | 
				
			||||||
            origins=ray_bundle.origins,
 | 
					            origins=ray_bundle.origins,
 | 
				
			||||||
            lengths=ray_bundle.lengths,
 | 
					            lengths=ray_bundle.lengths,
 | 
				
			||||||
            xys=ray_bundle.xys,
 | 
					            xys=ray_bundle.xys,
 | 
				
			||||||
            camera_ids=ray_bundle.camera_ids,
 | 
					            camera_counts=getattr(ray_bundle, "camera_counts", None),
 | 
				
			||||||
            camera_counts=ray_bundle.camera_counts,
 | 
					            camera_ids=getattr(ray_bundle, "camera_ids", None),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -274,3 +319,62 @@ class NearFarRaySampler(AbstractMaskRaySampler):
 | 
				
			|||||||
        Returns the stored near/far planes.
 | 
					        Returns the stored near/far planes.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        return self.min_depth, self.max_depth
 | 
					        return self.min_depth, self.max_depth
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def compute_radii(
 | 
				
			||||||
 | 
					    cameras: CamerasBase,
 | 
				
			||||||
 | 
					    xy_grid: torch.Tensor,
 | 
				
			||||||
 | 
					    pixel_hw_ndc: Tuple[float, float],
 | 
				
			||||||
 | 
					) -> torch.Tensor:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Compute radii of conical frustums in world coordinates.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        cameras: cameras object representing a batch of cameras.
 | 
				
			||||||
 | 
					        xy_grid: torch.tensor grid of image xy coords.
 | 
				
			||||||
 | 
					        pixel_hw_ndc: pixel height and width in NDC
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					        radii: A tensor of shape `(..., 1)` radii of a cone.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    batch_size = xy_grid.shape[0]
 | 
				
			||||||
 | 
					    spatial_size = xy_grid.shape[1:-1]
 | 
				
			||||||
 | 
					    n_rays_per_image = spatial_size.numel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    xy = xy_grid.view(batch_size, n_rays_per_image, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # [batch_size, 3 * n_rays_per_image, 2]
 | 
				
			||||||
 | 
					    xy = torch.cat(
 | 
				
			||||||
 | 
					        [
 | 
				
			||||||
 | 
					            xy,
 | 
				
			||||||
 | 
					            # Will allow to find the norm on the x axis
 | 
				
			||||||
 | 
					            xy + torch.tensor([pixel_hw_ndc[1], 0], device=xy.device),
 | 
				
			||||||
 | 
					            # Will allow to find the norm on the y axis
 | 
				
			||||||
 | 
					            xy + torch.tensor([0, pixel_hw_ndc[0]], device=xy.device),
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					        dim=1,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    # [batch_size, 3 * n_rays_per_image, 3]
 | 
				
			||||||
 | 
					    xyz = torch.cat(
 | 
				
			||||||
 | 
					        (
 | 
				
			||||||
 | 
					            xy,
 | 
				
			||||||
 | 
					            xy.new_ones(batch_size, 3 * n_rays_per_image, 1),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        dim=-1,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # unproject the points
 | 
				
			||||||
 | 
					    unprojected_xyz = cameras.unproject_points(xyz, from_ndc=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    plane_world, plane_world_dx, plane_world_dy = torch.split(
 | 
				
			||||||
 | 
					        unprojected_xyz, n_rays_per_image, dim=1
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Distance from each unit-norm direction vector to its neighbors.
 | 
				
			||||||
 | 
					    dx_norm = torch.linalg.norm(plane_world_dx - plane_world, dim=-1, keepdims=True)
 | 
				
			||||||
 | 
					    dy_norm = torch.linalg.norm(plane_world_dy - plane_world, dim=-1, keepdims=True)
 | 
				
			||||||
 | 
					    # Cut the distance in half to obtain the base radius: (dx_norm + dy_norm) * 0.5
 | 
				
			||||||
 | 
					    # Scale it by 2/12**0.5 to match the variance of the pixel’s footprint
 | 
				
			||||||
 | 
					    radii = (dx_norm + dy_norm) / 12**0.5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return radii.view(batch_size, *spatial_size, 1)
 | 
				
			||||||
 | 
				
			|||||||
@ -177,6 +177,20 @@ def chunk_generator(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    for start_idx in iter:
 | 
					    for start_idx in iter:
 | 
				
			||||||
        end_idx = min(start_idx + chunk_size_in_rays, n_rays)
 | 
					        end_idx = min(start_idx + chunk_size_in_rays, n_rays)
 | 
				
			||||||
 | 
					        bins = (
 | 
				
			||||||
 | 
					            None
 | 
				
			||||||
 | 
					            if ray_bundle.bins is None
 | 
				
			||||||
 | 
					            else ray_bundle.bins.reshape(batch_size, n_rays, n_pts_per_ray + 1)[
 | 
				
			||||||
 | 
					                :, start_idx:end_idx
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        pixel_radii_2d = (
 | 
				
			||||||
 | 
					            None
 | 
				
			||||||
 | 
					            if ray_bundle.pixel_radii_2d is None
 | 
				
			||||||
 | 
					            else ray_bundle.pixel_radii_2d.reshape(batch_size, -1, 1)[
 | 
				
			||||||
 | 
					                :, start_idx:end_idx
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        ray_bundle_chunk = ImplicitronRayBundle(
 | 
					        ray_bundle_chunk = ImplicitronRayBundle(
 | 
				
			||||||
            origins=ray_bundle.origins.reshape(batch_size, -1, 3)[:, start_idx:end_idx],
 | 
					            origins=ray_bundle.origins.reshape(batch_size, -1, 3)[:, start_idx:end_idx],
 | 
				
			||||||
            directions=ray_bundle.directions.reshape(batch_size, -1, 3)[
 | 
					            directions=ray_bundle.directions.reshape(batch_size, -1, 3)[
 | 
				
			||||||
@ -186,6 +200,8 @@ def chunk_generator(
 | 
				
			|||||||
                :, start_idx:end_idx
 | 
					                :, start_idx:end_idx
 | 
				
			||||||
            ],
 | 
					            ],
 | 
				
			||||||
            xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx],
 | 
					            xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx],
 | 
				
			||||||
 | 
					            bins=bins,
 | 
				
			||||||
 | 
					            pixel_radii_2d=pixel_radii_2d,
 | 
				
			||||||
            camera_ids=_safe_slice(ray_bundle.camera_ids, start_idx, end_idx),
 | 
					            camera_ids=_safe_slice(ray_bundle.camera_ids, start_idx, end_idx),
 | 
				
			||||||
            camera_counts=_safe_slice(ray_bundle.camera_counts, start_idx, end_idx),
 | 
					            camera_counts=_safe_slice(ray_bundle.camera_counts, start_idx, end_idx),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
				
			|||||||
@ -58,6 +58,12 @@ class MultinomialRaysampler(torch.nn.Module):
 | 
				
			|||||||
    coordinate convention. For a raysampler which follows the PyTorch3D
 | 
					    coordinate convention. For a raysampler which follows the PyTorch3D
 | 
				
			||||||
    coordinate conventions please refer to `NDCMultinomialRaysampler`.
 | 
					    coordinate conventions please refer to `NDCMultinomialRaysampler`.
 | 
				
			||||||
    As such, `NDCMultinomialRaysampler` is a special case of `MultinomialRaysampler`.
 | 
					    As such, `NDCMultinomialRaysampler` is a special case of `MultinomialRaysampler`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Attributes:
 | 
				
			||||||
 | 
					        min_x: The leftmost x-coordinate of each ray's source pixel's center.
 | 
				
			||||||
 | 
					        max_x: The rightmost x-coordinate of each ray's source pixel's center.
 | 
				
			||||||
 | 
					        min_y: The topmost y-coordinate of each ray's source pixel's center.
 | 
				
			||||||
 | 
					        max_y: The bottommost y-coordinate of each ray's source pixel's center.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(
 | 
					    def __init__(
 | 
				
			||||||
@ -107,7 +113,8 @@ class MultinomialRaysampler(torch.nn.Module):
 | 
				
			|||||||
        self._n_rays_total = n_rays_total
 | 
					        self._n_rays_total = n_rays_total
 | 
				
			||||||
        self._unit_directions = unit_directions
 | 
					        self._unit_directions = unit_directions
 | 
				
			||||||
        self._stratified_sampling = stratified_sampling
 | 
					        self._stratified_sampling = stratified_sampling
 | 
				
			||||||
 | 
					        self.min_x, self.max_x = min_x, max_x
 | 
				
			||||||
 | 
					        self.min_y, self.max_y = min_y, max_y
 | 
				
			||||||
        # get the initial grid of image xy coords
 | 
					        # get the initial grid of image xy coords
 | 
				
			||||||
        y, x = meshgrid_ij(
 | 
					        y, x = meshgrid_ij(
 | 
				
			||||||
            torch.linspace(min_y, max_y, image_height, dtype=torch.float32),
 | 
					            torch.linspace(min_y, max_y, image_height, dtype=torch.float32),
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										77
									
								
								tests/common_camera_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								tests/common_camera_utils.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,77 @@
 | 
				
			|||||||
 | 
					# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
				
			||||||
 | 
					# All rights reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This source code is licensed under the BSD-style license found in the
 | 
				
			||||||
 | 
					# LICENSE file in the root directory of this source tree.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import typing
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from pytorch3d.common.datatypes import Device
 | 
				
			||||||
 | 
					from pytorch3d.renderer.cameras import (
 | 
				
			||||||
 | 
					    CamerasBase,
 | 
				
			||||||
 | 
					    FoVOrthographicCameras,
 | 
				
			||||||
 | 
					    FoVPerspectiveCameras,
 | 
				
			||||||
 | 
					    OpenGLOrthographicCameras,
 | 
				
			||||||
 | 
					    OpenGLPerspectiveCameras,
 | 
				
			||||||
 | 
					    OrthographicCameras,
 | 
				
			||||||
 | 
					    PerspectiveCameras,
 | 
				
			||||||
 | 
					    SfMOrthographicCameras,
 | 
				
			||||||
 | 
					    SfMPerspectiveCameras,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from pytorch3d.renderer.fisheyecameras import FishEyeCameras
 | 
				
			||||||
 | 
					from pytorch3d.transforms.so3 import so3_exp_map
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def init_random_cameras(
 | 
				
			||||||
 | 
					    cam_type: typing.Type[CamerasBase],
 | 
				
			||||||
 | 
					    batch_size: int,
 | 
				
			||||||
 | 
					    random_z: bool = False,
 | 
				
			||||||
 | 
					    device: Device = "cpu",
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    cam_params = {}
 | 
				
			||||||
 | 
					    T = torch.randn(batch_size, 3) * 0.03
 | 
				
			||||||
 | 
					    if not random_z:
 | 
				
			||||||
 | 
					        T[:, 2] = 4
 | 
				
			||||||
 | 
					    R = so3_exp_map(torch.randn(batch_size, 3) * 3.0)
 | 
				
			||||||
 | 
					    cam_params = {"R": R, "T": T, "device": device}
 | 
				
			||||||
 | 
					    if cam_type in (OpenGLPerspectiveCameras, OpenGLOrthographicCameras):
 | 
				
			||||||
 | 
					        cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1
 | 
				
			||||||
 | 
					        cam_params["zfar"] = torch.rand(batch_size) * 4 + 1 + cam_params["znear"]
 | 
				
			||||||
 | 
					        if cam_type == OpenGLPerspectiveCameras:
 | 
				
			||||||
 | 
					            cam_params["fov"] = torch.rand(batch_size) * 60 + 30
 | 
				
			||||||
 | 
					            cam_params["aspect_ratio"] = torch.rand(batch_size) * 0.5 + 0.5
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            cam_params["top"] = torch.rand(batch_size) * 0.2 + 0.9
 | 
				
			||||||
 | 
					            cam_params["bottom"] = -(torch.rand(batch_size)) * 0.2 - 0.9
 | 
				
			||||||
 | 
					            cam_params["left"] = -(torch.rand(batch_size)) * 0.2 - 0.9
 | 
				
			||||||
 | 
					            cam_params["right"] = torch.rand(batch_size) * 0.2 + 0.9
 | 
				
			||||||
 | 
					    elif cam_type in (FoVPerspectiveCameras, FoVOrthographicCameras):
 | 
				
			||||||
 | 
					        cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1
 | 
				
			||||||
 | 
					        cam_params["zfar"] = torch.rand(batch_size) * 4 + 1 + cam_params["znear"]
 | 
				
			||||||
 | 
					        if cam_type == FoVPerspectiveCameras:
 | 
				
			||||||
 | 
					            cam_params["fov"] = torch.rand(batch_size) * 60 + 30
 | 
				
			||||||
 | 
					            cam_params["aspect_ratio"] = torch.rand(batch_size) * 0.5 + 0.5
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            cam_params["max_y"] = torch.rand(batch_size) * 0.2 + 0.9
 | 
				
			||||||
 | 
					            cam_params["min_y"] = -(torch.rand(batch_size)) * 0.2 - 0.9
 | 
				
			||||||
 | 
					            cam_params["min_x"] = -(torch.rand(batch_size)) * 0.2 - 0.9
 | 
				
			||||||
 | 
					            cam_params["max_x"] = torch.rand(batch_size) * 0.2 + 0.9
 | 
				
			||||||
 | 
					    elif cam_type in (
 | 
				
			||||||
 | 
					        SfMOrthographicCameras,
 | 
				
			||||||
 | 
					        SfMPerspectiveCameras,
 | 
				
			||||||
 | 
					        OrthographicCameras,
 | 
				
			||||||
 | 
					        PerspectiveCameras,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        cam_params["focal_length"] = torch.rand(batch_size) * 10 + 0.1
 | 
				
			||||||
 | 
					        cam_params["principal_point"] = torch.randn((batch_size, 2))
 | 
				
			||||||
 | 
					    elif cam_type == FishEyeCameras:
 | 
				
			||||||
 | 
					        cam_params["focal_length"] = torch.rand(batch_size, 1) * 10 + 0.1
 | 
				
			||||||
 | 
					        cam_params["principal_point"] = torch.randn((batch_size, 2))
 | 
				
			||||||
 | 
					        cam_params["radial_params"] = torch.randn((batch_size, 6))
 | 
				
			||||||
 | 
					        cam_params["tangential_params"] = torch.randn((batch_size, 2))
 | 
				
			||||||
 | 
					        cam_params["thin_prism_params"] = torch.randn((batch_size, 4))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        raise ValueError(str(cam_type))
 | 
				
			||||||
 | 
					    return cam_type(**cam_params)
 | 
				
			||||||
@ -62,6 +62,7 @@ raysampler_AdaptiveRaySampler_args:
 | 
				
			|||||||
  n_rays_total_training: null
 | 
					  n_rays_total_training: null
 | 
				
			||||||
  stratified_point_sampling_training: true
 | 
					  stratified_point_sampling_training: true
 | 
				
			||||||
  stratified_point_sampling_evaluation: false
 | 
					  stratified_point_sampling_evaluation: false
 | 
				
			||||||
 | 
					  cast_ray_bundle_as_cone: false
 | 
				
			||||||
  scene_extent: 8.0
 | 
					  scene_extent: 8.0
 | 
				
			||||||
  scene_center:
 | 
					  scene_center:
 | 
				
			||||||
  - 0.0
 | 
					  - 0.0
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										254
									
								
								tests/implicitron/test_models_renderer_base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										254
									
								
								tests/implicitron/test_models_renderer_base.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,254 @@
 | 
				
			|||||||
 | 
					# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
				
			||||||
 | 
					# All rights reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This source code is licensed under the BSD-style license found in the
 | 
				
			||||||
 | 
					# LICENSE file in the root directory of this source tree.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import unittest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from pytorch3d.implicitron.models.renderer.base import (
 | 
				
			||||||
 | 
					    approximate_conical_frustum_as_gaussians,
 | 
				
			||||||
 | 
					    compute_3d_diagonal_covariance_gaussian,
 | 
				
			||||||
 | 
					    conical_frustum_to_gaussian,
 | 
				
			||||||
 | 
					    ImplicitronRayBundle,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from pytorch3d.implicitron.models.renderer.ray_sampler import AbstractMaskRaySampler
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from tests.common_testing import TestCaseMixin
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestRendererBase(TestCaseMixin, unittest.TestCase):
 | 
				
			||||||
 | 
					    def test_implicitron_from_bins(self) -> None:
 | 
				
			||||||
 | 
					        bins = torch.randn(2, 3, 4, 5)
 | 
				
			||||||
 | 
					        ray_bundle = ImplicitronRayBundle.from_bins(
 | 
				
			||||||
 | 
					            origins=None,
 | 
				
			||||||
 | 
					            directions=None,
 | 
				
			||||||
 | 
					            xys=None,
 | 
				
			||||||
 | 
					            bins=bins,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertClose(ray_bundle.lengths, 0.5 * (bins[..., 1:] + bins[..., :-1]))
 | 
				
			||||||
 | 
					        self.assertClose(ray_bundle.bins, bins)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_implicitron_raise_value_error_if_bins_dim_equal_1(self) -> None:
 | 
				
			||||||
 | 
					        with self.assertRaises(ValueError):
 | 
				
			||||||
 | 
					            ImplicitronRayBundle.from_bins(
 | 
				
			||||||
 | 
					                origins=torch.rand(2, 3, 4, 3),
 | 
				
			||||||
 | 
					                directions=torch.rand(2, 3, 4, 3),
 | 
				
			||||||
 | 
					                xys=torch.rand(2, 3, 4, 2),
 | 
				
			||||||
 | 
					                bins=torch.rand(2, 3, 4, 1),
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_conical_frustum_to_gaussian(self) -> None:
 | 
				
			||||||
 | 
					        origins = torch.zeros(3, 3, 3)
 | 
				
			||||||
 | 
					        directions = torch.tensor(
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                [[0, 0, 0], [1, 0, 0], [3, 0, 0]],
 | 
				
			||||||
 | 
					                [[0, 0.25, 0], [1, 0.25, 0], [3, 0.25, 0]],
 | 
				
			||||||
 | 
					                [[0, 1, 0], [1, 1, 0], [3, 1, 0]],
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        bins = torch.tensor(
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                [[0.5, 1.5], [0.3, 0.7], [0.3, 0.7]],
 | 
				
			||||||
 | 
					                [[0.5, 1.5], [0.3, 0.7], [0.3, 0.7]],
 | 
				
			||||||
 | 
					                [[0.5, 1.5], [0.3, 0.7], [0.3, 0.7]],
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        # see test_compute_pixel_radii_from_ray_direction
 | 
				
			||||||
 | 
					        radii = torch.tensor(
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                [1.25, 2.25, 2.25],
 | 
				
			||||||
 | 
					                [1.75, 2.75, 2.75],
 | 
				
			||||||
 | 
					                [1.75, 2.75, 2.75],
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        radii = radii[..., None] / 12**0.5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # The expected mean and diagonal covariance have been computed
 | 
				
			||||||
 | 
					        # by hand from the official code of MipNerf.
 | 
				
			||||||
 | 
					        # https://github.com/google/mipnerf/blob/84c969e0a623edd183b75693aed72a7e7c22902d/internal/mip.py#L125
 | 
				
			||||||
 | 
					        # mean, cov_diag = cast_rays(length, origins, directions, radii, 'cone', diag=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        expected_mean = torch.tensor(
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                [
 | 
				
			||||||
 | 
					                    [[0.0, 0.0, 0.0]],
 | 
				
			||||||
 | 
					                    [[0.5506329, 0.0, 0.0]],
 | 
				
			||||||
 | 
					                    [[1.6518986, 0.0, 0.0]],
 | 
				
			||||||
 | 
					                ],
 | 
				
			||||||
 | 
					                [
 | 
				
			||||||
 | 
					                    [[0.0, 0.28846154, 0.0]],
 | 
				
			||||||
 | 
					                    [[0.5506329, 0.13765822, 0.0]],
 | 
				
			||||||
 | 
					                    [[1.6518986, 0.13765822, 0.0]],
 | 
				
			||||||
 | 
					                ],
 | 
				
			||||||
 | 
					                [
 | 
				
			||||||
 | 
					                    [[0.0, 1.1538461, 0.0]],
 | 
				
			||||||
 | 
					                    [[0.5506329, 0.5506329, 0.0]],
 | 
				
			||||||
 | 
					                    [[1.6518986, 0.5506329, 0.0]],
 | 
				
			||||||
 | 
					                ],
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        expected_diag_cov = torch.tensor(
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                [
 | 
				
			||||||
 | 
					                    [[0.04544772, 0.04544772, 0.04544772]],
 | 
				
			||||||
 | 
					                    [[0.01130973, 0.03317059, 0.03317059]],
 | 
				
			||||||
 | 
					                    [[0.10178753, 0.03317059, 0.03317059]],
 | 
				
			||||||
 | 
					                ],
 | 
				
			||||||
 | 
					                [
 | 
				
			||||||
 | 
					                    [[0.08907752, 0.00404956, 0.08907752]],
 | 
				
			||||||
 | 
					                    [[0.0142245, 0.04734321, 0.04955113]],
 | 
				
			||||||
 | 
					                    [[0.10212927, 0.04991625, 0.04955113]],
 | 
				
			||||||
 | 
					                ],
 | 
				
			||||||
 | 
					                [
 | 
				
			||||||
 | 
					                    [[0.08907752, 0.0647929, 0.08907752]],
 | 
				
			||||||
 | 
					                    [[0.03608529, 0.03608529, 0.04955113]],
 | 
				
			||||||
 | 
					                    [[0.10674264, 0.05590574, 0.04955113]],
 | 
				
			||||||
 | 
					                ],
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ray = ImplicitronRayBundle(
 | 
				
			||||||
 | 
					            origins=origins,
 | 
				
			||||||
 | 
					            directions=directions,
 | 
				
			||||||
 | 
					            bins=bins,
 | 
				
			||||||
 | 
					            lengths=None,
 | 
				
			||||||
 | 
					            pixel_radii_2d=radii,
 | 
				
			||||||
 | 
					            xys=None,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        mean, diag_cov = conical_frustum_to_gaussian(ray)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertClose(mean, expected_mean)
 | 
				
			||||||
 | 
					        self.assertClose(diag_cov, expected_diag_cov)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_scale_conical_frustum_to_gaussian(self) -> None:
 | 
				
			||||||
 | 
					        origins = torch.zeros(2, 2, 3)
 | 
				
			||||||
 | 
					        directions = torch.Tensor(
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                [[0, 1, 0], [0, 0, 1]],
 | 
				
			||||||
 | 
					                [[0, 1, 0], [0, 0, 1]],
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        bins = torch.Tensor(
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                [[0.5, 1.5], [0.3, 0.7]],
 | 
				
			||||||
 | 
					                [[0.5, 1.5], [0.3, 0.7]],
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        radii = torch.ones(2, 2, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ray = ImplicitronRayBundle(
 | 
				
			||||||
 | 
					            origins=origins,
 | 
				
			||||||
 | 
					            directions=directions,
 | 
				
			||||||
 | 
					            bins=bins,
 | 
				
			||||||
 | 
					            pixel_radii_2d=radii,
 | 
				
			||||||
 | 
					            lengths=None,
 | 
				
			||||||
 | 
					            xys=None,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        mean, diag_cov = conical_frustum_to_gaussian(ray)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        scaling_factor = 2.5
 | 
				
			||||||
 | 
					        ray = ImplicitronRayBundle(
 | 
				
			||||||
 | 
					            origins=origins,
 | 
				
			||||||
 | 
					            directions=directions,
 | 
				
			||||||
 | 
					            bins=bins * scaling_factor,
 | 
				
			||||||
 | 
					            pixel_radii_2d=radii,
 | 
				
			||||||
 | 
					            lengths=None,
 | 
				
			||||||
 | 
					            xys=None,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        mean_scaled, diag_cov_scaled = conical_frustum_to_gaussian(ray)
 | 
				
			||||||
 | 
					        np.testing.assert_allclose(mean * scaling_factor, mean_scaled)
 | 
				
			||||||
 | 
					        np.testing.assert_allclose(
 | 
				
			||||||
 | 
					            diag_cov * scaling_factor**2, diag_cov_scaled, atol=1e-6
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_approximate_conical_frustum_as_gaussian(self) -> None:
 | 
				
			||||||
 | 
					        """Ensure that the computation modularity in our function is well done."""
 | 
				
			||||||
 | 
					        bins = torch.Tensor([[0.5, 1.5], [0.3, 0.7]])
 | 
				
			||||||
 | 
					        radii = torch.Tensor([[1.0], [1.0]])
 | 
				
			||||||
 | 
					        t_mean, t_var, r_var = approximate_conical_frustum_as_gaussians(bins, radii)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(t_mean.shape, (2, 1))
 | 
				
			||||||
 | 
					        self.assertEqual(t_var.shape, (2, 1))
 | 
				
			||||||
 | 
					        self.assertEqual(r_var.shape, (2, 1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        mu = np.array([[1.0], [0.5]])
 | 
				
			||||||
 | 
					        delta = np.array([[0.5], [0.2]])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        np.testing.assert_allclose(
 | 
				
			||||||
 | 
					            mu + (2 * mu * delta**2) / (3 * mu**2 + delta**2), t_mean.numpy()
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        np.testing.assert_allclose(
 | 
				
			||||||
 | 
					            (delta**2) / 3
 | 
				
			||||||
 | 
					            - (4 / 15)
 | 
				
			||||||
 | 
					            * (
 | 
				
			||||||
 | 
					                (delta**4 * (12 * mu**2 - delta**2))
 | 
				
			||||||
 | 
					                / (3 * mu**2 + delta**2) ** 2
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            t_var.numpy(),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        np.testing.assert_allclose(
 | 
				
			||||||
 | 
					            radii**2
 | 
				
			||||||
 | 
					            * (
 | 
				
			||||||
 | 
					                (mu**2) / 4
 | 
				
			||||||
 | 
					                + (5 / 12) * delta**2
 | 
				
			||||||
 | 
					                - 4 / 15 * (delta**4) / (3 * mu**2 + delta**2)
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            r_var.numpy(),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_compute_3d_diagonal_covariance_gaussian(self) -> None:
 | 
				
			||||||
 | 
					        ray_directions = torch.Tensor([[0, 0, 1]])
 | 
				
			||||||
 | 
					        t_var = torch.Tensor([0.5, 0.5, 1])
 | 
				
			||||||
 | 
					        r_var = torch.Tensor([0.6, 0.3, 0.4])
 | 
				
			||||||
 | 
					        expected_diag_cov = np.array(
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                [
 | 
				
			||||||
 | 
					                    # t_cov_diag + xy_cov_diag
 | 
				
			||||||
 | 
					                    [0.0 + 0.6, 0.0 + 0.6, 0.5 + 0.0],
 | 
				
			||||||
 | 
					                    [0.0 + 0.3, 0.0 + 0.3, 0.5 + 0.0],
 | 
				
			||||||
 | 
					                    [0.0 + 0.4, 0.0 + 0.4, 1.0 + 0.0],
 | 
				
			||||||
 | 
					                ]
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        diag_cov = compute_3d_diagonal_covariance_gaussian(ray_directions, t_var, r_var)
 | 
				
			||||||
 | 
					        np.testing.assert_allclose(diag_cov.numpy(), expected_diag_cov)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_conical_frustum_to_gaussian_raise_valueerror(self) -> None:
 | 
				
			||||||
 | 
					        lengths = torch.linspace(0, 1, steps=6)
 | 
				
			||||||
 | 
					        directions = torch.tensor([0, 0, 1])
 | 
				
			||||||
 | 
					        origins = torch.tensor([1, 1, 1])
 | 
				
			||||||
 | 
					        ray = ImplicitronRayBundle(
 | 
				
			||||||
 | 
					            origins=origins, directions=directions, lengths=lengths, xys=None
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        with self.assertRaises(ValueError) as context:
 | 
				
			||||||
 | 
					            _ = conical_frustum_to_gaussian(ray)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        expected_error_message = (
 | 
				
			||||||
 | 
					            "RayBundle pixel_radii_2d or bins have not been provided."
 | 
				
			||||||
 | 
					            " Look at pytorch3d.renderer.implicit.renderer.ray_sampler::"
 | 
				
			||||||
 | 
					            "AbstractMaskRaySampler to see how to compute them. Have you forgot to set"
 | 
				
			||||||
 | 
					            "`cast_ray_bundle_as_cone` to True?"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(expected_error_message, str(context.exception))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Ensure message is coherent with AbstractMaskRaySampler
 | 
				
			||||||
 | 
					        class FakeRaySampler(AbstractMaskRaySampler):
 | 
				
			||||||
 | 
					            def _get_min_max_depth_bounds(self, *args):
 | 
				
			||||||
 | 
					                return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        message_assertion = (
 | 
				
			||||||
 | 
					            "If cast_ray_bundle_as_cone has been removed please update the doc"
 | 
				
			||||||
 | 
					            "conical_frustum_to_gaussian"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertIsNotNone(
 | 
				
			||||||
 | 
					            getattr(FakeRaySampler(), "cast_ray_bundle_as_cone", None),
 | 
				
			||||||
 | 
					            message_assertion,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
							
								
								
									
										290
									
								
								tests/implicitron/test_models_renderer_ray_sampler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										290
									
								
								tests/implicitron/test_models_renderer_ray_sampler.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,290 @@
 | 
				
			|||||||
 | 
					# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
				
			||||||
 | 
					# All rights reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This source code is licensed under the BSD-style license found in the
 | 
				
			||||||
 | 
					# LICENSE file in the root directory of this source tree.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import unittest
 | 
				
			||||||
 | 
					from itertools import product
 | 
				
			||||||
 | 
					from typing import Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from unittest.mock import patch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from pytorch3d.common.compat import meshgrid_ij
 | 
				
			||||||
 | 
					from pytorch3d.implicitron.models.renderer.base import EvaluationMode
 | 
				
			||||||
 | 
					from pytorch3d.implicitron.models.renderer.ray_sampler import (
 | 
				
			||||||
 | 
					    AdaptiveRaySampler,
 | 
				
			||||||
 | 
					    compute_radii,
 | 
				
			||||||
 | 
					    NearFarRaySampler,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from pytorch3d.renderer.cameras import (
 | 
				
			||||||
 | 
					    CamerasBase,
 | 
				
			||||||
 | 
					    FoVOrthographicCameras,
 | 
				
			||||||
 | 
					    FoVPerspectiveCameras,
 | 
				
			||||||
 | 
					    OrthographicCameras,
 | 
				
			||||||
 | 
					    PerspectiveCameras,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from pytorch3d.renderer.implicit.utils import HeterogeneousRayBundle
 | 
				
			||||||
 | 
					from tests.common_camera_utils import init_random_cameras
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from tests.common_testing import TestCaseMixin
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					CAMERA_TYPES = (
 | 
				
			||||||
 | 
					    FoVPerspectiveCameras,
 | 
				
			||||||
 | 
					    FoVOrthographicCameras,
 | 
				
			||||||
 | 
					    OrthographicCameras,
 | 
				
			||||||
 | 
					    PerspectiveCameras,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def unproject_xy_grid_from_ndc_to_world_coord(
 | 
				
			||||||
 | 
					    cameras: CamerasBase, xy_grid: torch.Tensor
 | 
				
			||||||
 | 
					) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Unproject a xy_grid from NDC coordinates to world coordinates.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        cameras: CamerasBase.
 | 
				
			||||||
 | 
					        xy_grid: A tensor of shape `(..., H*W, 2)` representing the
 | 
				
			||||||
 | 
					            x, y coords.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					        A tensor of shape `(..., H*W, 3)` representing the
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    batch_size = xy_grid.shape[0]
 | 
				
			||||||
 | 
					    n_rays_per_image = xy_grid.shape[1:-1].numel()
 | 
				
			||||||
 | 
					    xy = xy_grid.view(batch_size, -1, 2)
 | 
				
			||||||
 | 
					    xyz = torch.cat([xy, xy_grid.new_ones(batch_size, n_rays_per_image, 1)], dim=-1)
 | 
				
			||||||
 | 
					    plane_at_depth1 = cameras.unproject_points(xyz, from_ndc=True)
 | 
				
			||||||
 | 
					    return plane_at_depth1.view(*xy_grid.shape[:-1], 3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestRaysampler(TestCaseMixin, unittest.TestCase):
 | 
				
			||||||
 | 
					    def test_ndc_raysampler_n_ray_total_is_none(self):
 | 
				
			||||||
 | 
					        sampler = NearFarRaySampler()
 | 
				
			||||||
 | 
					        message = (
 | 
				
			||||||
 | 
					            "If you introduce the support of `n_rays_total` for {0}, please handle the "
 | 
				
			||||||
 | 
					            "packing and unpacking logic for the radii and lengths computation."
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertIsNone(
 | 
				
			||||||
 | 
					            sampler._training_raysampler._n_rays_total, message.format(type(sampler))
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertIsNone(
 | 
				
			||||||
 | 
					            sampler._evaluation_raysampler._n_rays_total, message.format(type(sampler))
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        sampler = AdaptiveRaySampler()
 | 
				
			||||||
 | 
					        self.assertIsNone(
 | 
				
			||||||
 | 
					            sampler._training_raysampler._n_rays_total, message.format(type(sampler))
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertIsNone(
 | 
				
			||||||
 | 
					            sampler._evaluation_raysampler._n_rays_total, message.format(type(sampler))
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_catch_heterogeneous_exception(self):
 | 
				
			||||||
 | 
					        cameras = init_random_cameras(FoVPerspectiveCameras, 1, random_z=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        class FakeSampler:
 | 
				
			||||||
 | 
					            def __init__(self):
 | 
				
			||||||
 | 
					                self.min_x, self.max_x = 1, 2
 | 
				
			||||||
 | 
					                self.min_y, self.max_y = 1, 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            def __call__(self, **kwargs):
 | 
				
			||||||
 | 
					                return HeterogeneousRayBundle(
 | 
				
			||||||
 | 
					                    torch.rand(3), torch.rand(3), torch.rand(3), torch.rand(1)
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with patch(
 | 
				
			||||||
 | 
					            "pytorch3d.implicitron.models.renderer.ray_sampler.NDCMultinomialRaysampler",
 | 
				
			||||||
 | 
					            return_value=FakeSampler(),
 | 
				
			||||||
 | 
					        ):
 | 
				
			||||||
 | 
					            for sampler in [
 | 
				
			||||||
 | 
					                AdaptiveRaySampler(cast_ray_bundle_as_cone=True),
 | 
				
			||||||
 | 
					                NearFarRaySampler(cast_ray_bundle_as_cone=True),
 | 
				
			||||||
 | 
					            ]:
 | 
				
			||||||
 | 
					                with self.assertRaises(TypeError):
 | 
				
			||||||
 | 
					                    _ = sampler(cameras, EvaluationMode.TRAINING)
 | 
				
			||||||
 | 
					            for sampler in [
 | 
				
			||||||
 | 
					                AdaptiveRaySampler(cast_ray_bundle_as_cone=False),
 | 
				
			||||||
 | 
					                NearFarRaySampler(cast_ray_bundle_as_cone=False),
 | 
				
			||||||
 | 
					            ]:
 | 
				
			||||||
 | 
					                _ = sampler(cameras, EvaluationMode.TRAINING)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_compute_radii(self):
 | 
				
			||||||
 | 
					        batch_size = 1
 | 
				
			||||||
 | 
					        image_height, image_width = 20, 10
 | 
				
			||||||
 | 
					        min_y, max_y, min_x, max_x = -1.0, 1.0, -1.0, 1.0
 | 
				
			||||||
 | 
					        y, x = meshgrid_ij(
 | 
				
			||||||
 | 
					            torch.linspace(min_y, max_y, image_height, dtype=torch.float32),
 | 
				
			||||||
 | 
					            torch.linspace(min_x, max_x, image_width, dtype=torch.float32),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        xy_grid = torch.stack([x, y], dim=-1).view(-1, 2)
 | 
				
			||||||
 | 
					        pixel_width = (max_x - min_x) / (image_width - 1)
 | 
				
			||||||
 | 
					        pixel_height = (max_y - min_y) / (image_height - 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for cam_type in CAMERA_TYPES:
 | 
				
			||||||
 | 
					            # init a batch of random cameras
 | 
				
			||||||
 | 
					            cameras = init_random_cameras(cam_type, batch_size, random_z=True)
 | 
				
			||||||
 | 
					            # This method allow us to compute the radii whithout having
 | 
				
			||||||
 | 
					            # access to the full grid. Raysamplers during the training
 | 
				
			||||||
 | 
					            # will sample random rays from the grid.
 | 
				
			||||||
 | 
					            radii = compute_radii(
 | 
				
			||||||
 | 
					                cameras, xy_grid, pixel_hw_ndc=(pixel_height, pixel_width)
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            plane_at_depth1 = unproject_xy_grid_from_ndc_to_world_coord(
 | 
				
			||||||
 | 
					                cameras, xy_grid
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            # This method absolutely needs the full grid to work.
 | 
				
			||||||
 | 
					            expected_radii = compute_pixel_radii_from_grid(
 | 
				
			||||||
 | 
					                plane_at_depth1.reshape(1, image_height, image_width, 3)
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.assertClose(expected_radii.reshape(-1, 1), radii)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_forward(self):
 | 
				
			||||||
 | 
					        n_rays_per_image = 16
 | 
				
			||||||
 | 
					        image_height, image_width = 20, 20
 | 
				
			||||||
 | 
					        kwargs = {
 | 
				
			||||||
 | 
					            "image_width": image_width,
 | 
				
			||||||
 | 
					            "image_height": image_height,
 | 
				
			||||||
 | 
					            "n_pts_per_ray_training": 32,
 | 
				
			||||||
 | 
					            "n_pts_per_ray_evaluation": 32,
 | 
				
			||||||
 | 
					            "n_rays_per_image_sampled_from_mask": n_rays_per_image,
 | 
				
			||||||
 | 
					            "cast_ray_bundle_as_cone": False,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        batch_size = 2
 | 
				
			||||||
 | 
					        samplers = [NearFarRaySampler(**kwargs), AdaptiveRaySampler(**kwargs)]
 | 
				
			||||||
 | 
					        evaluation_modes = [EvaluationMode.TRAINING, EvaluationMode.EVALUATION]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for cam_type, sampler, evaluation_mode in product(
 | 
				
			||||||
 | 
					            CAMERA_TYPES, samplers, evaluation_modes
 | 
				
			||||||
 | 
					        ):
 | 
				
			||||||
 | 
					            cameras = init_random_cameras(cam_type, batch_size, random_z=True)
 | 
				
			||||||
 | 
					            ray_bundle = sampler(cameras, evaluation_mode)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            shape_out = (
 | 
				
			||||||
 | 
					                (batch_size, image_width, image_height)
 | 
				
			||||||
 | 
					                if evaluation_mode == EvaluationMode.EVALUATION
 | 
				
			||||||
 | 
					                else (batch_size, n_rays_per_image, 1)
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            n_pts_per_ray = (
 | 
				
			||||||
 | 
					                kwargs["n_pts_per_ray_evaluation"]
 | 
				
			||||||
 | 
					                if evaluation_mode == EvaluationMode.EVALUATION
 | 
				
			||||||
 | 
					                else kwargs["n_pts_per_ray_training"]
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.assertIsNone(ray_bundle.bins)
 | 
				
			||||||
 | 
					            self.assertIsNone(ray_bundle.pixel_radii_2d)
 | 
				
			||||||
 | 
					            self.assertEqual(
 | 
				
			||||||
 | 
					                ray_bundle.lengths.shape,
 | 
				
			||||||
 | 
					                (*shape_out, n_pts_per_ray),
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.assertEqual(ray_bundle.directions.shape, (*shape_out, 3))
 | 
				
			||||||
 | 
					            self.assertEqual(ray_bundle.origins.shape, (*shape_out, 3))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_forward_with_use_bins(self):
 | 
				
			||||||
 | 
					        n_rays_per_image = 16
 | 
				
			||||||
 | 
					        image_height, image_width = 20, 20
 | 
				
			||||||
 | 
					        kwargs = {
 | 
				
			||||||
 | 
					            "image_width": image_width,
 | 
				
			||||||
 | 
					            "image_height": image_height,
 | 
				
			||||||
 | 
					            "n_pts_per_ray_training": 32,
 | 
				
			||||||
 | 
					            "n_pts_per_ray_evaluation": 32,
 | 
				
			||||||
 | 
					            "n_rays_per_image_sampled_from_mask": n_rays_per_image,
 | 
				
			||||||
 | 
					            "cast_ray_bundle_as_cone": True,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        batch_size = 1
 | 
				
			||||||
 | 
					        samplers = [NearFarRaySampler(**kwargs), AdaptiveRaySampler(**kwargs)]
 | 
				
			||||||
 | 
					        evaluation_modes = [EvaluationMode.TRAINING, EvaluationMode.EVALUATION]
 | 
				
			||||||
 | 
					        for cam_type, sampler, evaluation_mode in product(
 | 
				
			||||||
 | 
					            CAMERA_TYPES, samplers, evaluation_modes
 | 
				
			||||||
 | 
					        ):
 | 
				
			||||||
 | 
					            cameras = init_random_cameras(cam_type, batch_size, random_z=True)
 | 
				
			||||||
 | 
					            ray_bundle = sampler(cameras, evaluation_mode)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            lengths = 0.5 * (ray_bundle.bins[..., :-1] + ray_bundle.bins[..., 1:])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.assertClose(ray_bundle.lengths, lengths)
 | 
				
			||||||
 | 
					            shape_out = (
 | 
				
			||||||
 | 
					                (batch_size, image_width, image_height)
 | 
				
			||||||
 | 
					                if evaluation_mode == EvaluationMode.EVALUATION
 | 
				
			||||||
 | 
					                else (batch_size, n_rays_per_image, 1)
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.assertEqual(ray_bundle.pixel_radii_2d.shape, (*shape_out, 1))
 | 
				
			||||||
 | 
					            self.assertEqual(ray_bundle.directions.shape, (*shape_out, 3))
 | 
				
			||||||
 | 
					            self.assertEqual(ray_bundle.origins.shape, (*shape_out, 3))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Helper to test compute_radii
 | 
				
			||||||
 | 
					def compute_pixel_radii_from_grid(pixel_grid: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Compute the radii of a conical frustum given the pixel grid.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    To compute the radii we first compute the translation from a pixel
 | 
				
			||||||
 | 
					    to its neighbors along the x and y axis. Then, we compute the norm
 | 
				
			||||||
 | 
					    of each translation along the x and y axis.
 | 
				
			||||||
 | 
					    The radii are then obtained by the following formula:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    (dx_norm + dy_norm) * 0.5 * 2 / 12**0.5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    where 2/12**0.5 is a scaling factor to match
 | 
				
			||||||
 | 
					    the variance of the pixel’s footprint.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        pixel_grid: A tensor of shape `(..., H, W, dim)` representing the
 | 
				
			||||||
 | 
					            full grid of rays pixel_grid.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					        The radiis for each pixels and shape `(..., H, W, 1)`.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    # [B, H, W - 1, 3]
 | 
				
			||||||
 | 
					    x_translation = torch.diff(pixel_grid, dim=-2)
 | 
				
			||||||
 | 
					    # [B, H - 1, W, 3]
 | 
				
			||||||
 | 
					    y_translation = torch.diff(pixel_grid, dim=-3)
 | 
				
			||||||
 | 
					    # [B, H, W - 1, 1]
 | 
				
			||||||
 | 
					    dx_norm = torch.linalg.norm(x_translation, dim=-1, keepdim=True)
 | 
				
			||||||
 | 
					    # [B, H - 1, W, 1]
 | 
				
			||||||
 | 
					    dy_norm = torch.linalg.norm(y_translation, dim=-1, keepdim=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Fill the missing value [B, H, W, 1]
 | 
				
			||||||
 | 
					    dx_norm = torch.concatenate([dx_norm, dx_norm[..., -1:, :]], -2)
 | 
				
			||||||
 | 
					    dy_norm = torch.concatenate([dy_norm, dy_norm[..., -1:, :, :]], -3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Cut the distance in half to obtain the base radius: (dx_norm + dy_norm) * 0.5
 | 
				
			||||||
 | 
					    # and multiply it by the scaling factor: * 2 / 12**0.5
 | 
				
			||||||
 | 
					    radii = (dx_norm + dy_norm) / 12**0.5
 | 
				
			||||||
 | 
					    return radii
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestRadiiComputationOnFullGrid(TestCaseMixin, unittest.TestCase):
 | 
				
			||||||
 | 
					    def test_compute_pixel_radii_from_grid(self):
 | 
				
			||||||
 | 
					        pixel_grid = torch.tensor(
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                [[0.0, 0, 0], [1.0, 0.0, 0], [3.0, 0.0, 0.0]],
 | 
				
			||||||
 | 
					                [[0.0, 0.25, 0], [1.0, 0.25, 0], [3.0, 0.25, 0]],
 | 
				
			||||||
 | 
					                [[0.0, 1, 0], [1.0, 1.0, 0], [3.0000, 1.0, 0]],
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        expected_y_norm = torch.tensor(
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                [0.25, 0.25, 0.25],
 | 
				
			||||||
 | 
					                [0.75, 0.75, 0.75],
 | 
				
			||||||
 | 
					                [0.75, 0.75, 0.75],  # duplicated from previous row
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        expected_x_norm = torch.tensor(
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                # 3rd column is duplicated from 2nd
 | 
				
			||||||
 | 
					                [1.0, 2.0, 2.0],
 | 
				
			||||||
 | 
					                [1.0, 2.0, 2.0],
 | 
				
			||||||
 | 
					                [1.0, 2.0, 2.0],
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        expected_radii = (expected_x_norm + expected_y_norm) / 12**0.5
 | 
				
			||||||
 | 
					        radii = compute_pixel_radii_from_grid(pixel_grid)
 | 
				
			||||||
 | 
					        self.assertClose(radii, expected_radii[..., None])
 | 
				
			||||||
@ -32,7 +32,6 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import math
 | 
					import math
 | 
				
			||||||
import pickle
 | 
					import pickle
 | 
				
			||||||
import typing
 | 
					 | 
				
			||||||
import unittest
 | 
					import unittest
 | 
				
			||||||
from itertools import product
 | 
					from itertools import product
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -60,6 +59,8 @@ from pytorch3d.transforms import Transform3d
 | 
				
			|||||||
from pytorch3d.transforms.rotation_conversions import random_rotations
 | 
					from pytorch3d.transforms.rotation_conversions import random_rotations
 | 
				
			||||||
from pytorch3d.transforms.so3 import so3_exp_map
 | 
					from pytorch3d.transforms.so3 import so3_exp_map
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .common_camera_utils import init_random_cameras
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .common_testing import TestCaseMixin
 | 
					from .common_testing import TestCaseMixin
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -151,60 +152,6 @@ def ndc_to_screen_points_naive(points, imsize):
 | 
				
			|||||||
    return torch.stack((x, y, z), dim=2)
 | 
					    return torch.stack((x, y, z), dim=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def init_random_cameras(
 | 
					 | 
				
			||||||
    cam_type: typing.Type[CamerasBase],
 | 
					 | 
				
			||||||
    batch_size: int,
 | 
					 | 
				
			||||||
    random_z: bool = False,
 | 
					 | 
				
			||||||
    device: Device = "cpu",
 | 
					 | 
				
			||||||
):
 | 
					 | 
				
			||||||
    cam_params = {}
 | 
					 | 
				
			||||||
    T = torch.randn(batch_size, 3) * 0.03
 | 
					 | 
				
			||||||
    if not random_z:
 | 
					 | 
				
			||||||
        T[:, 2] = 4
 | 
					 | 
				
			||||||
    R = so3_exp_map(torch.randn(batch_size, 3) * 3.0)
 | 
					 | 
				
			||||||
    cam_params = {"R": R, "T": T, "device": device}
 | 
					 | 
				
			||||||
    if cam_type in (OpenGLPerspectiveCameras, OpenGLOrthographicCameras):
 | 
					 | 
				
			||||||
        cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1
 | 
					 | 
				
			||||||
        cam_params["zfar"] = torch.rand(batch_size) * 4 + 1 + cam_params["znear"]
 | 
					 | 
				
			||||||
        if cam_type == OpenGLPerspectiveCameras:
 | 
					 | 
				
			||||||
            cam_params["fov"] = torch.rand(batch_size) * 60 + 30
 | 
					 | 
				
			||||||
            cam_params["aspect_ratio"] = torch.rand(batch_size) * 0.5 + 0.5
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            cam_params["top"] = torch.rand(batch_size) * 0.2 + 0.9
 | 
					 | 
				
			||||||
            cam_params["bottom"] = -(torch.rand(batch_size)) * 0.2 - 0.9
 | 
					 | 
				
			||||||
            cam_params["left"] = -(torch.rand(batch_size)) * 0.2 - 0.9
 | 
					 | 
				
			||||||
            cam_params["right"] = torch.rand(batch_size) * 0.2 + 0.9
 | 
					 | 
				
			||||||
    elif cam_type in (FoVPerspectiveCameras, FoVOrthographicCameras):
 | 
					 | 
				
			||||||
        cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1
 | 
					 | 
				
			||||||
        cam_params["zfar"] = torch.rand(batch_size) * 4 + 1 + cam_params["znear"]
 | 
					 | 
				
			||||||
        if cam_type == FoVPerspectiveCameras:
 | 
					 | 
				
			||||||
            cam_params["fov"] = torch.rand(batch_size) * 60 + 30
 | 
					 | 
				
			||||||
            cam_params["aspect_ratio"] = torch.rand(batch_size) * 0.5 + 0.5
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            cam_params["max_y"] = torch.rand(batch_size) * 0.2 + 0.9
 | 
					 | 
				
			||||||
            cam_params["min_y"] = -(torch.rand(batch_size)) * 0.2 - 0.9
 | 
					 | 
				
			||||||
            cam_params["min_x"] = -(torch.rand(batch_size)) * 0.2 - 0.9
 | 
					 | 
				
			||||||
            cam_params["max_x"] = torch.rand(batch_size) * 0.2 + 0.9
 | 
					 | 
				
			||||||
    elif cam_type in (
 | 
					 | 
				
			||||||
        SfMOrthographicCameras,
 | 
					 | 
				
			||||||
        SfMPerspectiveCameras,
 | 
					 | 
				
			||||||
        OrthographicCameras,
 | 
					 | 
				
			||||||
        PerspectiveCameras,
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        cam_params["focal_length"] = torch.rand(batch_size) * 10 + 0.1
 | 
					 | 
				
			||||||
        cam_params["principal_point"] = torch.randn((batch_size, 2))
 | 
					 | 
				
			||||||
    elif cam_type == FishEyeCameras:
 | 
					 | 
				
			||||||
        cam_params["focal_length"] = torch.rand(batch_size, 1) * 10 + 0.1
 | 
					 | 
				
			||||||
        cam_params["principal_point"] = torch.randn((batch_size, 2))
 | 
					 | 
				
			||||||
        cam_params["radial_params"] = torch.randn((batch_size, 6))
 | 
					 | 
				
			||||||
        cam_params["tangential_params"] = torch.randn((batch_size, 2))
 | 
					 | 
				
			||||||
        cam_params["thin_prism_params"] = torch.randn((batch_size, 4))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        raise ValueError(str(cam_type))
 | 
					 | 
				
			||||||
    return cam_type(**cam_params)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class TestCameraHelpers(TestCaseMixin, unittest.TestCase):
 | 
					class TestCameraHelpers(TestCaseMixin, unittest.TestCase):
 | 
				
			||||||
    def setUp(self) -> None:
 | 
					    def setUp(self) -> None:
 | 
				
			||||||
        super().setUp()
 | 
					        super().setUp()
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user