diff --git a/projects/implicitron_trainer/tests/experiment.yaml b/projects/implicitron_trainer/tests/experiment.yaml index d217c867..c60d444d 100644 --- a/projects/implicitron_trainer/tests/experiment.yaml +++ b/projects/implicitron_trainer/tests/experiment.yaml @@ -216,6 +216,7 @@ model_factory_ImplicitronModelFactory_args: n_rays_total_training: null stratified_point_sampling_training: true stratified_point_sampling_evaluation: false + cast_ray_bundle_as_cone: false scene_extent: 8.0 scene_center: - 0.0 @@ -228,6 +229,7 @@ model_factory_ImplicitronModelFactory_args: n_rays_total_training: null stratified_point_sampling_training: true stratified_point_sampling_evaluation: false + cast_ray_bundle_as_cone: false min_depth: 0.1 max_depth: 8.0 renderer_LSTMRenderer_args: @@ -642,6 +644,7 @@ model_factory_ImplicitronModelFactory_args: n_rays_total_training: null stratified_point_sampling_training: true stratified_point_sampling_evaluation: false + cast_ray_bundle_as_cone: false scene_extent: 8.0 scene_center: - 0.0 @@ -654,6 +657,7 @@ model_factory_ImplicitronModelFactory_args: n_rays_total_training: null stratified_point_sampling_training: true stratified_point_sampling_evaluation: false + cast_ray_bundle_as_cone: false min_depth: 0.1 max_depth: 8.0 renderer_LSTMRenderer_args: diff --git a/pytorch3d/implicitron/models/renderer/base.py b/pytorch3d/implicitron/models/renderer/base.py index a8644dde..41d9ddbb 100644 --- a/pytorch3d/implicitron/models/renderer/base.py +++ b/pytorch3d/implicitron/models/renderer/base.py @@ -16,6 +16,7 @@ from typing import Any, Dict, List, Optional, Tuple import torch from pytorch3d.implicitron.tools.config import ReplaceableBase from pytorch3d.ops import packed_to_padded +from pytorch3d.renderer.implicit.utils import ray_bundle_variables_to_ray_points class EvaluationMode(Enum): @@ -47,6 +48,27 @@ class ImplicitronRayBundle: camera_counts: A tensor of shape (N, ) which how many times the coresponding camera in `camera_ids` was sampled. `sum(camera_counts) == minibatch`, where `minibatch = origins.shape[0]`. + + Attributes: + origins: A tensor of shape `(..., 3)` denoting the + origins of the sampling rays in world coords. + directions: A tensor of shape `(..., 3)` containing the direction + vectors of sampling rays in world coords. They don't have to be normalized; + they define unit vectors in the respective 1D coordinate systems; see + documentation for :func:`ray_bundle_to_ray_points` for the conversion formula. + lengths: A tensor of shape `(..., num_points_per_ray)` + containing the lengths at which the rays are sampled. + xys: A tensor of shape `(..., 2)`, the xy-locations (`xys`) of the ray pixels + camera_ids: An optional tensor of shape (N, ) which indicates which camera + was used to sample the rays. `N` is the number of unique sampled cameras. + camera_counts: An optional tensor of shape (N, ) indicates how many times the + coresponding camera in `camera_ids` was sampled. + `sum(camera_counts)==total_number_of_rays`. + bins: An optional tensor of shape `(..., num_points_per_ray + 1)` + containing the bins at which the rays are sampled. In this case + lengths should be equal to the midpoints of bins `(..., num_points_per_ray)`. + pixel_radii_2d: An optional tensor of shape `(..., 1)` + base radii of the conical frustums. """ origins: torch.Tensor @@ -55,6 +77,45 @@ class ImplicitronRayBundle: xys: torch.Tensor camera_ids: Optional[torch.LongTensor] = None camera_counts: Optional[torch.LongTensor] = None + bins: Optional[torch.Tensor] = None + pixel_radii_2d: Optional[torch.Tensor] = None + + @classmethod + def from_bins( + cls, + origins: torch.Tensor, + directions: torch.Tensor, + bins: torch.Tensor, + xys: torch.Tensor, + **kwargs, + ) -> "ImplicitronRayBundle": + """ + Creates a new instance from bins instead of lengths. + + Attributes: + origins: A tensor of shape `(..., 3)` denoting the + origins of the sampling rays in world coords. + directions: A tensor of shape `(..., 3)` containing the direction + vectors of sampling rays in world coords. They don't have to be normalized; + they define unit vectors in the respective 1D coordinate systems; see + documentation for :func:`ray_bundle_to_ray_points` for the conversion formula. + bins: A tensor of shape `(..., num_points_per_ray + 1)` + containing the bins at which the rays are sampled. In this case + lengths is equal to the midpoints of bins `(..., num_points_per_ray)`. + xys: A tensor of shape `(..., 2)`, the xy-locations (`xys`) of the ray pixels + kwargs: Additional arguments passed to the constructor of ImplicitronRayBundle + Returns: + An instance of ImplicitronRayBundle. + """ + + if bins.shape[-1] <= 1: + raise ValueError( + "The last dim of bins must be at least superior or equal to 2." + ) + # equivalent to: 0.5 * (bins[..., 1:] + bins[..., :-1]) but more efficient + lengths = torch.lerp(bins[..., 1:], bins[..., :-1], 0.5) + + return cls(origins, directions, lengths, xys, bins=bins, **kwargs) def is_packed(self) -> bool: """ @@ -195,3 +256,154 @@ class BaseRenderer(ABC, ReplaceableBase): instance of RendererOutput """ pass + + +def compute_3d_diagonal_covariance_gaussian( + rays_directions: torch.Tensor, + rays_dir_variance: torch.Tensor, + radii_variance: torch.Tensor, + eps: float = 1e-6, +) -> torch.Tensor: + """ + Transform the variances (rays_dir_variance, radii_variance) of the gaussians from + the coordinate frame of the conical frustum to 3D world coordinates. + + It follows the equation 16 of `MIP-NeRF `_ + + Args: + rays_directions: A tensor of shape `(..., 3)` + rays_dir_variance: A tensor of shape `(..., num_intervals)` representing + the variance of the conical frustum with respect to the rays direction. + radii_variance: A tensor of shape `(..., num_intervals)` representing + the variance of the conical frustum with respect to its radius. + eps: a small number to prevent division by zero. + + Returns: + A tensor of shape `(..., num_intervals, 3)` containing the diagonal + of the covariance matrix. + """ + d_outer_diag = torch.pow(rays_directions, 2) + dir_mag_sq = torch.clamp(torch.sum(d_outer_diag, dim=-1, keepdim=True), min=eps) + + null_outer_diag = 1 - d_outer_diag / dir_mag_sq + ray_dir_cov_diag = rays_dir_variance[..., None] * d_outer_diag[..., None, :] + xy_cov_diag = radii_variance[..., None] * null_outer_diag[..., None, :] + return ray_dir_cov_diag + xy_cov_diag + + +def approximate_conical_frustum_as_gaussians( + bins: torch.Tensor, radii: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Approximates a conical frustum as two Gaussian distributions. + + The Gaussian distributions are characterized by + three values: + + - rays_dir_mean: mean along the rays direction + (defined as t in the parametric representation of a cone). + - rays_dir_variance: the variance of the conical frustum along the rays direction. + - radii_variance: variance of the conical frustum with respect to its radius. + + + The computation is stable and follows equation 7 + of `MIP-NeRF `_. + + For more information on how the mean and variances are computed + refers to the appendix of the paper. + + Args: + bins: A tensor of shape `(..., num_points_per_ray + 1)` + containing the bins at which the rays are sampled. + `bin[..., t]` and `bin[..., t+1]` represent respectively + the left and right coordinates of the interval. + t0: A tensor of shape `(..., num_points_per_ray)` + containing the left coordinates of the intervals + on which the rays are sampled. + t1: A tensor of shape `(..., num_points_per_ray)` + containing the rights coordinates of the intervals + on which the rays are sampled. + radii: A tensor of shape `(..., 1)` + base radii of the conical frustums. + + Returns: + rays_dir_mean: A tensor of shape `(..., num_intervals)` representing + the mean along the rays direction + (t in the parametric represention of the cone) + rays_dir_variance: A tensor of shape `(..., num_intervals)` representing + the variance of the conical frustum along the rays + (t in the parametric represention of the cone). + radii_variance: A tensor of shape `(..., num_intervals)` representing + the variance of the conical frustum with respect to its radius. + """ + t_mu = torch.lerp(bins[..., 1:], bins[..., :-1], 0.5) + t_delta = torch.diff(bins, dim=-1) / 2 + + t_mu_pow2 = torch.pow(t_mu, 2) + t_delta_pow2 = torch.pow(t_delta, 2) + t_delta_pow4 = torch.pow(t_delta, 4) + + den = 3 * t_mu_pow2 + t_delta_pow2 + + # mean along the rays direction + rays_dir_mean = t_mu + 2 * t_mu * t_delta_pow2 / den + + # Variance of the conical frustum with along the rays directions + rays_dir_variance = t_delta_pow2 / 3 - (4 / 15) * ( + t_delta_pow4 * (12 * t_mu_pow2 - t_delta_pow2) / torch.pow(den, 2) + ) + + # Variance of the conical frustum with respect to its radius + radii_variance = torch.pow(radii, 2) * ( + t_mu_pow2 / 4 + (5 / 12) * t_delta_pow2 - 4 / 15 * (t_delta_pow4) / den + ) + return rays_dir_mean, rays_dir_variance, radii_variance + + +def conical_frustum_to_gaussian( + ray_bundle: ImplicitronRayBundle, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Approximate a conical frustum following a ray bundle as a Gaussian. + + Args: + ray_bundle: A `RayBundle` or `HeterogeneousRayBundle` object with fields: + origins: A tensor of shape `(..., 3)` + directions: A tensor of shape `(..., 3)` + lengths: A tensor of shape `(..., num_points_per_ray)` + bins: A tensor of shape `(..., num_points_per_ray + 1)` + containing the bins at which the rays are sampled. . + pixel_radii_2d: A tensor of shape `(..., 1)` + base radii of the conical frustums. + + Returns: + means: A tensor of shape `(..., num_points_per_ray - 1, 3)` + representing the means of the Gaussians + approximating the conical frustums. + diag_covariances: A tensor of shape `(...,num_points_per_ray -1, 3)` + representing the diagonal covariance matrices of our Gaussians. + """ + + if ray_bundle.pixel_radii_2d is None or ray_bundle.bins is None: + raise ValueError( + "RayBundle pixel_radii_2d or bins have not been provided." + " Look at pytorch3d.renderer.implicit.renderer.ray_sampler::" + "AbstractMaskRaySampler to see how to compute them. Have you forgot to set" + "`cast_ray_bundle_as_cone` to True?" + ) + + ( + rays_dir_mean, + rays_dir_variance, + radii_variance, + ) = approximate_conical_frustum_as_gaussians( + ray_bundle.bins, + ray_bundle.pixel_radii_2d, + ) + means = ray_bundle_variables_to_ray_points( + ray_bundle.origins, ray_bundle.directions, rays_dir_mean + ) + diag_covariances = compute_3d_diagonal_covariance_gaussian( + ray_bundle.directions, rays_dir_variance, radii_variance + ) + return means, diag_covariances diff --git a/pytorch3d/implicitron/models/renderer/ray_sampler.py b/pytorch3d/implicitron/models/renderer/ray_sampler.py index 915140c8..d3f1c6b3 100644 --- a/pytorch3d/implicitron/models/renderer/ray_sampler.py +++ b/pytorch3d/implicitron/models/renderer/ray_sampler.py @@ -11,6 +11,7 @@ from pytorch3d.implicitron.tools import camera_utils from pytorch3d.implicitron.tools.config import registry, ReplaceableBase from pytorch3d.renderer import NDCMultinomialRaysampler from pytorch3d.renderer.cameras import CamerasBase +from pytorch3d.renderer.implicit.utils import HeterogeneousRayBundle from .base import EvaluationMode, ImplicitronRayBundle, RenderSamplingMode @@ -83,7 +84,20 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module): stratified_point_sampling_training: if set, performs stratified random sampling along the ray; otherwise takes ray points at deterministic offsets. stratified_point_sampling_evaluation: Same as above but for evaluation. + cast_ray_bundle_as_cone: If True, the sampling will generate the bins and radii + attribute of ImplicitronRayBundle. The `bins` contain the z-coordinate + (=depth) of each ray in world units and are of shape + `(batch_size, n_rays_per_image, n_pts_per_ray_training/evaluation + 1)` + while `lengths` is equal to the midpoint of the bins: + (0.5 * (bins[..., 1:] + bins[..., :-1]). + If False, `bins` is None, `radii` is None and `lengths` contains + the z-coordinate (=depth) of each ray in world units and are of shape + `(batch_size, n_rays_per_image, n_pts_per_ray_training/evaluation)` + Raises: + TypeError: if cast_ray_bundle_as_cone is set to True and n_rays_total_training + is not None will result in an error. HeterogeneousRayBundle is + not supported for conical frustum computation yet. """ image_width: int = 400 @@ -97,6 +111,7 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module): # stratified sampling vs taking points at deterministic offsets stratified_point_sampling_training: bool = True stratified_point_sampling_evaluation: bool = False + cast_ray_bundle_as_cone: bool = False def __post_init__(self): if (self.n_rays_per_image_sampled_from_mask is not None) and ( @@ -114,10 +129,20 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module): ), } + n_pts_per_ray_training = ( + self.n_pts_per_ray_training + 1 + if self.cast_ray_bundle_as_cone + else self.n_pts_per_ray_training + ) + n_pts_per_ray_evaluation = ( + self.n_pts_per_ray_evaluation + 1 + if self.cast_ray_bundle_as_cone + else self.n_pts_per_ray_evaluation + ) self._training_raysampler = NDCMultinomialRaysampler( image_width=self.image_width, image_height=self.image_height, - n_pts_per_ray=self.n_pts_per_ray_training, + n_pts_per_ray=n_pts_per_ray_training, min_depth=0.0, max_depth=0.0, n_rays_per_image=self.n_rays_per_image_sampled_from_mask @@ -132,7 +157,7 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module): self._evaluation_raysampler = NDCMultinomialRaysampler( image_width=self.image_width, image_height=self.image_height, - n_pts_per_ray=self.n_pts_per_ray_evaluation, + n_pts_per_ray=n_pts_per_ray_evaluation, min_depth=0.0, max_depth=0.0, n_rays_per_image=self.n_rays_per_image_sampled_from_mask @@ -143,6 +168,11 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module): stratified_sampling=self.stratified_point_sampling_evaluation, ) + max_y, min_y = self._training_raysampler.max_y, self._training_raysampler.min_y + max_x, min_x = self._training_raysampler.max_x, self._training_raysampler.min_x + self.pixel_height: float = (max_y - min_y) / (self.image_height - 1) + self.pixel_width: float = (max_x - min_x) / (self.image_width - 1) + def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]: raise NotImplementedError() @@ -193,19 +223,34 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module): min_depth=min_depth, max_depth=max_depth, ) - - if isinstance(ray_bundle, tuple): - return ImplicitronRayBundle( - # pyre-ignore[16] - **{k: v for k, v in ray_bundle._asdict().items()} + if self.cast_ray_bundle_as_cone and isinstance( + ray_bundle, HeterogeneousRayBundle + ): + # If this error rises it means that raysampler has among + # its arguments `n_ray_totals`. If it is the case + # then you should update the radii computation and lengths + # computation to handle padding and unpadding. + raise TypeError( + "Heterogeneous ray bundle is not supported for conical frustum computation yet" ) + elif self.cast_ray_bundle_as_cone: + pixel_hw: Tuple[float, float] = (self.pixel_height, self.pixel_width) + pixel_radii_2d = compute_radii(cameras, ray_bundle.xys[..., :2], pixel_hw) + return ImplicitronRayBundle.from_bins( + directions=ray_bundle.directions, + origins=ray_bundle.origins, + bins=ray_bundle.lengths, + xys=ray_bundle.xys, + pixel_radii_2d=pixel_radii_2d, + ) + return ImplicitronRayBundle( directions=ray_bundle.directions, origins=ray_bundle.origins, lengths=ray_bundle.lengths, xys=ray_bundle.xys, - camera_ids=ray_bundle.camera_ids, - camera_counts=ray_bundle.camera_counts, + camera_counts=getattr(ray_bundle, "camera_counts", None), + camera_ids=getattr(ray_bundle, "camera_ids", None), ) @@ -274,3 +319,62 @@ class NearFarRaySampler(AbstractMaskRaySampler): Returns the stored near/far planes. """ return self.min_depth, self.max_depth + + +def compute_radii( + cameras: CamerasBase, + xy_grid: torch.Tensor, + pixel_hw_ndc: Tuple[float, float], +) -> torch.Tensor: + """ + Compute radii of conical frustums in world coordinates. + + Args: + cameras: cameras object representing a batch of cameras. + xy_grid: torch.tensor grid of image xy coords. + pixel_hw_ndc: pixel height and width in NDC + + Returns: + radii: A tensor of shape `(..., 1)` radii of a cone. + """ + batch_size = xy_grid.shape[0] + spatial_size = xy_grid.shape[1:-1] + n_rays_per_image = spatial_size.numel() + + xy = xy_grid.view(batch_size, n_rays_per_image, 2) + + # [batch_size, 3 * n_rays_per_image, 2] + xy = torch.cat( + [ + xy, + # Will allow to find the norm on the x axis + xy + torch.tensor([pixel_hw_ndc[1], 0], device=xy.device), + # Will allow to find the norm on the y axis + xy + torch.tensor([0, pixel_hw_ndc[0]], device=xy.device), + ], + dim=1, + ) + # [batch_size, 3 * n_rays_per_image, 3] + xyz = torch.cat( + ( + xy, + xy.new_ones(batch_size, 3 * n_rays_per_image, 1), + ), + dim=-1, + ) + + # unproject the points + unprojected_xyz = cameras.unproject_points(xyz, from_ndc=True) + + plane_world, plane_world_dx, plane_world_dy = torch.split( + unprojected_xyz, n_rays_per_image, dim=1 + ) + + # Distance from each unit-norm direction vector to its neighbors. + dx_norm = torch.linalg.norm(plane_world_dx - plane_world, dim=-1, keepdims=True) + dy_norm = torch.linalg.norm(plane_world_dy - plane_world, dim=-1, keepdims=True) + # Cut the distance in half to obtain the base radius: (dx_norm + dy_norm) * 0.5 + # Scale it by 2/12**0.5 to match the variance of the pixel’s footprint + radii = (dx_norm + dy_norm) / 12**0.5 + + return radii.view(batch_size, *spatial_size, 1) diff --git a/pytorch3d/implicitron/models/utils.py b/pytorch3d/implicitron/models/utils.py index 94480cd6..b2f7dc66 100644 --- a/pytorch3d/implicitron/models/utils.py +++ b/pytorch3d/implicitron/models/utils.py @@ -177,6 +177,20 @@ def chunk_generator( for start_idx in iter: end_idx = min(start_idx + chunk_size_in_rays, n_rays) + bins = ( + None + if ray_bundle.bins is None + else ray_bundle.bins.reshape(batch_size, n_rays, n_pts_per_ray + 1)[ + :, start_idx:end_idx + ] + ) + pixel_radii_2d = ( + None + if ray_bundle.pixel_radii_2d is None + else ray_bundle.pixel_radii_2d.reshape(batch_size, -1, 1)[ + :, start_idx:end_idx + ] + ) ray_bundle_chunk = ImplicitronRayBundle( origins=ray_bundle.origins.reshape(batch_size, -1, 3)[:, start_idx:end_idx], directions=ray_bundle.directions.reshape(batch_size, -1, 3)[ @@ -186,6 +200,8 @@ def chunk_generator( :, start_idx:end_idx ], xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx], + bins=bins, + pixel_radii_2d=pixel_radii_2d, camera_ids=_safe_slice(ray_bundle.camera_ids, start_idx, end_idx), camera_counts=_safe_slice(ray_bundle.camera_counts, start_idx, end_idx), ) diff --git a/pytorch3d/renderer/implicit/raysampling.py b/pytorch3d/renderer/implicit/raysampling.py index a1453819..c81178af 100644 --- a/pytorch3d/renderer/implicit/raysampling.py +++ b/pytorch3d/renderer/implicit/raysampling.py @@ -58,6 +58,12 @@ class MultinomialRaysampler(torch.nn.Module): coordinate convention. For a raysampler which follows the PyTorch3D coordinate conventions please refer to `NDCMultinomialRaysampler`. As such, `NDCMultinomialRaysampler` is a special case of `MultinomialRaysampler`. + + Attributes: + min_x: The leftmost x-coordinate of each ray's source pixel's center. + max_x: The rightmost x-coordinate of each ray's source pixel's center. + min_y: The topmost y-coordinate of each ray's source pixel's center. + max_y: The bottommost y-coordinate of each ray's source pixel's center. """ def __init__( @@ -107,7 +113,8 @@ class MultinomialRaysampler(torch.nn.Module): self._n_rays_total = n_rays_total self._unit_directions = unit_directions self._stratified_sampling = stratified_sampling - + self.min_x, self.max_x = min_x, max_x + self.min_y, self.max_y = min_y, max_y # get the initial grid of image xy coords y, x = meshgrid_ij( torch.linspace(min_y, max_y, image_height, dtype=torch.float32), diff --git a/tests/common_camera_utils.py b/tests/common_camera_utils.py new file mode 100644 index 00000000..aa7aeb2d --- /dev/null +++ b/tests/common_camera_utils.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import typing + +import torch +from pytorch3d.common.datatypes import Device +from pytorch3d.renderer.cameras import ( + CamerasBase, + FoVOrthographicCameras, + FoVPerspectiveCameras, + OpenGLOrthographicCameras, + OpenGLPerspectiveCameras, + OrthographicCameras, + PerspectiveCameras, + SfMOrthographicCameras, + SfMPerspectiveCameras, +) +from pytorch3d.renderer.fisheyecameras import FishEyeCameras +from pytorch3d.transforms.so3 import so3_exp_map + + +def init_random_cameras( + cam_type: typing.Type[CamerasBase], + batch_size: int, + random_z: bool = False, + device: Device = "cpu", +): + cam_params = {} + T = torch.randn(batch_size, 3) * 0.03 + if not random_z: + T[:, 2] = 4 + R = so3_exp_map(torch.randn(batch_size, 3) * 3.0) + cam_params = {"R": R, "T": T, "device": device} + if cam_type in (OpenGLPerspectiveCameras, OpenGLOrthographicCameras): + cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1 + cam_params["zfar"] = torch.rand(batch_size) * 4 + 1 + cam_params["znear"] + if cam_type == OpenGLPerspectiveCameras: + cam_params["fov"] = torch.rand(batch_size) * 60 + 30 + cam_params["aspect_ratio"] = torch.rand(batch_size) * 0.5 + 0.5 + else: + cam_params["top"] = torch.rand(batch_size) * 0.2 + 0.9 + cam_params["bottom"] = -(torch.rand(batch_size)) * 0.2 - 0.9 + cam_params["left"] = -(torch.rand(batch_size)) * 0.2 - 0.9 + cam_params["right"] = torch.rand(batch_size) * 0.2 + 0.9 + elif cam_type in (FoVPerspectiveCameras, FoVOrthographicCameras): + cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1 + cam_params["zfar"] = torch.rand(batch_size) * 4 + 1 + cam_params["znear"] + if cam_type == FoVPerspectiveCameras: + cam_params["fov"] = torch.rand(batch_size) * 60 + 30 + cam_params["aspect_ratio"] = torch.rand(batch_size) * 0.5 + 0.5 + else: + cam_params["max_y"] = torch.rand(batch_size) * 0.2 + 0.9 + cam_params["min_y"] = -(torch.rand(batch_size)) * 0.2 - 0.9 + cam_params["min_x"] = -(torch.rand(batch_size)) * 0.2 - 0.9 + cam_params["max_x"] = torch.rand(batch_size) * 0.2 + 0.9 + elif cam_type in ( + SfMOrthographicCameras, + SfMPerspectiveCameras, + OrthographicCameras, + PerspectiveCameras, + ): + cam_params["focal_length"] = torch.rand(batch_size) * 10 + 0.1 + cam_params["principal_point"] = torch.randn((batch_size, 2)) + elif cam_type == FishEyeCameras: + cam_params["focal_length"] = torch.rand(batch_size, 1) * 10 + 0.1 + cam_params["principal_point"] = torch.randn((batch_size, 2)) + cam_params["radial_params"] = torch.randn((batch_size, 6)) + cam_params["tangential_params"] = torch.randn((batch_size, 2)) + cam_params["thin_prism_params"] = torch.randn((batch_size, 4)) + + else: + raise ValueError(str(cam_type)) + return cam_type(**cam_params) diff --git a/tests/implicitron/data/overrides.yaml b/tests/implicitron/data/overrides.yaml index e02f7bf6..95899d89 100644 --- a/tests/implicitron/data/overrides.yaml +++ b/tests/implicitron/data/overrides.yaml @@ -62,6 +62,7 @@ raysampler_AdaptiveRaySampler_args: n_rays_total_training: null stratified_point_sampling_training: true stratified_point_sampling_evaluation: false + cast_ray_bundle_as_cone: false scene_extent: 8.0 scene_center: - 0.0 diff --git a/tests/implicitron/test_models_renderer_base.py b/tests/implicitron/test_models_renderer_base.py new file mode 100644 index 00000000..4b7827b1 --- /dev/null +++ b/tests/implicitron/test_models_renderer_base.py @@ -0,0 +1,254 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import unittest + +import numpy as np + +import torch + +from pytorch3d.implicitron.models.renderer.base import ( + approximate_conical_frustum_as_gaussians, + compute_3d_diagonal_covariance_gaussian, + conical_frustum_to_gaussian, + ImplicitronRayBundle, +) +from pytorch3d.implicitron.models.renderer.ray_sampler import AbstractMaskRaySampler + +from tests.common_testing import TestCaseMixin + + +class TestRendererBase(TestCaseMixin, unittest.TestCase): + def test_implicitron_from_bins(self) -> None: + bins = torch.randn(2, 3, 4, 5) + ray_bundle = ImplicitronRayBundle.from_bins( + origins=None, + directions=None, + xys=None, + bins=bins, + ) + self.assertClose(ray_bundle.lengths, 0.5 * (bins[..., 1:] + bins[..., :-1])) + self.assertClose(ray_bundle.bins, bins) + + def test_implicitron_raise_value_error_if_bins_dim_equal_1(self) -> None: + with self.assertRaises(ValueError): + ImplicitronRayBundle.from_bins( + origins=torch.rand(2, 3, 4, 3), + directions=torch.rand(2, 3, 4, 3), + xys=torch.rand(2, 3, 4, 2), + bins=torch.rand(2, 3, 4, 1), + ) + + def test_conical_frustum_to_gaussian(self) -> None: + origins = torch.zeros(3, 3, 3) + directions = torch.tensor( + [ + [[0, 0, 0], [1, 0, 0], [3, 0, 0]], + [[0, 0.25, 0], [1, 0.25, 0], [3, 0.25, 0]], + [[0, 1, 0], [1, 1, 0], [3, 1, 0]], + ] + ) + bins = torch.tensor( + [ + [[0.5, 1.5], [0.3, 0.7], [0.3, 0.7]], + [[0.5, 1.5], [0.3, 0.7], [0.3, 0.7]], + [[0.5, 1.5], [0.3, 0.7], [0.3, 0.7]], + ] + ) + # see test_compute_pixel_radii_from_ray_direction + radii = torch.tensor( + [ + [1.25, 2.25, 2.25], + [1.75, 2.75, 2.75], + [1.75, 2.75, 2.75], + ] + ) + radii = radii[..., None] / 12**0.5 + + # The expected mean and diagonal covariance have been computed + # by hand from the official code of MipNerf. + # https://github.com/google/mipnerf/blob/84c969e0a623edd183b75693aed72a7e7c22902d/internal/mip.py#L125 + # mean, cov_diag = cast_rays(length, origins, directions, radii, 'cone', diag=True) + + expected_mean = torch.tensor( + [ + [ + [[0.0, 0.0, 0.0]], + [[0.5506329, 0.0, 0.0]], + [[1.6518986, 0.0, 0.0]], + ], + [ + [[0.0, 0.28846154, 0.0]], + [[0.5506329, 0.13765822, 0.0]], + [[1.6518986, 0.13765822, 0.0]], + ], + [ + [[0.0, 1.1538461, 0.0]], + [[0.5506329, 0.5506329, 0.0]], + [[1.6518986, 0.5506329, 0.0]], + ], + ] + ) + expected_diag_cov = torch.tensor( + [ + [ + [[0.04544772, 0.04544772, 0.04544772]], + [[0.01130973, 0.03317059, 0.03317059]], + [[0.10178753, 0.03317059, 0.03317059]], + ], + [ + [[0.08907752, 0.00404956, 0.08907752]], + [[0.0142245, 0.04734321, 0.04955113]], + [[0.10212927, 0.04991625, 0.04955113]], + ], + [ + [[0.08907752, 0.0647929, 0.08907752]], + [[0.03608529, 0.03608529, 0.04955113]], + [[0.10674264, 0.05590574, 0.04955113]], + ], + ] + ) + + ray = ImplicitronRayBundle( + origins=origins, + directions=directions, + bins=bins, + lengths=None, + pixel_radii_2d=radii, + xys=None, + ) + mean, diag_cov = conical_frustum_to_gaussian(ray) + + self.assertClose(mean, expected_mean) + self.assertClose(diag_cov, expected_diag_cov) + + def test_scale_conical_frustum_to_gaussian(self) -> None: + origins = torch.zeros(2, 2, 3) + directions = torch.Tensor( + [ + [[0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1]], + ] + ) + bins = torch.Tensor( + [ + [[0.5, 1.5], [0.3, 0.7]], + [[0.5, 1.5], [0.3, 0.7]], + ] + ) + radii = torch.ones(2, 2, 1) + + ray = ImplicitronRayBundle( + origins=origins, + directions=directions, + bins=bins, + pixel_radii_2d=radii, + lengths=None, + xys=None, + ) + + mean, diag_cov = conical_frustum_to_gaussian(ray) + + scaling_factor = 2.5 + ray = ImplicitronRayBundle( + origins=origins, + directions=directions, + bins=bins * scaling_factor, + pixel_radii_2d=radii, + lengths=None, + xys=None, + ) + mean_scaled, diag_cov_scaled = conical_frustum_to_gaussian(ray) + np.testing.assert_allclose(mean * scaling_factor, mean_scaled) + np.testing.assert_allclose( + diag_cov * scaling_factor**2, diag_cov_scaled, atol=1e-6 + ) + + def test_approximate_conical_frustum_as_gaussian(self) -> None: + """Ensure that the computation modularity in our function is well done.""" + bins = torch.Tensor([[0.5, 1.5], [0.3, 0.7]]) + radii = torch.Tensor([[1.0], [1.0]]) + t_mean, t_var, r_var = approximate_conical_frustum_as_gaussians(bins, radii) + + self.assertEqual(t_mean.shape, (2, 1)) + self.assertEqual(t_var.shape, (2, 1)) + self.assertEqual(r_var.shape, (2, 1)) + + mu = np.array([[1.0], [0.5]]) + delta = np.array([[0.5], [0.2]]) + + np.testing.assert_allclose( + mu + (2 * mu * delta**2) / (3 * mu**2 + delta**2), t_mean.numpy() + ) + np.testing.assert_allclose( + (delta**2) / 3 + - (4 / 15) + * ( + (delta**4 * (12 * mu**2 - delta**2)) + / (3 * mu**2 + delta**2) ** 2 + ), + t_var.numpy(), + ) + np.testing.assert_allclose( + radii**2 + * ( + (mu**2) / 4 + + (5 / 12) * delta**2 + - 4 / 15 * (delta**4) / (3 * mu**2 + delta**2) + ), + r_var.numpy(), + ) + + def test_compute_3d_diagonal_covariance_gaussian(self) -> None: + ray_directions = torch.Tensor([[0, 0, 1]]) + t_var = torch.Tensor([0.5, 0.5, 1]) + r_var = torch.Tensor([0.6, 0.3, 0.4]) + expected_diag_cov = np.array( + [ + [ + # t_cov_diag + xy_cov_diag + [0.0 + 0.6, 0.0 + 0.6, 0.5 + 0.0], + [0.0 + 0.3, 0.0 + 0.3, 0.5 + 0.0], + [0.0 + 0.4, 0.0 + 0.4, 1.0 + 0.0], + ] + ] + ) + diag_cov = compute_3d_diagonal_covariance_gaussian(ray_directions, t_var, r_var) + np.testing.assert_allclose(diag_cov.numpy(), expected_diag_cov) + + def test_conical_frustum_to_gaussian_raise_valueerror(self) -> None: + lengths = torch.linspace(0, 1, steps=6) + directions = torch.tensor([0, 0, 1]) + origins = torch.tensor([1, 1, 1]) + ray = ImplicitronRayBundle( + origins=origins, directions=directions, lengths=lengths, xys=None + ) + with self.assertRaises(ValueError) as context: + _ = conical_frustum_to_gaussian(ray) + + expected_error_message = ( + "RayBundle pixel_radii_2d or bins have not been provided." + " Look at pytorch3d.renderer.implicit.renderer.ray_sampler::" + "AbstractMaskRaySampler to see how to compute them. Have you forgot to set" + "`cast_ray_bundle_as_cone` to True?" + ) + + self.assertEqual(expected_error_message, str(context.exception)) + + # Ensure message is coherent with AbstractMaskRaySampler + class FakeRaySampler(AbstractMaskRaySampler): + def _get_min_max_depth_bounds(self, *args): + return None + + message_assertion = ( + "If cast_ray_bundle_as_cone has been removed please update the doc" + "conical_frustum_to_gaussian" + ) + self.assertIsNotNone( + getattr(FakeRaySampler(), "cast_ray_bundle_as_cone", None), + message_assertion, + ) diff --git a/tests/implicitron/test_models_renderer_ray_sampler.py b/tests/implicitron/test_models_renderer_ray_sampler.py new file mode 100644 index 00000000..17c1d132 --- /dev/null +++ b/tests/implicitron/test_models_renderer_ray_sampler.py @@ -0,0 +1,290 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from itertools import product +from typing import Tuple + +from unittest.mock import patch + +import torch +from pytorch3d.common.compat import meshgrid_ij +from pytorch3d.implicitron.models.renderer.base import EvaluationMode +from pytorch3d.implicitron.models.renderer.ray_sampler import ( + AdaptiveRaySampler, + compute_radii, + NearFarRaySampler, +) + +from pytorch3d.renderer.cameras import ( + CamerasBase, + FoVOrthographicCameras, + FoVPerspectiveCameras, + OrthographicCameras, + PerspectiveCameras, +) +from pytorch3d.renderer.implicit.utils import HeterogeneousRayBundle +from tests.common_camera_utils import init_random_cameras + +from tests.common_testing import TestCaseMixin + +CAMERA_TYPES = ( + FoVPerspectiveCameras, + FoVOrthographicCameras, + OrthographicCameras, + PerspectiveCameras, +) + + +def unproject_xy_grid_from_ndc_to_world_coord( + cameras: CamerasBase, xy_grid: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + + Unproject a xy_grid from NDC coordinates to world coordinates. + + Args: + cameras: CamerasBase. + xy_grid: A tensor of shape `(..., H*W, 2)` representing the + x, y coords. + + Returns: + A tensor of shape `(..., H*W, 3)` representing the + """ + + batch_size = xy_grid.shape[0] + n_rays_per_image = xy_grid.shape[1:-1].numel() + xy = xy_grid.view(batch_size, -1, 2) + xyz = torch.cat([xy, xy_grid.new_ones(batch_size, n_rays_per_image, 1)], dim=-1) + plane_at_depth1 = cameras.unproject_points(xyz, from_ndc=True) + return plane_at_depth1.view(*xy_grid.shape[:-1], 3) + + +class TestRaysampler(TestCaseMixin, unittest.TestCase): + def test_ndc_raysampler_n_ray_total_is_none(self): + sampler = NearFarRaySampler() + message = ( + "If you introduce the support of `n_rays_total` for {0}, please handle the " + "packing and unpacking logic for the radii and lengths computation." + ) + self.assertIsNone( + sampler._training_raysampler._n_rays_total, message.format(type(sampler)) + ) + self.assertIsNone( + sampler._evaluation_raysampler._n_rays_total, message.format(type(sampler)) + ) + + sampler = AdaptiveRaySampler() + self.assertIsNone( + sampler._training_raysampler._n_rays_total, message.format(type(sampler)) + ) + self.assertIsNone( + sampler._evaluation_raysampler._n_rays_total, message.format(type(sampler)) + ) + + def test_catch_heterogeneous_exception(self): + cameras = init_random_cameras(FoVPerspectiveCameras, 1, random_z=True) + + class FakeSampler: + def __init__(self): + self.min_x, self.max_x = 1, 2 + self.min_y, self.max_y = 1, 2 + + def __call__(self, **kwargs): + return HeterogeneousRayBundle( + torch.rand(3), torch.rand(3), torch.rand(3), torch.rand(1) + ) + + with patch( + "pytorch3d.implicitron.models.renderer.ray_sampler.NDCMultinomialRaysampler", + return_value=FakeSampler(), + ): + for sampler in [ + AdaptiveRaySampler(cast_ray_bundle_as_cone=True), + NearFarRaySampler(cast_ray_bundle_as_cone=True), + ]: + with self.assertRaises(TypeError): + _ = sampler(cameras, EvaluationMode.TRAINING) + for sampler in [ + AdaptiveRaySampler(cast_ray_bundle_as_cone=False), + NearFarRaySampler(cast_ray_bundle_as_cone=False), + ]: + _ = sampler(cameras, EvaluationMode.TRAINING) + + def test_compute_radii(self): + batch_size = 1 + image_height, image_width = 20, 10 + min_y, max_y, min_x, max_x = -1.0, 1.0, -1.0, 1.0 + y, x = meshgrid_ij( + torch.linspace(min_y, max_y, image_height, dtype=torch.float32), + torch.linspace(min_x, max_x, image_width, dtype=torch.float32), + ) + xy_grid = torch.stack([x, y], dim=-1).view(-1, 2) + pixel_width = (max_x - min_x) / (image_width - 1) + pixel_height = (max_y - min_y) / (image_height - 1) + + for cam_type in CAMERA_TYPES: + # init a batch of random cameras + cameras = init_random_cameras(cam_type, batch_size, random_z=True) + # This method allow us to compute the radii whithout having + # access to the full grid. Raysamplers during the training + # will sample random rays from the grid. + radii = compute_radii( + cameras, xy_grid, pixel_hw_ndc=(pixel_height, pixel_width) + ) + plane_at_depth1 = unproject_xy_grid_from_ndc_to_world_coord( + cameras, xy_grid + ) + # This method absolutely needs the full grid to work. + expected_radii = compute_pixel_radii_from_grid( + plane_at_depth1.reshape(1, image_height, image_width, 3) + ) + self.assertClose(expected_radii.reshape(-1, 1), radii) + + def test_forward(self): + n_rays_per_image = 16 + image_height, image_width = 20, 20 + kwargs = { + "image_width": image_width, + "image_height": image_height, + "n_pts_per_ray_training": 32, + "n_pts_per_ray_evaluation": 32, + "n_rays_per_image_sampled_from_mask": n_rays_per_image, + "cast_ray_bundle_as_cone": False, + } + + batch_size = 2 + samplers = [NearFarRaySampler(**kwargs), AdaptiveRaySampler(**kwargs)] + evaluation_modes = [EvaluationMode.TRAINING, EvaluationMode.EVALUATION] + + for cam_type, sampler, evaluation_mode in product( + CAMERA_TYPES, samplers, evaluation_modes + ): + cameras = init_random_cameras(cam_type, batch_size, random_z=True) + ray_bundle = sampler(cameras, evaluation_mode) + + shape_out = ( + (batch_size, image_width, image_height) + if evaluation_mode == EvaluationMode.EVALUATION + else (batch_size, n_rays_per_image, 1) + ) + n_pts_per_ray = ( + kwargs["n_pts_per_ray_evaluation"] + if evaluation_mode == EvaluationMode.EVALUATION + else kwargs["n_pts_per_ray_training"] + ) + self.assertIsNone(ray_bundle.bins) + self.assertIsNone(ray_bundle.pixel_radii_2d) + self.assertEqual( + ray_bundle.lengths.shape, + (*shape_out, n_pts_per_ray), + ) + self.assertEqual(ray_bundle.directions.shape, (*shape_out, 3)) + self.assertEqual(ray_bundle.origins.shape, (*shape_out, 3)) + + def test_forward_with_use_bins(self): + n_rays_per_image = 16 + image_height, image_width = 20, 20 + kwargs = { + "image_width": image_width, + "image_height": image_height, + "n_pts_per_ray_training": 32, + "n_pts_per_ray_evaluation": 32, + "n_rays_per_image_sampled_from_mask": n_rays_per_image, + "cast_ray_bundle_as_cone": True, + } + + batch_size = 1 + samplers = [NearFarRaySampler(**kwargs), AdaptiveRaySampler(**kwargs)] + evaluation_modes = [EvaluationMode.TRAINING, EvaluationMode.EVALUATION] + for cam_type, sampler, evaluation_mode in product( + CAMERA_TYPES, samplers, evaluation_modes + ): + cameras = init_random_cameras(cam_type, batch_size, random_z=True) + ray_bundle = sampler(cameras, evaluation_mode) + + lengths = 0.5 * (ray_bundle.bins[..., :-1] + ray_bundle.bins[..., 1:]) + + self.assertClose(ray_bundle.lengths, lengths) + shape_out = ( + (batch_size, image_width, image_height) + if evaluation_mode == EvaluationMode.EVALUATION + else (batch_size, n_rays_per_image, 1) + ) + self.assertEqual(ray_bundle.pixel_radii_2d.shape, (*shape_out, 1)) + self.assertEqual(ray_bundle.directions.shape, (*shape_out, 3)) + self.assertEqual(ray_bundle.origins.shape, (*shape_out, 3)) + + +# Helper to test compute_radii +def compute_pixel_radii_from_grid(pixel_grid: torch.Tensor) -> torch.Tensor: + """ + Compute the radii of a conical frustum given the pixel grid. + + To compute the radii we first compute the translation from a pixel + to its neighbors along the x and y axis. Then, we compute the norm + of each translation along the x and y axis. + The radii are then obtained by the following formula: + + (dx_norm + dy_norm) * 0.5 * 2 / 12**0.5 + + where 2/12**0.5 is a scaling factor to match + the variance of the pixel’s footprint. + + Args: + pixel_grid: A tensor of shape `(..., H, W, dim)` representing the + full grid of rays pixel_grid. + + Returns: + The radiis for each pixels and shape `(..., H, W, 1)`. + """ + # [B, H, W - 1, 3] + x_translation = torch.diff(pixel_grid, dim=-2) + # [B, H - 1, W, 3] + y_translation = torch.diff(pixel_grid, dim=-3) + # [B, H, W - 1, 1] + dx_norm = torch.linalg.norm(x_translation, dim=-1, keepdim=True) + # [B, H - 1, W, 1] + dy_norm = torch.linalg.norm(y_translation, dim=-1, keepdim=True) + + # Fill the missing value [B, H, W, 1] + dx_norm = torch.concatenate([dx_norm, dx_norm[..., -1:, :]], -2) + dy_norm = torch.concatenate([dy_norm, dy_norm[..., -1:, :, :]], -3) + + # Cut the distance in half to obtain the base radius: (dx_norm + dy_norm) * 0.5 + # and multiply it by the scaling factor: * 2 / 12**0.5 + radii = (dx_norm + dy_norm) / 12**0.5 + return radii + + +class TestRadiiComputationOnFullGrid(TestCaseMixin, unittest.TestCase): + def test_compute_pixel_radii_from_grid(self): + pixel_grid = torch.tensor( + [ + [[0.0, 0, 0], [1.0, 0.0, 0], [3.0, 0.0, 0.0]], + [[0.0, 0.25, 0], [1.0, 0.25, 0], [3.0, 0.25, 0]], + [[0.0, 1, 0], [1.0, 1.0, 0], [3.0000, 1.0, 0]], + ] + ) + + expected_y_norm = torch.tensor( + [ + [0.25, 0.25, 0.25], + [0.75, 0.75, 0.75], + [0.75, 0.75, 0.75], # duplicated from previous row + ] + ) + expected_x_norm = torch.tensor( + [ + # 3rd column is duplicated from 2nd + [1.0, 2.0, 2.0], + [1.0, 2.0, 2.0], + [1.0, 2.0, 2.0], + ] + ) + expected_radii = (expected_x_norm + expected_y_norm) / 12**0.5 + radii = compute_pixel_radii_from_grid(pixel_grid) + self.assertClose(radii, expected_radii[..., None]) diff --git a/tests/test_cameras.py b/tests/test_cameras.py index 043d6d51..7ca86d7d 100644 --- a/tests/test_cameras.py +++ b/tests/test_cameras.py @@ -32,7 +32,6 @@ import math import pickle -import typing import unittest from itertools import product @@ -60,6 +59,8 @@ from pytorch3d.transforms import Transform3d from pytorch3d.transforms.rotation_conversions import random_rotations from pytorch3d.transforms.so3 import so3_exp_map +from .common_camera_utils import init_random_cameras + from .common_testing import TestCaseMixin @@ -151,60 +152,6 @@ def ndc_to_screen_points_naive(points, imsize): return torch.stack((x, y, z), dim=2) -def init_random_cameras( - cam_type: typing.Type[CamerasBase], - batch_size: int, - random_z: bool = False, - device: Device = "cpu", -): - cam_params = {} - T = torch.randn(batch_size, 3) * 0.03 - if not random_z: - T[:, 2] = 4 - R = so3_exp_map(torch.randn(batch_size, 3) * 3.0) - cam_params = {"R": R, "T": T, "device": device} - if cam_type in (OpenGLPerspectiveCameras, OpenGLOrthographicCameras): - cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1 - cam_params["zfar"] = torch.rand(batch_size) * 4 + 1 + cam_params["znear"] - if cam_type == OpenGLPerspectiveCameras: - cam_params["fov"] = torch.rand(batch_size) * 60 + 30 - cam_params["aspect_ratio"] = torch.rand(batch_size) * 0.5 + 0.5 - else: - cam_params["top"] = torch.rand(batch_size) * 0.2 + 0.9 - cam_params["bottom"] = -(torch.rand(batch_size)) * 0.2 - 0.9 - cam_params["left"] = -(torch.rand(batch_size)) * 0.2 - 0.9 - cam_params["right"] = torch.rand(batch_size) * 0.2 + 0.9 - elif cam_type in (FoVPerspectiveCameras, FoVOrthographicCameras): - cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1 - cam_params["zfar"] = torch.rand(batch_size) * 4 + 1 + cam_params["znear"] - if cam_type == FoVPerspectiveCameras: - cam_params["fov"] = torch.rand(batch_size) * 60 + 30 - cam_params["aspect_ratio"] = torch.rand(batch_size) * 0.5 + 0.5 - else: - cam_params["max_y"] = torch.rand(batch_size) * 0.2 + 0.9 - cam_params["min_y"] = -(torch.rand(batch_size)) * 0.2 - 0.9 - cam_params["min_x"] = -(torch.rand(batch_size)) * 0.2 - 0.9 - cam_params["max_x"] = torch.rand(batch_size) * 0.2 + 0.9 - elif cam_type in ( - SfMOrthographicCameras, - SfMPerspectiveCameras, - OrthographicCameras, - PerspectiveCameras, - ): - cam_params["focal_length"] = torch.rand(batch_size) * 10 + 0.1 - cam_params["principal_point"] = torch.randn((batch_size, 2)) - elif cam_type == FishEyeCameras: - cam_params["focal_length"] = torch.rand(batch_size, 1) * 10 + 0.1 - cam_params["principal_point"] = torch.randn((batch_size, 2)) - cam_params["radial_params"] = torch.randn((batch_size, 6)) - cam_params["tangential_params"] = torch.randn((batch_size, 2)) - cam_params["thin_prism_params"] = torch.randn((batch_size, 4)) - - else: - raise ValueError(str(cam_type)) - return cam_type(**cam_params) - - class TestCameraHelpers(TestCaseMixin, unittest.TestCase): def setUp(self) -> None: super().setUp()