Heterogeneous raysampling -> RayBundleHeterogeneous

Summary:
Added heterogeneous raysampling to pytorch3d raysampler, different cameras are sampled different number of times.

 It now returns RayBundle if heterogeneous raysampling is off and new RayBundleHeterogeneous (with added fields `camera_ids` and `camera_counts`).  Heterogeneous raysampling is on if `n_rays_total` is not None.

Reviewed By: bottler

Differential Revision: D39542222

fbshipit-source-id: d3d88d822ec7696e856007c088dc36a1cfa8c625
This commit is contained in:
Darijan Gudelj 2022-09-30 04:03:01 -07:00 committed by Facebook GitHub Bot
parent 9a0f9ae572
commit 6ae863f301
6 changed files with 325 additions and 48 deletions

View File

@ -31,6 +31,7 @@ from .implicit import (
EmissionAbsorptionRaymarcher, EmissionAbsorptionRaymarcher,
GridRaysampler, GridRaysampler,
HarmonicEmbedding, HarmonicEmbedding,
HeterogeneousRayBundle,
ImplicitRenderer, ImplicitRenderer,
MonteCarloRaysampler, MonteCarloRaysampler,
MultinomialRaysampler, MultinomialRaysampler,

View File

@ -15,6 +15,7 @@ from .raysampling import (
) )
from .renderer import ImplicitRenderer, VolumeRenderer, VolumeSampler from .renderer import ImplicitRenderer, VolumeRenderer, VolumeSampler
from .utils import ( from .utils import (
HeterogeneousRayBundle,
ray_bundle_to_ray_points, ray_bundle_to_ray_points,
ray_bundle_variables_to_ray_points, ray_bundle_variables_to_ray_points,
RayBundle, RayBundle,

View File

@ -5,12 +5,13 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import warnings import warnings
from typing import Optional from typing import Optional, Tuple, Union
import torch import torch
from pytorch3d.common.compat import meshgrid_ij from pytorch3d.common.compat import meshgrid_ij
from pytorch3d.ops import padded_to_packed
from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit.utils import RayBundle from pytorch3d.renderer.implicit.utils import HeterogeneousRayBundle, RayBundle
from torch.nn import functional as F from torch.nn import functional as F
@ -73,6 +74,7 @@ class MultinomialRaysampler(torch.nn.Module):
min_depth: float, min_depth: float,
max_depth: float, max_depth: float,
n_rays_per_image: Optional[int] = None, n_rays_per_image: Optional[int] = None,
n_rays_total: Optional[int] = None,
unit_directions: bool = False, unit_directions: bool = False,
stratified_sampling: bool = False, stratified_sampling: bool = False,
) -> None: ) -> None:
@ -88,6 +90,11 @@ class MultinomialRaysampler(torch.nn.Module):
min_depth: The minimum depth of a ray-point. min_depth: The minimum depth of a ray-point.
max_depth: The maximum depth of a ray-point. max_depth: The maximum depth of a ray-point.
n_rays_per_image: If given, this amount of rays are sampled from the grid. n_rays_per_image: If given, this amount of rays are sampled from the grid.
n_rays_total: How many rays in total to sample from the cameras provided. The result
is as if `n_rays_total` cameras were sampled with replacement from the
cameras provided and for every camera one ray was sampled. If set, this disables
`n_rays_per_image` and returns the HeterogeneousRayBundle with
batch_size=n_rays_total.
unit_directions: whether to normalize direction vectors in ray bundle. unit_directions: whether to normalize direction vectors in ray bundle.
stratified_sampling: if True, performs stratified random sampling stratified_sampling: if True, performs stratified random sampling
along the ray; otherwise takes ray points at deterministic offsets. along the ray; otherwise takes ray points at deterministic offsets.
@ -97,6 +104,7 @@ class MultinomialRaysampler(torch.nn.Module):
self._min_depth = min_depth self._min_depth = min_depth
self._max_depth = max_depth self._max_depth = max_depth
self._n_rays_per_image = n_rays_per_image self._n_rays_per_image = n_rays_per_image
self._n_rays_total = n_rays_total
self._unit_directions = unit_directions self._unit_directions = unit_directions
self._stratified_sampling = stratified_sampling self._stratified_sampling = stratified_sampling
@ -125,8 +133,9 @@ class MultinomialRaysampler(torch.nn.Module):
n_rays_per_image: Optional[int] = None, n_rays_per_image: Optional[int] = None,
n_pts_per_ray: Optional[int] = None, n_pts_per_ray: Optional[int] = None,
stratified_sampling: Optional[bool] = None, stratified_sampling: Optional[bool] = None,
n_rays_total: Optional[int] = None,
**kwargs, **kwargs,
) -> RayBundle: ) -> Union[RayBundle, HeterogeneousRayBundle]:
""" """
Args: Args:
cameras: A batch of `batch_size` cameras from which the rays are emitted. cameras: A batch of `batch_size` cameras from which the rays are emitted.
@ -138,8 +147,15 @@ class MultinomialRaysampler(torch.nn.Module):
n_pts_per_ray: The number of points sampled along each ray. n_pts_per_ray: The number of points sampled along each ray.
stratified_sampling: if set, overrides stratified_sampling provided stratified_sampling: if set, overrides stratified_sampling provided
in __init__. in __init__.
n_rays_total: How many rays in total to sample from the cameras provided. The result
is as if `n_rays_total_training` cameras were sampled with replacement from the
cameras provided and for every camera one ray was sampled. If set, this disables
`n_rays_per_image` and returns the HeterogeneousRayBundle with
batch_size=n_rays_total.
Returns: Returns:
A named tuple RayBundle with the following fields: A named tuple RayBundle or dataclass HeterogeneousRayBundle with the
following fields:
origins: A tensor of shape origins: A tensor of shape
`(batch_size, s1, s2, 3)` `(batch_size, s1, s2, 3)`
denoting the locations of ray origins in the world coordinates. denoting the locations of ray origins in the world coordinates.
@ -153,23 +169,56 @@ class MultinomialRaysampler(torch.nn.Module):
`(batch_size, s1, s2, 2)` `(batch_size, s1, s2, 2)`
containing the 2D image coordinates of each ray or, containing the 2D image coordinates of each ray or,
if mask is given, `(batch_size, n, 1, 2)` if mask is given, `(batch_size, n, 1, 2)`
Here `s1, s2` refer to spatial dimensions. Unless the mask is Here `s1, s2` refer to spatial dimensions.
given, they equal `(image_height, image_width)`, otherwise `(n, 1)`, `(s1, s2)` refer to (highest priority first):
where `n` is `n_rays_per_image` if provided, otherwise the minimum - `(1, 1)` if `n_rays_total` is provided, (batch_size=n_rays_total)
cardinality of the mask in the batch. - `(n_rays_per_image, 1) if `n_rays_per_image` if provided,
- `(n, 1)` where n is the minimum cardinality of the mask
in the batch if `mask` is provided
- `(image_height, image_width)` if nothing from above is satisfied
`HeterogeneousRayBundle` has additional members:
- camera_ids: tensor of shape (M,), where `M` is the number of unique sampled
cameras. It represents unique ids of sampled cameras.
- camera_counts: tensor of shape (M,), where `M` is the number of unique sampled
cameras. Represents how many times each camera from `camera_ids` was sampled
`HeterogeneousRayBundle` is returned if `n_rays_total` is provided else `RayBundle`
is returned.
""" """
n_rays_total = n_rays_total or self._n_rays_total
n_rays_per_image = n_rays_per_image or self._n_rays_per_image
assert (n_rays_total is None) or (
n_rays_per_image is None
), "`n_rays_total` and `n_rays_per_image` cannot both be defined."
if n_rays_total:
(
cameras,
mask,
camera_ids, # unique ids of sampled cameras
camera_counts, # number of times unique camera id was sampled
# `n_rays_per_image` is equal to the max number of times a simgle camera
# was sampled. We sample all cameras at `camera_ids` `n_rays_per_image` times
# and then discard the unneeded rays.
# pyre-ignore[9]
n_rays_per_image,
) = _sample_cameras_and_masks(n_rays_total, cameras, mask)
else:
camera_ids = torch.range(0, len(cameras), dtype=torch.long)
batch_size = cameras.R.shape[0] batch_size = cameras.R.shape[0]
device = cameras.device device = cameras.device
# expand the (H, W, 2) grid batch_size-times to (B, H, W, 2) # expand the (H, W, 2) grid batch_size-times to (B, H, W, 2)
xy_grid = self._xy_grid.to(device).expand(batch_size, -1, -1, -1) xy_grid = self._xy_grid.to(device).expand(batch_size, -1, -1, -1)
num_rays = n_rays_per_image or self._n_rays_per_image if mask is not None and n_rays_per_image is None:
if mask is not None and num_rays is None:
# if num rays not given, sample according to the smallest mask # if num rays not given, sample according to the smallest mask
num_rays = num_rays or mask.sum(dim=(1, 2)).min().int().item() n_rays_per_image = (
n_rays_per_image or mask.sum(dim=(1, 2)).min().int().item()
)
if num_rays is not None: if n_rays_per_image is not None:
if mask is not None: if mask is not None:
assert mask.shape == xy_grid.shape[:3] assert mask.shape == xy_grid.shape[:3]
weights = mask.reshape(batch_size, -1) weights = mask.reshape(batch_size, -1)
@ -181,7 +230,9 @@ class MultinomialRaysampler(torch.nn.Module):
weights = xy_grid.new_ones(batch_size, width * height) weights = xy_grid.new_ones(batch_size, width * height)
# pyre-fixme[6]: For 2nd param expected `int` but got `Union[bool, # pyre-fixme[6]: For 2nd param expected `int` but got `Union[bool,
# float, int]`. # float, int]`.
rays_idx = _safe_multinomial(weights, num_rays)[..., None].expand(-1, -1, 2) rays_idx = _safe_multinomial(weights, n_rays_per_image)[..., None].expand(
-1, -1, 2
)
xy_grid = torch.gather(xy_grid.reshape(batch_size, -1, 2), 1, rays_idx)[ xy_grid = torch.gather(xy_grid.reshape(batch_size, -1, 2), 1, rays_idx)[
:, :, None :, :, None
@ -198,7 +249,7 @@ class MultinomialRaysampler(torch.nn.Module):
else self._stratified_sampling else self._stratified_sampling
) )
return _xy_to_ray_bundle( ray_bundle = _xy_to_ray_bundle(
cameras, cameras,
xy_grid, xy_grid,
min_depth, min_depth,
@ -208,6 +259,13 @@ class MultinomialRaysampler(torch.nn.Module):
stratified_sampling, stratified_sampling,
) )
return (
# pyre-ignore[61]
_pack_ray_bundle(ray_bundle, camera_ids, camera_counts)
if n_rays_total
else ray_bundle
)
class NDCMultinomialRaysampler(MultinomialRaysampler): class NDCMultinomialRaysampler(MultinomialRaysampler):
""" """
@ -231,6 +289,7 @@ class NDCMultinomialRaysampler(MultinomialRaysampler):
min_depth: float, min_depth: float,
max_depth: float, max_depth: float,
n_rays_per_image: Optional[int] = None, n_rays_per_image: Optional[int] = None,
n_rays_total: Optional[int] = None,
unit_directions: bool = False, unit_directions: bool = False,
stratified_sampling: bool = False, stratified_sampling: bool = False,
) -> None: ) -> None:
@ -254,6 +313,7 @@ class NDCMultinomialRaysampler(MultinomialRaysampler):
min_depth=min_depth, min_depth=min_depth,
max_depth=max_depth, max_depth=max_depth,
n_rays_per_image=n_rays_per_image, n_rays_per_image=n_rays_per_image,
n_rays_total=n_rays_total,
unit_directions=unit_directions, unit_directions=unit_directions,
stratified_sampling=stratified_sampling, stratified_sampling=stratified_sampling,
) )
@ -281,6 +341,7 @@ class MonteCarloRaysampler(torch.nn.Module):
min_depth: float, min_depth: float,
max_depth: float, max_depth: float,
*, *,
n_rays_total: Optional[int] = None,
unit_directions: bool = False, unit_directions: bool = False,
stratified_sampling: bool = False, stratified_sampling: bool = False,
) -> None: ) -> None:
@ -294,6 +355,11 @@ class MonteCarloRaysampler(torch.nn.Module):
n_pts_per_ray: The number of points sampled along each ray. n_pts_per_ray: The number of points sampled along each ray.
min_depth: The minimum depth of each ray-point. min_depth: The minimum depth of each ray-point.
max_depth: The maximum depth of each ray-point. max_depth: The maximum depth of each ray-point.
n_rays_total: How many rays in total to sample from the cameras provided. The result
is as if `n_rays_total_training` cameras were sampled with replacement from the
cameras provided and for every camera one ray was sampled. If set, this disables
`n_rays_per_image` and returns the HeterogeneousRayBundleyBundle with
batch_size=n_rays_total.
unit_directions: whether to normalize direction vectors in ray bundle. unit_directions: whether to normalize direction vectors in ray bundle.
stratified_sampling: if True, performs stratified sampling in n_pts_per_ray stratified_sampling: if True, performs stratified sampling in n_pts_per_ray
bins for each ray; otherwise takes n_pts_per_ray deterministic points bins for each ray; otherwise takes n_pts_per_ray deterministic points
@ -308,6 +374,7 @@ class MonteCarloRaysampler(torch.nn.Module):
self._n_pts_per_ray = n_pts_per_ray self._n_pts_per_ray = n_pts_per_ray
self._min_depth = min_depth self._min_depth = min_depth
self._max_depth = max_depth self._max_depth = max_depth
self._n_rays_total = n_rays_total
self._unit_directions = unit_directions self._unit_directions = unit_directions
self._stratified_sampling = stratified_sampling self._stratified_sampling = stratified_sampling
@ -317,15 +384,16 @@ class MonteCarloRaysampler(torch.nn.Module):
*, *,
stratified_sampling: Optional[bool] = None, stratified_sampling: Optional[bool] = None,
**kwargs, **kwargs,
) -> RayBundle: ) -> Union[RayBundle, HeterogeneousRayBundle]:
""" """
Args: Args:
cameras: A batch of `batch_size` cameras from which the rays are emitted. cameras: A batch of `batch_size` cameras from which the rays are emitted.
stratified_sampling: if set, overrides stratified_sampling provided stratified_sampling: if set, overrides stratified_sampling provided
in __init__. in __init__.
Returns: Returns:
A named tuple RayBundle with the following fields: A named tuple `RayBundle` or dataclass `HeterogeneousRayBundle` with the
following fields:
origins: A tensor of shape origins: A tensor of shape
`(batch_size, n_rays_per_image, 3)` `(batch_size, n_rays_per_image, 3)`
denoting the locations of ray origins in the world coordinates. denoting the locations of ray origins in the world coordinates.
@ -338,7 +406,31 @@ class MonteCarloRaysampler(torch.nn.Module):
xys: A tensor of shape xys: A tensor of shape
`(batch_size, n_rays_per_image, 2)` `(batch_size, n_rays_per_image, 2)`
containing the 2D image coordinates of each ray. containing the 2D image coordinates of each ray.
If `n_rays_total` is provided `batch_size=n_rays_total`and
`n_rays_per_image=1` and `HeterogeneousRayBundle` is returned else `RayBundle`
is returned.
`HeterogeneousRayBundle` has additional members:
- camera_ids: tensor of shape (M,), where `M` is the number of unique sampled
cameras. It represents unique ids of sampled cameras.
- camera_counts: tensor of shape (M,), where `M` is the number of unique sampled
cameras. Represents how many times each camera from `camera_ids` was sampled
""" """
assert (self._n_rays_total is None) or (
self._n_rays_per_image is None
), "`self.n_rays_total` and `self.n_rays_per_image` cannot both be defined."
if self._n_rays_total:
(
cameras,
_,
camera_ids,
camera_counts,
n_rays_per_image,
) = _sample_cameras_and_masks(self._n_rays_total, cameras, None)
else:
camera_ids = torch.range(0, len(cameras), dtype=torch.long)
n_rays_per_image = self._n_rays_per_image
batch_size = cameras.R.shape[0] batch_size = cameras.R.shape[0]
@ -349,7 +441,7 @@ class MonteCarloRaysampler(torch.nn.Module):
rays_xy = torch.cat( rays_xy = torch.cat(
[ [
torch.rand( torch.rand(
size=(batch_size, self._n_rays_per_image, 1), size=(batch_size, n_rays_per_image, 1),
dtype=torch.float32, dtype=torch.float32,
device=device, device=device,
) )
@ -369,7 +461,7 @@ class MonteCarloRaysampler(torch.nn.Module):
else self._stratified_sampling else self._stratified_sampling
) )
return _xy_to_ray_bundle( ray_bundle = _xy_to_ray_bundle(
cameras, cameras,
rays_xy, rays_xy,
self._min_depth, self._min_depth,
@ -379,6 +471,13 @@ class MonteCarloRaysampler(torch.nn.Module):
stratified_sampling, stratified_sampling,
) )
return (
# pyre-ignore[61]
_pack_ray_bundle(ray_bundle, camera_ids, camera_counts)
if self._n_rays_total
else ray_bundle
)
# Settings for backwards compatibility # Settings for backwards compatibility
def GridRaysampler( def GridRaysampler(
@ -602,3 +701,74 @@ def _jiggle_within_stratas(bin_centers: torch.Tensor) -> torch.Tensor:
# Samples in those intervals. # Samples in those intervals.
jiggled = lower + (upper - lower) * torch.rand_like(lower) jiggled = lower + (upper - lower) * torch.rand_like(lower)
return jiggled return jiggled
def _sample_cameras_and_masks(
n_samples: int, cameras: CamerasBase, mask: Optional[torch.Tensor] = None
) -> Tuple[
CamerasBase, Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor
]:
"""
Samples n_rays_total cameras and masks and returns them in a form
(camera_idx, count), where count represents number of times the same camera
has been sampled.
Args:
n_samples: how many camera and mask pairs to sample
cameras: A batch of `batch_size` cameras from which the rays are emitted.
mask: Optional. Should be of size (batch_size, image_height, image_width).
Returns:
tuple of a form (sampled_cameras, sampled_masks, unique_sampled_camera_ids,
number_of_times_each_sampled_camera_has_been_sampled,
max_number_of_times_camera_has_been_sampled,
)
"""
sampled_ids = torch.randint(
0,
len(cameras),
size=(n_samples,),
dtype=torch.long,
)
unique_ids, counts = torch.unique(sampled_ids, return_counts=True)
return (
cameras[unique_ids],
mask[unique_ids] if mask is not None else None,
unique_ids,
counts,
torch.max(counts),
)
def _pack_ray_bundle(
ray_bundle: RayBundle, camera_ids: torch.Tensor, camera_counts: torch.Tensor
) -> HeterogeneousRayBundle:
"""
Pack the raybundle from [n_cameras, max(rays_per_camera), ...] to
[total_num_rays, 1, ...]
Args:
ray_bundle: A ray_bundle to pack
camera_ids: Unique ids of cameras that were sampled
camera_counts: how many of which camera to pack, each count coresponds to
one 'row' of the ray_bundle and says how many rays wll be taken
from it and packed.
Returns:
HeterogeneousRayBundle where batch_size=sum(camera_counts) and n_rays_per_image=1
"""
camera_counts = camera_counts.to(ray_bundle.origins.device)
cumsum = torch.cumsum(camera_counts, dim=0, dtype=torch.long)
first_idxs = torch.cat(
(camera_counts.new_zeros((1,), dtype=torch.long), cumsum[:-1])
)
num_inputs = int(camera_counts.sum())
return HeterogeneousRayBundle(
origins=padded_to_packed(ray_bundle.origins, first_idxs, num_inputs)[:, None],
directions=padded_to_packed(ray_bundle.directions, first_idxs, num_inputs)[
:, None
],
lengths=padded_to_packed(ray_bundle.lengths, first_idxs, num_inputs)[:, None],
xys=padded_to_packed(ray_bundle.xys, first_idxs, num_inputs)[:, None],
camera_ids=camera_ids,
camera_counts=camera_counts,
)

View File

@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Callable, Tuple from typing import Callable, Tuple, Union
import torch import torch
@ -12,7 +12,7 @@ from ...ops.utils import eyes
from ...structures import Volumes from ...structures import Volumes
from ...transforms import Transform3d from ...transforms import Transform3d
from ..cameras import CamerasBase from ..cameras import CamerasBase
from .raysampling import RayBundle from .raysampling import HeterogeneousRayBundle, RayBundle
from .utils import _validate_ray_bundle_variables, ray_bundle_variables_to_ray_points from .utils import _validate_ray_bundle_variables, ray_bundle_variables_to_ray_points
@ -44,12 +44,14 @@ class ImplicitRenderer(torch.nn.Module):
A standard `volumetric_function` has the following signature: A standard `volumetric_function` has the following signature:
``` ```
def volumetric_function( def volumetric_function(
ray_bundle: RayBundle, ray_bundle: Union[RayBundle, HeterogeneousRayBundle],
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]
``` ```
With the following arguments: With the following arguments:
`ray_bundle`: A RayBundle object containing the following variables: `ray_bundle`: A RayBundle or HeterogeneousRayBundle object
containing the following variables:
`origins`: A tensor of shape `(minibatch, ..., 3)` denoting `origins`: A tensor of shape `(minibatch, ..., 3)` denoting
the origins of the rendering rays. the origins of the rendering rays.
`directions`: A tensor of shape `(minibatch, ..., 3)` `directions`: A tensor of shape `(minibatch, ..., 3)`
@ -80,7 +82,7 @@ class ImplicitRenderer(torch.nn.Module):
RGB sphere with a unit diameter is defined as follows: RGB sphere with a unit diameter is defined as follows:
``` ```
def volumetric_function( def volumetric_function(
ray_bundle: RayBundle, ray_bundle: Union[RayBundle, HeterogeneousRayBundle],
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
@ -109,7 +111,8 @@ class ImplicitRenderer(torch.nn.Module):
""" """
Args: Args:
raysampler: A `Callable` that takes as input scene cameras raysampler: A `Callable` that takes as input scene cameras
(an instance of `CamerasBase`) and returns a `RayBundle` that (an instance of `CamerasBase`) and returns a
RayBundle or HeterogeneousRayBundle, that
describes the rays emitted from the cameras. describes the rays emitted from the cameras.
raymarcher: A `Callable` that receives the response of the raymarcher: A `Callable` that receives the response of the
`volumetric_function` (an input to `self.forward`) evaluated `volumetric_function` (an input to `self.forward`) evaluated
@ -128,7 +131,7 @@ class ImplicitRenderer(torch.nn.Module):
def forward( def forward(
self, cameras: CamerasBase, volumetric_function: Callable, **kwargs self, cameras: CamerasBase, volumetric_function: Callable, **kwargs
) -> Tuple[torch.Tensor, RayBundle]: ) -> Tuple[torch.Tensor, Union[RayBundle, HeterogeneousRayBundle]]:
""" """
Render a batch of images using a volumetric function Render a batch of images using a volumetric function
represented as a callable (e.g. a Pytorch module). represented as a callable (e.g. a Pytorch module).
@ -145,15 +148,15 @@ class ImplicitRenderer(torch.nn.Module):
Returns: Returns:
images: A tensor of shape `(minibatch, ..., feature_dim + opacity_dim)` images: A tensor of shape `(minibatch, ..., feature_dim + opacity_dim)`
containing the result of the rendering. containing the result of the rendering.
ray_bundle: A `RayBundle` containing the parametrizations of the ray_bundle: A `Union[RayBundle, HeterogeneousRayBundle]` containing
sampled rendering rays. the parametrizations of the sampled rendering rays.
""" """
if not callable(volumetric_function): if not callable(volumetric_function):
raise ValueError('"volumetric_function" has to be a "Callable" object.') raise ValueError('"volumetric_function" has to be a "Callable" object.')
# first call the ray sampler that returns the RayBundle parametrizing # first call the ray sampler that returns the RayBundle or HeterogeneousRayBundle
# the rendering rays. # parametrizing the rendering rays.
ray_bundle = self.raysampler( ray_bundle = self.raysampler(
cameras=cameras, volumetric_function=volumetric_function, **kwargs cameras=cameras, volumetric_function=volumetric_function, **kwargs
) )
@ -211,7 +214,8 @@ class VolumeRenderer(torch.nn.Module):
""" """
Args: Args:
raysampler: A `Callable` that takes as input scene cameras raysampler: A `Callable` that takes as input scene cameras
(an instance of `CamerasBase`) and returns a `RayBundle` that (an instance of `CamerasBase`) and returns a
`Union[RayBundle, HeterogeneousRayBundle],` that
describes the rays emitted from the cameras. describes the rays emitted from the cameras.
raymarcher: A `Callable` that receives the `volumes` raymarcher: A `Callable` that receives the `volumes`
(an instance of `Volumes` input to `self.forward`) (an instance of `Volumes` input to `self.forward`)
@ -227,7 +231,7 @@ class VolumeRenderer(torch.nn.Module):
def forward( def forward(
self, cameras: CamerasBase, volumes: Volumes, **kwargs self, cameras: CamerasBase, volumes: Volumes, **kwargs
) -> Tuple[torch.Tensor, RayBundle]: ) -> Tuple[torch.Tensor, Union[RayBundle, HeterogeneousRayBundle]]:
""" """
Render a batch of images using raymarching over rays cast through Render a batch of images using raymarching over rays cast through
input `Volumes`. input `Volumes`.
@ -242,8 +246,8 @@ class VolumeRenderer(torch.nn.Module):
Returns: Returns:
images: A tensor of shape `(minibatch, ..., (feature_dim + opacity_dim)` images: A tensor of shape `(minibatch, ..., (feature_dim + opacity_dim)`
containing the result of the rendering. containing the result of the rendering.
ray_bundle: A `RayBundle` containing the parametrizations of the ray_bundle: A `RayBundle` or `HeterogeneousRayBundle` containing the
sampled rendering rays. parametrizations of the sampled rendering rays.
""" """
volumetric_function = VolumeSampler(volumes, sample_mode=self._sample_mode) volumetric_function = VolumeSampler(volumes, sample_mode=self._sample_mode)
return self.renderer( return self.renderer(
@ -288,14 +292,14 @@ class VolumeSampler(torch.nn.Module):
return directions_transform return directions_transform
def forward( def forward(
self, ray_bundle: RayBundle, **kwargs self, ray_bundle: Union[RayBundle, HeterogeneousRayBundle], **kwargs
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Given an input ray parametrization, the forward function samples Given an input ray parametrization, the forward function samples
`self._volumes` at the respective 3D ray-points. `self._volumes` at the respective 3D ray-points.
Args: Args:
ray_bundle: A RayBundle object with the following fields: ray_bundle: A RayBundle or HeterogeneousRayBundle object with the following fields:
rays_origins_world: A tensor of shape `(minibatch, ..., 3)` denoting the rays_origins_world: A tensor of shape `(minibatch, ..., 3)` denoting the
origins of the sampling rays in world coords. origins of the sampling rays in world coords.
rays_directions_world: A tensor of shape `(minibatch, ..., 3)` rays_directions_world: A tensor of shape `(minibatch, ..., 3)`

View File

@ -4,19 +4,25 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import NamedTuple import dataclasses
from typing import NamedTuple, Optional, Union
import torch import torch
class RayBundle(NamedTuple): class RayBundle(NamedTuple):
""" """
RayBundle parametrizes points along projection rays by storing ray `origins`, Parametrizes points along projection rays by storing:
`directions` vectors and `lengths` at which the ray-points are sampled.
Furthermore, the xy-locations (`xys`) of the ray pixels are stored as well. origins: A tensor of shape `(..., 3)` denoting the
Note that `directions` don't have to be normalized; they define unit vectors origins of the sampling rays in world coords.
in the respective 1D coordinate systems; see documentation for directions: A tensor of shape `(..., 3)` containing the direction
:func:`ray_bundle_to_ray_points` for the conversion formula. vectors of sampling rays in world coords. They don't have to be normalized;
they define unit vectors in the respective 1D coordinate systems; see
documentation for :func:`ray_bundle_to_ray_points` for the conversion formula.
lengths: A tensor of shape `(..., num_points_per_ray)`
containing the lengths at which the rays are sampled.
xys: A tensor of shape `(..., 2)`, the xy-locations (`xys`) of the ray pixels
""" """
origins: torch.Tensor origins: torch.Tensor
@ -25,11 +31,46 @@ class RayBundle(NamedTuple):
xys: torch.Tensor xys: torch.Tensor
def ray_bundle_to_ray_points(ray_bundle: RayBundle) -> torch.Tensor: @dataclasses.dataclass
class HeterogeneousRayBundle:
"""
Members:
origins: A tensor of shape `(..., 3)` denoting the
origins of the sampling rays in world coords.
directions: A tensor of shape `(..., 3)` containing the direction
vectors of sampling rays in world coords. They don't have to be normalized;
they define unit vectors in the respective 1D coordinate systems; see
documentation for :func:`ray_bundle_to_ray_points` for the conversion formula.
lengths: A tensor of shape `(..., num_points_per_ray)`
containing the lengths at which the rays are sampled.
xys: A tensor of shape `(..., 2)`, the xy-locations (`xys`) of the ray pixels
camera_ids: A tensor of shape (N, ) which indicates which camera
was used to sample the rays. `N` is the number of unique sampled cameras.
camera_counts: A tensor of shape (N, ) which how many times the
coresponding camera in `camera_ids` was sampled.
`sum(camera_counts)==total_number_of_rays`
If we sample cameras of ids [0, 3, 5, 3, 1, 0, 0] that would be
stored as camera_ids=[1, 3, 5, 0] and camera_counts=[1, 2, 1, 3]. `camera_ids` is a
set like object with no particular ordering of elements. ith element of
`camera_ids` coresponds to the ith element of `camera_counts`.
"""
origins: torch.Tensor
directions: torch.Tensor
lengths: torch.Tensor
xys: torch.Tensor
camera_ids: Optional[torch.Tensor] = None
camera_counts: Optional[torch.Tensor] = None
def ray_bundle_to_ray_points(
ray_bundle: Union[RayBundle, HeterogeneousRayBundle]
) -> torch.Tensor:
""" """
Converts rays parametrized with a `ray_bundle` (an instance of the `RayBundle` Converts rays parametrized with a `ray_bundle` (an instance of the `RayBundle`
named tuple) to 3D points by extending each ray according to the corresponding named tuple or HeterogeneousRayBundle dataclass) to 3D points by
length. extending each ray according to the corresponding length.
E.g. for 2 dimensional tensors `ray_bundle.origins`, `ray_bundle.directions` E.g. for 2 dimensional tensors `ray_bundle.origins`, `ray_bundle.directions`
and `ray_bundle.lengths`, the ray point at position `[i, j]` is: and `ray_bundle.lengths`, the ray point at position `[i, j]` is:
@ -43,7 +84,7 @@ def ray_bundle_to_ray_points(ray_bundle: RayBundle) -> torch.Tensor:
`ray_bundle.directions` matter. `ray_bundle.directions` matter.
Args: Args:
ray_bundle: A `RayBundle` object with fields: ray_bundle: A `RayBundle` or `HeterogeneousRayBundle` object with fields:
origins: A tensor of shape `(..., 3)` origins: A tensor of shape `(..., 3)`
directions: A tensor of shape `(..., 3)` directions: A tensor of shape `(..., 3)`
lengths: A tensor of shape `(..., num_points_per_ray)` lengths: A tensor of shape `(..., num_points_per_ray)`

View File

@ -152,6 +152,8 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
min_depth=1.0, min_depth=1.0,
max_depth=10.0, max_depth=10.0,
n_pts_per_ray=10, n_pts_per_ray=10,
n_rays_total=None,
n_rays_per_image=None,
): ):
raysampler_params = { raysampler_params = {
"min_x": min_x, "min_x": min_x,
@ -161,6 +163,8 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
"n_pts_per_ray": n_pts_per_ray, "n_pts_per_ray": n_pts_per_ray,
"min_depth": min_depth, "min_depth": min_depth,
"max_depth": max_depth, "max_depth": max_depth,
"n_rays_total": n_rays_total,
"n_rays_per_image": n_rays_per_image,
} }
if issubclass(raysampler_type, MultinomialRaysampler): if issubclass(raysampler_type, MultinomialRaysampler):
@ -168,7 +172,11 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
{"image_width": image_width, "image_height": image_height} {"image_width": image_width, "image_height": image_height}
) )
elif issubclass(raysampler_type, MonteCarloRaysampler): elif issubclass(raysampler_type, MonteCarloRaysampler):
raysampler_params["n_rays_per_image"] = image_width * image_height raysampler_params["n_rays_per_image"] = (
image_width * image_height
if (n_rays_total is None) and (n_rays_per_image is None)
else n_rays_per_image
)
else: else:
raise ValueError(str(raysampler_type)) raise ValueError(str(raysampler_type))
@ -580,3 +588,55 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
# samples[3] has enough sources, so must contain 3 distinct values. # samples[3] has enough sources, so must contain 3 distinct values.
self.assertLessEqual(samples[3].max(), 3) self.assertLessEqual(samples[3].max(), 3)
self.assertEqual(len(set(samples[3].tolist())), 3) self.assertEqual(len(set(samples[3].tolist())), 3)
def test_heterogeneous_sampling(self, batch_size=8):
"""
Test that the output of heterogeneous sampling has the first dimension equal
to n_rays_total and second to 1 and that ray_bundle elements from different
sampled cameras are different and equal for same sampled cameras.
"""
cameras = init_random_cameras(PerspectiveCameras, batch_size, random_z=True)
for n_rays_total in [2, 3, 17, 21, 32]:
for cls in (MultinomialRaysampler, MonteCarloRaysampler):
with self.subTest(cls.__name__ + ", n_rays_total=" + str(n_rays_total)):
raysampler = self.init_raysampler(
cls, n_rays_total=n_rays_total, n_rays_per_image=None
)
ray_bundle = raysampler(cameras)
# test weather they are of the correct shape
for attr in ("origins", "directions", "lengths", "xys"):
tensor = getattr(ray_bundle, attr)
assert tensor.shape[:2] == torch.Size(
(n_rays_total, 1)
), tensor.shape
# if two camera ids are same than origins should also be the same
# directions and xys are always different and lengths equal
for i1, (origin1, dir1, len1, id1) in enumerate(
zip(
ray_bundle.origins,
ray_bundle.directions,
ray_bundle.lengths,
torch.repeat_interleave(
ray_bundle.camera_ids, ray_bundle.camera_counts
),
)
):
for i2, (origin2, dir2, len2, id2) in enumerate(
zip(
ray_bundle.origins,
ray_bundle.directions,
ray_bundle.lengths,
torch.repeat_interleave(
ray_bundle.camera_ids, ray_bundle.camera_counts
),
)
):
if i1 == i2:
continue
assert torch.allclose(
origin1, origin2, rtol=1e-4, atol=1e-4
) == (id1 == id2), (origin1, origin2, id1, id2)
assert not torch.allclose(dir1, dir2), (dir1, dir2)
self.assertClose(len1, len2), (len1, len2)