From 6ae863f301c66b82c8caf18e12cbae17d2620415 Mon Sep 17 00:00:00 2001 From: Darijan Gudelj Date: Fri, 30 Sep 2022 04:03:01 -0700 Subject: [PATCH] Heterogeneous raysampling -> RayBundleHeterogeneous Summary: Added heterogeneous raysampling to pytorch3d raysampler, different cameras are sampled different number of times. It now returns RayBundle if heterogeneous raysampling is off and new RayBundleHeterogeneous (with added fields `camera_ids` and `camera_counts`). Heterogeneous raysampling is on if `n_rays_total` is not None. Reviewed By: bottler Differential Revision: D39542222 fbshipit-source-id: d3d88d822ec7696e856007c088dc36a1cfa8c625 --- pytorch3d/renderer/__init__.py | 1 + pytorch3d/renderer/implicit/__init__.py | 1 + pytorch3d/renderer/implicit/raysampling.py | 208 +++++++++++++++++++-- pytorch3d/renderer/implicit/renderer.py | 38 ++-- pytorch3d/renderer/implicit/utils.py | 63 +++++-- tests/test_raysampling.py | 62 +++++- 6 files changed, 325 insertions(+), 48 deletions(-) diff --git a/pytorch3d/renderer/__init__.py b/pytorch3d/renderer/__init__.py index 07edc238..a667b012 100644 --- a/pytorch3d/renderer/__init__.py +++ b/pytorch3d/renderer/__init__.py @@ -31,6 +31,7 @@ from .implicit import ( EmissionAbsorptionRaymarcher, GridRaysampler, HarmonicEmbedding, + HeterogeneousRayBundle, ImplicitRenderer, MonteCarloRaysampler, MultinomialRaysampler, diff --git a/pytorch3d/renderer/implicit/__init__.py b/pytorch3d/renderer/implicit/__init__.py index bb617df8..39090112 100644 --- a/pytorch3d/renderer/implicit/__init__.py +++ b/pytorch3d/renderer/implicit/__init__.py @@ -15,6 +15,7 @@ from .raysampling import ( ) from .renderer import ImplicitRenderer, VolumeRenderer, VolumeSampler from .utils import ( + HeterogeneousRayBundle, ray_bundle_to_ray_points, ray_bundle_variables_to_ray_points, RayBundle, diff --git a/pytorch3d/renderer/implicit/raysampling.py b/pytorch3d/renderer/implicit/raysampling.py index fb55ba7b..c53754e8 100644 --- a/pytorch3d/renderer/implicit/raysampling.py +++ b/pytorch3d/renderer/implicit/raysampling.py @@ -5,12 +5,13 @@ # LICENSE file in the root directory of this source tree. import warnings -from typing import Optional +from typing import Optional, Tuple, Union import torch from pytorch3d.common.compat import meshgrid_ij +from pytorch3d.ops import padded_to_packed from pytorch3d.renderer.cameras import CamerasBase -from pytorch3d.renderer.implicit.utils import RayBundle +from pytorch3d.renderer.implicit.utils import HeterogeneousRayBundle, RayBundle from torch.nn import functional as F @@ -73,6 +74,7 @@ class MultinomialRaysampler(torch.nn.Module): min_depth: float, max_depth: float, n_rays_per_image: Optional[int] = None, + n_rays_total: Optional[int] = None, unit_directions: bool = False, stratified_sampling: bool = False, ) -> None: @@ -88,6 +90,11 @@ class MultinomialRaysampler(torch.nn.Module): min_depth: The minimum depth of a ray-point. max_depth: The maximum depth of a ray-point. n_rays_per_image: If given, this amount of rays are sampled from the grid. + n_rays_total: How many rays in total to sample from the cameras provided. The result + is as if `n_rays_total` cameras were sampled with replacement from the + cameras provided and for every camera one ray was sampled. If set, this disables + `n_rays_per_image` and returns the HeterogeneousRayBundle with + batch_size=n_rays_total. unit_directions: whether to normalize direction vectors in ray bundle. stratified_sampling: if True, performs stratified random sampling along the ray; otherwise takes ray points at deterministic offsets. @@ -97,6 +104,7 @@ class MultinomialRaysampler(torch.nn.Module): self._min_depth = min_depth self._max_depth = max_depth self._n_rays_per_image = n_rays_per_image + self._n_rays_total = n_rays_total self._unit_directions = unit_directions self._stratified_sampling = stratified_sampling @@ -125,8 +133,9 @@ class MultinomialRaysampler(torch.nn.Module): n_rays_per_image: Optional[int] = None, n_pts_per_ray: Optional[int] = None, stratified_sampling: Optional[bool] = None, + n_rays_total: Optional[int] = None, **kwargs, - ) -> RayBundle: + ) -> Union[RayBundle, HeterogeneousRayBundle]: """ Args: cameras: A batch of `batch_size` cameras from which the rays are emitted. @@ -138,8 +147,15 @@ class MultinomialRaysampler(torch.nn.Module): n_pts_per_ray: The number of points sampled along each ray. stratified_sampling: if set, overrides stratified_sampling provided in __init__. + n_rays_total: How many rays in total to sample from the cameras provided. The result + is as if `n_rays_total_training` cameras were sampled with replacement from the + cameras provided and for every camera one ray was sampled. If set, this disables + `n_rays_per_image` and returns the HeterogeneousRayBundle with + batch_size=n_rays_total. Returns: - A named tuple RayBundle with the following fields: + A named tuple RayBundle or dataclass HeterogeneousRayBundle with the + following fields: + origins: A tensor of shape `(batch_size, s1, s2, 3)` denoting the locations of ray origins in the world coordinates. @@ -153,23 +169,56 @@ class MultinomialRaysampler(torch.nn.Module): `(batch_size, s1, s2, 2)` containing the 2D image coordinates of each ray or, if mask is given, `(batch_size, n, 1, 2)` - Here `s1, s2` refer to spatial dimensions. Unless the mask is - given, they equal `(image_height, image_width)`, otherwise `(n, 1)`, - where `n` is `n_rays_per_image` if provided, otherwise the minimum - cardinality of the mask in the batch. + Here `s1, s2` refer to spatial dimensions. + `(s1, s2)` refer to (highest priority first): + - `(1, 1)` if `n_rays_total` is provided, (batch_size=n_rays_total) + - `(n_rays_per_image, 1) if `n_rays_per_image` if provided, + - `(n, 1)` where n is the minimum cardinality of the mask + in the batch if `mask` is provided + - `(image_height, image_width)` if nothing from above is satisfied + + `HeterogeneousRayBundle` has additional members: + - camera_ids: tensor of shape (M,), where `M` is the number of unique sampled + cameras. It represents unique ids of sampled cameras. + - camera_counts: tensor of shape (M,), where `M` is the number of unique sampled + cameras. Represents how many times each camera from `camera_ids` was sampled + + `HeterogeneousRayBundle` is returned if `n_rays_total` is provided else `RayBundle` + is returned. """ + n_rays_total = n_rays_total or self._n_rays_total + n_rays_per_image = n_rays_per_image or self._n_rays_per_image + assert (n_rays_total is None) or ( + n_rays_per_image is None + ), "`n_rays_total` and `n_rays_per_image` cannot both be defined." + if n_rays_total: + ( + cameras, + mask, + camera_ids, # unique ids of sampled cameras + camera_counts, # number of times unique camera id was sampled + # `n_rays_per_image` is equal to the max number of times a simgle camera + # was sampled. We sample all cameras at `camera_ids` `n_rays_per_image` times + # and then discard the unneeded rays. + # pyre-ignore[9] + n_rays_per_image, + ) = _sample_cameras_and_masks(n_rays_total, cameras, mask) + else: + camera_ids = torch.range(0, len(cameras), dtype=torch.long) + batch_size = cameras.R.shape[0] device = cameras.device # expand the (H, W, 2) grid batch_size-times to (B, H, W, 2) xy_grid = self._xy_grid.to(device).expand(batch_size, -1, -1, -1) - num_rays = n_rays_per_image or self._n_rays_per_image - if mask is not None and num_rays is None: + if mask is not None and n_rays_per_image is None: # if num rays not given, sample according to the smallest mask - num_rays = num_rays or mask.sum(dim=(1, 2)).min().int().item() + n_rays_per_image = ( + n_rays_per_image or mask.sum(dim=(1, 2)).min().int().item() + ) - if num_rays is not None: + if n_rays_per_image is not None: if mask is not None: assert mask.shape == xy_grid.shape[:3] weights = mask.reshape(batch_size, -1) @@ -181,7 +230,9 @@ class MultinomialRaysampler(torch.nn.Module): weights = xy_grid.new_ones(batch_size, width * height) # pyre-fixme[6]: For 2nd param expected `int` but got `Union[bool, # float, int]`. - rays_idx = _safe_multinomial(weights, num_rays)[..., None].expand(-1, -1, 2) + rays_idx = _safe_multinomial(weights, n_rays_per_image)[..., None].expand( + -1, -1, 2 + ) xy_grid = torch.gather(xy_grid.reshape(batch_size, -1, 2), 1, rays_idx)[ :, :, None @@ -198,7 +249,7 @@ class MultinomialRaysampler(torch.nn.Module): else self._stratified_sampling ) - return _xy_to_ray_bundle( + ray_bundle = _xy_to_ray_bundle( cameras, xy_grid, min_depth, @@ -208,6 +259,13 @@ class MultinomialRaysampler(torch.nn.Module): stratified_sampling, ) + return ( + # pyre-ignore[61] + _pack_ray_bundle(ray_bundle, camera_ids, camera_counts) + if n_rays_total + else ray_bundle + ) + class NDCMultinomialRaysampler(MultinomialRaysampler): """ @@ -231,6 +289,7 @@ class NDCMultinomialRaysampler(MultinomialRaysampler): min_depth: float, max_depth: float, n_rays_per_image: Optional[int] = None, + n_rays_total: Optional[int] = None, unit_directions: bool = False, stratified_sampling: bool = False, ) -> None: @@ -254,6 +313,7 @@ class NDCMultinomialRaysampler(MultinomialRaysampler): min_depth=min_depth, max_depth=max_depth, n_rays_per_image=n_rays_per_image, + n_rays_total=n_rays_total, unit_directions=unit_directions, stratified_sampling=stratified_sampling, ) @@ -281,6 +341,7 @@ class MonteCarloRaysampler(torch.nn.Module): min_depth: float, max_depth: float, *, + n_rays_total: Optional[int] = None, unit_directions: bool = False, stratified_sampling: bool = False, ) -> None: @@ -294,6 +355,11 @@ class MonteCarloRaysampler(torch.nn.Module): n_pts_per_ray: The number of points sampled along each ray. min_depth: The minimum depth of each ray-point. max_depth: The maximum depth of each ray-point. + n_rays_total: How many rays in total to sample from the cameras provided. The result + is as if `n_rays_total_training` cameras were sampled with replacement from the + cameras provided and for every camera one ray was sampled. If set, this disables + `n_rays_per_image` and returns the HeterogeneousRayBundleyBundle with + batch_size=n_rays_total. unit_directions: whether to normalize direction vectors in ray bundle. stratified_sampling: if True, performs stratified sampling in n_pts_per_ray bins for each ray; otherwise takes n_pts_per_ray deterministic points @@ -308,6 +374,7 @@ class MonteCarloRaysampler(torch.nn.Module): self._n_pts_per_ray = n_pts_per_ray self._min_depth = min_depth self._max_depth = max_depth + self._n_rays_total = n_rays_total self._unit_directions = unit_directions self._stratified_sampling = stratified_sampling @@ -317,15 +384,16 @@ class MonteCarloRaysampler(torch.nn.Module): *, stratified_sampling: Optional[bool] = None, **kwargs, - ) -> RayBundle: + ) -> Union[RayBundle, HeterogeneousRayBundle]: """ Args: cameras: A batch of `batch_size` cameras from which the rays are emitted. stratified_sampling: if set, overrides stratified_sampling provided in __init__. - Returns: - A named tuple RayBundle with the following fields: + A named tuple `RayBundle` or dataclass `HeterogeneousRayBundle` with the + following fields: + origins: A tensor of shape `(batch_size, n_rays_per_image, 3)` denoting the locations of ray origins in the world coordinates. @@ -338,7 +406,31 @@ class MonteCarloRaysampler(torch.nn.Module): xys: A tensor of shape `(batch_size, n_rays_per_image, 2)` containing the 2D image coordinates of each ray. + If `n_rays_total` is provided `batch_size=n_rays_total`and + `n_rays_per_image=1` and `HeterogeneousRayBundle` is returned else `RayBundle` + is returned. + + `HeterogeneousRayBundle` has additional members: + - camera_ids: tensor of shape (M,), where `M` is the number of unique sampled + cameras. It represents unique ids of sampled cameras. + - camera_counts: tensor of shape (M,), where `M` is the number of unique sampled + cameras. Represents how many times each camera from `camera_ids` was sampled """ + assert (self._n_rays_total is None) or ( + self._n_rays_per_image is None + ), "`self.n_rays_total` and `self.n_rays_per_image` cannot both be defined." + + if self._n_rays_total: + ( + cameras, + _, + camera_ids, + camera_counts, + n_rays_per_image, + ) = _sample_cameras_and_masks(self._n_rays_total, cameras, None) + else: + camera_ids = torch.range(0, len(cameras), dtype=torch.long) + n_rays_per_image = self._n_rays_per_image batch_size = cameras.R.shape[0] @@ -349,7 +441,7 @@ class MonteCarloRaysampler(torch.nn.Module): rays_xy = torch.cat( [ torch.rand( - size=(batch_size, self._n_rays_per_image, 1), + size=(batch_size, n_rays_per_image, 1), dtype=torch.float32, device=device, ) @@ -369,7 +461,7 @@ class MonteCarloRaysampler(torch.nn.Module): else self._stratified_sampling ) - return _xy_to_ray_bundle( + ray_bundle = _xy_to_ray_bundle( cameras, rays_xy, self._min_depth, @@ -379,6 +471,13 @@ class MonteCarloRaysampler(torch.nn.Module): stratified_sampling, ) + return ( + # pyre-ignore[61] + _pack_ray_bundle(ray_bundle, camera_ids, camera_counts) + if self._n_rays_total + else ray_bundle + ) + # Settings for backwards compatibility def GridRaysampler( @@ -602,3 +701,74 @@ def _jiggle_within_stratas(bin_centers: torch.Tensor) -> torch.Tensor: # Samples in those intervals. jiggled = lower + (upper - lower) * torch.rand_like(lower) return jiggled + + +def _sample_cameras_and_masks( + n_samples: int, cameras: CamerasBase, mask: Optional[torch.Tensor] = None +) -> Tuple[ + CamerasBase, Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor +]: + """ + Samples n_rays_total cameras and masks and returns them in a form + (camera_idx, count), where count represents number of times the same camera + has been sampled. + + Args: + n_samples: how many camera and mask pairs to sample + cameras: A batch of `batch_size` cameras from which the rays are emitted. + mask: Optional. Should be of size (batch_size, image_height, image_width). + Returns: + tuple of a form (sampled_cameras, sampled_masks, unique_sampled_camera_ids, + number_of_times_each_sampled_camera_has_been_sampled, + max_number_of_times_camera_has_been_sampled, + ) + """ + sampled_ids = torch.randint( + 0, + len(cameras), + size=(n_samples,), + dtype=torch.long, + ) + unique_ids, counts = torch.unique(sampled_ids, return_counts=True) + return ( + cameras[unique_ids], + mask[unique_ids] if mask is not None else None, + unique_ids, + counts, + torch.max(counts), + ) + + +def _pack_ray_bundle( + ray_bundle: RayBundle, camera_ids: torch.Tensor, camera_counts: torch.Tensor +) -> HeterogeneousRayBundle: + """ + Pack the raybundle from [n_cameras, max(rays_per_camera), ...] to + [total_num_rays, 1, ...] + + Args: + ray_bundle: A ray_bundle to pack + camera_ids: Unique ids of cameras that were sampled + camera_counts: how many of which camera to pack, each count coresponds to + one 'row' of the ray_bundle and says how many rays wll be taken + from it and packed. + Returns: + HeterogeneousRayBundle where batch_size=sum(camera_counts) and n_rays_per_image=1 + """ + camera_counts = camera_counts.to(ray_bundle.origins.device) + cumsum = torch.cumsum(camera_counts, dim=0, dtype=torch.long) + first_idxs = torch.cat( + (camera_counts.new_zeros((1,), dtype=torch.long), cumsum[:-1]) + ) + num_inputs = int(camera_counts.sum()) + + return HeterogeneousRayBundle( + origins=padded_to_packed(ray_bundle.origins, first_idxs, num_inputs)[:, None], + directions=padded_to_packed(ray_bundle.directions, first_idxs, num_inputs)[ + :, None + ], + lengths=padded_to_packed(ray_bundle.lengths, first_idxs, num_inputs)[:, None], + xys=padded_to_packed(ray_bundle.xys, first_idxs, num_inputs)[:, None], + camera_ids=camera_ids, + camera_counts=camera_counts, + ) diff --git a/pytorch3d/renderer/implicit/renderer.py b/pytorch3d/renderer/implicit/renderer.py index f476fb3b..c2be5adc 100644 --- a/pytorch3d/renderer/implicit/renderer.py +++ b/pytorch3d/renderer/implicit/renderer.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable, Tuple +from typing import Callable, Tuple, Union import torch @@ -12,7 +12,7 @@ from ...ops.utils import eyes from ...structures import Volumes from ...transforms import Transform3d from ..cameras import CamerasBase -from .raysampling import RayBundle +from .raysampling import HeterogeneousRayBundle, RayBundle from .utils import _validate_ray_bundle_variables, ray_bundle_variables_to_ray_points @@ -44,12 +44,14 @@ class ImplicitRenderer(torch.nn.Module): A standard `volumetric_function` has the following signature: ``` def volumetric_function( - ray_bundle: RayBundle, + ray_bundle: Union[RayBundle, HeterogeneousRayBundle], **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor] ``` With the following arguments: - `ray_bundle`: A RayBundle object containing the following variables: + `ray_bundle`: A RayBundle or HeterogeneousRayBundle object + containing the following variables: + `origins`: A tensor of shape `(minibatch, ..., 3)` denoting the origins of the rendering rays. `directions`: A tensor of shape `(minibatch, ..., 3)` @@ -80,7 +82,7 @@ class ImplicitRenderer(torch.nn.Module): RGB sphere with a unit diameter is defined as follows: ``` def volumetric_function( - ray_bundle: RayBundle, + ray_bundle: Union[RayBundle, HeterogeneousRayBundle], **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -109,7 +111,8 @@ class ImplicitRenderer(torch.nn.Module): """ Args: raysampler: A `Callable` that takes as input scene cameras - (an instance of `CamerasBase`) and returns a `RayBundle` that + (an instance of `CamerasBase`) and returns a + RayBundle or HeterogeneousRayBundle, that describes the rays emitted from the cameras. raymarcher: A `Callable` that receives the response of the `volumetric_function` (an input to `self.forward`) evaluated @@ -128,7 +131,7 @@ class ImplicitRenderer(torch.nn.Module): def forward( self, cameras: CamerasBase, volumetric_function: Callable, **kwargs - ) -> Tuple[torch.Tensor, RayBundle]: + ) -> Tuple[torch.Tensor, Union[RayBundle, HeterogeneousRayBundle]]: """ Render a batch of images using a volumetric function represented as a callable (e.g. a Pytorch module). @@ -145,15 +148,15 @@ class ImplicitRenderer(torch.nn.Module): Returns: images: A tensor of shape `(minibatch, ..., feature_dim + opacity_dim)` containing the result of the rendering. - ray_bundle: A `RayBundle` containing the parametrizations of the - sampled rendering rays. + ray_bundle: A `Union[RayBundle, HeterogeneousRayBundle]` containing + the parametrizations of the sampled rendering rays. """ if not callable(volumetric_function): raise ValueError('"volumetric_function" has to be a "Callable" object.') - # first call the ray sampler that returns the RayBundle parametrizing - # the rendering rays. + # first call the ray sampler that returns the RayBundle or HeterogeneousRayBundle + # parametrizing the rendering rays. ray_bundle = self.raysampler( cameras=cameras, volumetric_function=volumetric_function, **kwargs ) @@ -211,7 +214,8 @@ class VolumeRenderer(torch.nn.Module): """ Args: raysampler: A `Callable` that takes as input scene cameras - (an instance of `CamerasBase`) and returns a `RayBundle` that + (an instance of `CamerasBase`) and returns a + `Union[RayBundle, HeterogeneousRayBundle],` that describes the rays emitted from the cameras. raymarcher: A `Callable` that receives the `volumes` (an instance of `Volumes` input to `self.forward`) @@ -227,7 +231,7 @@ class VolumeRenderer(torch.nn.Module): def forward( self, cameras: CamerasBase, volumes: Volumes, **kwargs - ) -> Tuple[torch.Tensor, RayBundle]: + ) -> Tuple[torch.Tensor, Union[RayBundle, HeterogeneousRayBundle]]: """ Render a batch of images using raymarching over rays cast through input `Volumes`. @@ -242,8 +246,8 @@ class VolumeRenderer(torch.nn.Module): Returns: images: A tensor of shape `(minibatch, ..., (feature_dim + opacity_dim)` containing the result of the rendering. - ray_bundle: A `RayBundle` containing the parametrizations of the - sampled rendering rays. + ray_bundle: A `RayBundle` or `HeterogeneousRayBundle` containing the + parametrizations of the sampled rendering rays. """ volumetric_function = VolumeSampler(volumes, sample_mode=self._sample_mode) return self.renderer( @@ -288,14 +292,14 @@ class VolumeSampler(torch.nn.Module): return directions_transform def forward( - self, ray_bundle: RayBundle, **kwargs + self, ray_bundle: Union[RayBundle, HeterogeneousRayBundle], **kwargs ) -> Tuple[torch.Tensor, torch.Tensor]: """ Given an input ray parametrization, the forward function samples `self._volumes` at the respective 3D ray-points. Args: - ray_bundle: A RayBundle object with the following fields: + ray_bundle: A RayBundle or HeterogeneousRayBundle object with the following fields: rays_origins_world: A tensor of shape `(minibatch, ..., 3)` denoting the origins of the sampling rays in world coords. rays_directions_world: A tensor of shape `(minibatch, ..., 3)` diff --git a/pytorch3d/renderer/implicit/utils.py b/pytorch3d/renderer/implicit/utils.py index 4d26391b..6ccae29b 100644 --- a/pytorch3d/renderer/implicit/utils.py +++ b/pytorch3d/renderer/implicit/utils.py @@ -4,19 +4,25 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import NamedTuple +import dataclasses +from typing import NamedTuple, Optional, Union import torch class RayBundle(NamedTuple): """ - RayBundle parametrizes points along projection rays by storing ray `origins`, - `directions` vectors and `lengths` at which the ray-points are sampled. - Furthermore, the xy-locations (`xys`) of the ray pixels are stored as well. - Note that `directions` don't have to be normalized; they define unit vectors - in the respective 1D coordinate systems; see documentation for - :func:`ray_bundle_to_ray_points` for the conversion formula. + Parametrizes points along projection rays by storing: + + origins: A tensor of shape `(..., 3)` denoting the + origins of the sampling rays in world coords. + directions: A tensor of shape `(..., 3)` containing the direction + vectors of sampling rays in world coords. They don't have to be normalized; + they define unit vectors in the respective 1D coordinate systems; see + documentation for :func:`ray_bundle_to_ray_points` for the conversion formula. + lengths: A tensor of shape `(..., num_points_per_ray)` + containing the lengths at which the rays are sampled. + xys: A tensor of shape `(..., 2)`, the xy-locations (`xys`) of the ray pixels """ origins: torch.Tensor @@ -25,11 +31,46 @@ class RayBundle(NamedTuple): xys: torch.Tensor -def ray_bundle_to_ray_points(ray_bundle: RayBundle) -> torch.Tensor: +@dataclasses.dataclass +class HeterogeneousRayBundle: + """ + Members: + origins: A tensor of shape `(..., 3)` denoting the + origins of the sampling rays in world coords. + directions: A tensor of shape `(..., 3)` containing the direction + vectors of sampling rays in world coords. They don't have to be normalized; + they define unit vectors in the respective 1D coordinate systems; see + documentation for :func:`ray_bundle_to_ray_points` for the conversion formula. + lengths: A tensor of shape `(..., num_points_per_ray)` + containing the lengths at which the rays are sampled. + xys: A tensor of shape `(..., 2)`, the xy-locations (`xys`) of the ray pixels + camera_ids: A tensor of shape (N, ) which indicates which camera + was used to sample the rays. `N` is the number of unique sampled cameras. + camera_counts: A tensor of shape (N, ) which how many times the + coresponding camera in `camera_ids` was sampled. + `sum(camera_counts)==total_number_of_rays` + + If we sample cameras of ids [0, 3, 5, 3, 1, 0, 0] that would be + stored as camera_ids=[1, 3, 5, 0] and camera_counts=[1, 2, 1, 3]. `camera_ids` is a + set like object with no particular ordering of elements. ith element of + `camera_ids` coresponds to the ith element of `camera_counts`. + """ + + origins: torch.Tensor + directions: torch.Tensor + lengths: torch.Tensor + xys: torch.Tensor + camera_ids: Optional[torch.Tensor] = None + camera_counts: Optional[torch.Tensor] = None + + +def ray_bundle_to_ray_points( + ray_bundle: Union[RayBundle, HeterogeneousRayBundle] +) -> torch.Tensor: """ Converts rays parametrized with a `ray_bundle` (an instance of the `RayBundle` - named tuple) to 3D points by extending each ray according to the corresponding - length. + named tuple or HeterogeneousRayBundle dataclass) to 3D points by + extending each ray according to the corresponding length. E.g. for 2 dimensional tensors `ray_bundle.origins`, `ray_bundle.directions` and `ray_bundle.lengths`, the ray point at position `[i, j]` is: @@ -43,7 +84,7 @@ def ray_bundle_to_ray_points(ray_bundle: RayBundle) -> torch.Tensor: `ray_bundle.directions` matter. Args: - ray_bundle: A `RayBundle` object with fields: + ray_bundle: A `RayBundle` or `HeterogeneousRayBundle` object with fields: origins: A tensor of shape `(..., 3)` directions: A tensor of shape `(..., 3)` lengths: A tensor of shape `(..., num_points_per_ray)` diff --git a/tests/test_raysampling.py b/tests/test_raysampling.py index a731aa8d..d05041c0 100644 --- a/tests/test_raysampling.py +++ b/tests/test_raysampling.py @@ -152,6 +152,8 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase): min_depth=1.0, max_depth=10.0, n_pts_per_ray=10, + n_rays_total=None, + n_rays_per_image=None, ): raysampler_params = { "min_x": min_x, @@ -161,6 +163,8 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase): "n_pts_per_ray": n_pts_per_ray, "min_depth": min_depth, "max_depth": max_depth, + "n_rays_total": n_rays_total, + "n_rays_per_image": n_rays_per_image, } if issubclass(raysampler_type, MultinomialRaysampler): @@ -168,7 +172,11 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase): {"image_width": image_width, "image_height": image_height} ) elif issubclass(raysampler_type, MonteCarloRaysampler): - raysampler_params["n_rays_per_image"] = image_width * image_height + raysampler_params["n_rays_per_image"] = ( + image_width * image_height + if (n_rays_total is None) and (n_rays_per_image is None) + else n_rays_per_image + ) else: raise ValueError(str(raysampler_type)) @@ -580,3 +588,55 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase): # samples[3] has enough sources, so must contain 3 distinct values. self.assertLessEqual(samples[3].max(), 3) self.assertEqual(len(set(samples[3].tolist())), 3) + + def test_heterogeneous_sampling(self, batch_size=8): + """ + Test that the output of heterogeneous sampling has the first dimension equal + to n_rays_total and second to 1 and that ray_bundle elements from different + sampled cameras are different and equal for same sampled cameras. + """ + cameras = init_random_cameras(PerspectiveCameras, batch_size, random_z=True) + for n_rays_total in [2, 3, 17, 21, 32]: + for cls in (MultinomialRaysampler, MonteCarloRaysampler): + with self.subTest(cls.__name__ + ", n_rays_total=" + str(n_rays_total)): + raysampler = self.init_raysampler( + cls, n_rays_total=n_rays_total, n_rays_per_image=None + ) + ray_bundle = raysampler(cameras) + + # test weather they are of the correct shape + for attr in ("origins", "directions", "lengths", "xys"): + tensor = getattr(ray_bundle, attr) + assert tensor.shape[:2] == torch.Size( + (n_rays_total, 1) + ), tensor.shape + + # if two camera ids are same than origins should also be the same + # directions and xys are always different and lengths equal + for i1, (origin1, dir1, len1, id1) in enumerate( + zip( + ray_bundle.origins, + ray_bundle.directions, + ray_bundle.lengths, + torch.repeat_interleave( + ray_bundle.camera_ids, ray_bundle.camera_counts + ), + ) + ): + for i2, (origin2, dir2, len2, id2) in enumerate( + zip( + ray_bundle.origins, + ray_bundle.directions, + ray_bundle.lengths, + torch.repeat_interleave( + ray_bundle.camera_ids, ray_bundle.camera_counts + ), + ) + ): + if i1 == i2: + continue + assert torch.allclose( + origin1, origin2, rtol=1e-4, atol=1e-4 + ) == (id1 == id2), (origin1, origin2, id1, id2) + assert not torch.allclose(dir1, dir2), (dir1, dir2) + self.assertClose(len1, len2), (len1, len2)