mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
Heterogeneous raysampling -> RayBundleHeterogeneous
Summary: Added heterogeneous raysampling to pytorch3d raysampler, different cameras are sampled different number of times. It now returns RayBundle if heterogeneous raysampling is off and new RayBundleHeterogeneous (with added fields `camera_ids` and `camera_counts`). Heterogeneous raysampling is on if `n_rays_total` is not None. Reviewed By: bottler Differential Revision: D39542222 fbshipit-source-id: d3d88d822ec7696e856007c088dc36a1cfa8c625
This commit is contained in:
parent
9a0f9ae572
commit
6ae863f301
@ -31,6 +31,7 @@ from .implicit import (
|
|||||||
EmissionAbsorptionRaymarcher,
|
EmissionAbsorptionRaymarcher,
|
||||||
GridRaysampler,
|
GridRaysampler,
|
||||||
HarmonicEmbedding,
|
HarmonicEmbedding,
|
||||||
|
HeterogeneousRayBundle,
|
||||||
ImplicitRenderer,
|
ImplicitRenderer,
|
||||||
MonteCarloRaysampler,
|
MonteCarloRaysampler,
|
||||||
MultinomialRaysampler,
|
MultinomialRaysampler,
|
||||||
|
@ -15,6 +15,7 @@ from .raysampling import (
|
|||||||
)
|
)
|
||||||
from .renderer import ImplicitRenderer, VolumeRenderer, VolumeSampler
|
from .renderer import ImplicitRenderer, VolumeRenderer, VolumeSampler
|
||||||
from .utils import (
|
from .utils import (
|
||||||
|
HeterogeneousRayBundle,
|
||||||
ray_bundle_to_ray_points,
|
ray_bundle_to_ray_points,
|
||||||
ray_bundle_variables_to_ray_points,
|
ray_bundle_variables_to_ray_points,
|
||||||
RayBundle,
|
RayBundle,
|
||||||
|
@ -5,12 +5,13 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.common.compat import meshgrid_ij
|
from pytorch3d.common.compat import meshgrid_ij
|
||||||
|
from pytorch3d.ops import padded_to_packed
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
from pytorch3d.renderer.implicit.utils import RayBundle
|
from pytorch3d.renderer.implicit.utils import HeterogeneousRayBundle, RayBundle
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
@ -73,6 +74,7 @@ class MultinomialRaysampler(torch.nn.Module):
|
|||||||
min_depth: float,
|
min_depth: float,
|
||||||
max_depth: float,
|
max_depth: float,
|
||||||
n_rays_per_image: Optional[int] = None,
|
n_rays_per_image: Optional[int] = None,
|
||||||
|
n_rays_total: Optional[int] = None,
|
||||||
unit_directions: bool = False,
|
unit_directions: bool = False,
|
||||||
stratified_sampling: bool = False,
|
stratified_sampling: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -88,6 +90,11 @@ class MultinomialRaysampler(torch.nn.Module):
|
|||||||
min_depth: The minimum depth of a ray-point.
|
min_depth: The minimum depth of a ray-point.
|
||||||
max_depth: The maximum depth of a ray-point.
|
max_depth: The maximum depth of a ray-point.
|
||||||
n_rays_per_image: If given, this amount of rays are sampled from the grid.
|
n_rays_per_image: If given, this amount of rays are sampled from the grid.
|
||||||
|
n_rays_total: How many rays in total to sample from the cameras provided. The result
|
||||||
|
is as if `n_rays_total` cameras were sampled with replacement from the
|
||||||
|
cameras provided and for every camera one ray was sampled. If set, this disables
|
||||||
|
`n_rays_per_image` and returns the HeterogeneousRayBundle with
|
||||||
|
batch_size=n_rays_total.
|
||||||
unit_directions: whether to normalize direction vectors in ray bundle.
|
unit_directions: whether to normalize direction vectors in ray bundle.
|
||||||
stratified_sampling: if True, performs stratified random sampling
|
stratified_sampling: if True, performs stratified random sampling
|
||||||
along the ray; otherwise takes ray points at deterministic offsets.
|
along the ray; otherwise takes ray points at deterministic offsets.
|
||||||
@ -97,6 +104,7 @@ class MultinomialRaysampler(torch.nn.Module):
|
|||||||
self._min_depth = min_depth
|
self._min_depth = min_depth
|
||||||
self._max_depth = max_depth
|
self._max_depth = max_depth
|
||||||
self._n_rays_per_image = n_rays_per_image
|
self._n_rays_per_image = n_rays_per_image
|
||||||
|
self._n_rays_total = n_rays_total
|
||||||
self._unit_directions = unit_directions
|
self._unit_directions = unit_directions
|
||||||
self._stratified_sampling = stratified_sampling
|
self._stratified_sampling = stratified_sampling
|
||||||
|
|
||||||
@ -125,8 +133,9 @@ class MultinomialRaysampler(torch.nn.Module):
|
|||||||
n_rays_per_image: Optional[int] = None,
|
n_rays_per_image: Optional[int] = None,
|
||||||
n_pts_per_ray: Optional[int] = None,
|
n_pts_per_ray: Optional[int] = None,
|
||||||
stratified_sampling: Optional[bool] = None,
|
stratified_sampling: Optional[bool] = None,
|
||||||
|
n_rays_total: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> RayBundle:
|
) -> Union[RayBundle, HeterogeneousRayBundle]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
cameras: A batch of `batch_size` cameras from which the rays are emitted.
|
cameras: A batch of `batch_size` cameras from which the rays are emitted.
|
||||||
@ -138,8 +147,15 @@ class MultinomialRaysampler(torch.nn.Module):
|
|||||||
n_pts_per_ray: The number of points sampled along each ray.
|
n_pts_per_ray: The number of points sampled along each ray.
|
||||||
stratified_sampling: if set, overrides stratified_sampling provided
|
stratified_sampling: if set, overrides stratified_sampling provided
|
||||||
in __init__.
|
in __init__.
|
||||||
|
n_rays_total: How many rays in total to sample from the cameras provided. The result
|
||||||
|
is as if `n_rays_total_training` cameras were sampled with replacement from the
|
||||||
|
cameras provided and for every camera one ray was sampled. If set, this disables
|
||||||
|
`n_rays_per_image` and returns the HeterogeneousRayBundle with
|
||||||
|
batch_size=n_rays_total.
|
||||||
Returns:
|
Returns:
|
||||||
A named tuple RayBundle with the following fields:
|
A named tuple RayBundle or dataclass HeterogeneousRayBundle with the
|
||||||
|
following fields:
|
||||||
|
|
||||||
origins: A tensor of shape
|
origins: A tensor of shape
|
||||||
`(batch_size, s1, s2, 3)`
|
`(batch_size, s1, s2, 3)`
|
||||||
denoting the locations of ray origins in the world coordinates.
|
denoting the locations of ray origins in the world coordinates.
|
||||||
@ -153,23 +169,56 @@ class MultinomialRaysampler(torch.nn.Module):
|
|||||||
`(batch_size, s1, s2, 2)`
|
`(batch_size, s1, s2, 2)`
|
||||||
containing the 2D image coordinates of each ray or,
|
containing the 2D image coordinates of each ray or,
|
||||||
if mask is given, `(batch_size, n, 1, 2)`
|
if mask is given, `(batch_size, n, 1, 2)`
|
||||||
Here `s1, s2` refer to spatial dimensions. Unless the mask is
|
Here `s1, s2` refer to spatial dimensions.
|
||||||
given, they equal `(image_height, image_width)`, otherwise `(n, 1)`,
|
`(s1, s2)` refer to (highest priority first):
|
||||||
where `n` is `n_rays_per_image` if provided, otherwise the minimum
|
- `(1, 1)` if `n_rays_total` is provided, (batch_size=n_rays_total)
|
||||||
cardinality of the mask in the batch.
|
- `(n_rays_per_image, 1) if `n_rays_per_image` if provided,
|
||||||
|
- `(n, 1)` where n is the minimum cardinality of the mask
|
||||||
|
in the batch if `mask` is provided
|
||||||
|
- `(image_height, image_width)` if nothing from above is satisfied
|
||||||
|
|
||||||
|
`HeterogeneousRayBundle` has additional members:
|
||||||
|
- camera_ids: tensor of shape (M,), where `M` is the number of unique sampled
|
||||||
|
cameras. It represents unique ids of sampled cameras.
|
||||||
|
- camera_counts: tensor of shape (M,), where `M` is the number of unique sampled
|
||||||
|
cameras. Represents how many times each camera from `camera_ids` was sampled
|
||||||
|
|
||||||
|
`HeterogeneousRayBundle` is returned if `n_rays_total` is provided else `RayBundle`
|
||||||
|
is returned.
|
||||||
"""
|
"""
|
||||||
|
n_rays_total = n_rays_total or self._n_rays_total
|
||||||
|
n_rays_per_image = n_rays_per_image or self._n_rays_per_image
|
||||||
|
assert (n_rays_total is None) or (
|
||||||
|
n_rays_per_image is None
|
||||||
|
), "`n_rays_total` and `n_rays_per_image` cannot both be defined."
|
||||||
|
if n_rays_total:
|
||||||
|
(
|
||||||
|
cameras,
|
||||||
|
mask,
|
||||||
|
camera_ids, # unique ids of sampled cameras
|
||||||
|
camera_counts, # number of times unique camera id was sampled
|
||||||
|
# `n_rays_per_image` is equal to the max number of times a simgle camera
|
||||||
|
# was sampled. We sample all cameras at `camera_ids` `n_rays_per_image` times
|
||||||
|
# and then discard the unneeded rays.
|
||||||
|
# pyre-ignore[9]
|
||||||
|
n_rays_per_image,
|
||||||
|
) = _sample_cameras_and_masks(n_rays_total, cameras, mask)
|
||||||
|
else:
|
||||||
|
camera_ids = torch.range(0, len(cameras), dtype=torch.long)
|
||||||
|
|
||||||
batch_size = cameras.R.shape[0]
|
batch_size = cameras.R.shape[0]
|
||||||
device = cameras.device
|
device = cameras.device
|
||||||
|
|
||||||
# expand the (H, W, 2) grid batch_size-times to (B, H, W, 2)
|
# expand the (H, W, 2) grid batch_size-times to (B, H, W, 2)
|
||||||
xy_grid = self._xy_grid.to(device).expand(batch_size, -1, -1, -1)
|
xy_grid = self._xy_grid.to(device).expand(batch_size, -1, -1, -1)
|
||||||
|
|
||||||
num_rays = n_rays_per_image or self._n_rays_per_image
|
if mask is not None and n_rays_per_image is None:
|
||||||
if mask is not None and num_rays is None:
|
|
||||||
# if num rays not given, sample according to the smallest mask
|
# if num rays not given, sample according to the smallest mask
|
||||||
num_rays = num_rays or mask.sum(dim=(1, 2)).min().int().item()
|
n_rays_per_image = (
|
||||||
|
n_rays_per_image or mask.sum(dim=(1, 2)).min().int().item()
|
||||||
|
)
|
||||||
|
|
||||||
if num_rays is not None:
|
if n_rays_per_image is not None:
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
assert mask.shape == xy_grid.shape[:3]
|
assert mask.shape == xy_grid.shape[:3]
|
||||||
weights = mask.reshape(batch_size, -1)
|
weights = mask.reshape(batch_size, -1)
|
||||||
@ -181,7 +230,9 @@ class MultinomialRaysampler(torch.nn.Module):
|
|||||||
weights = xy_grid.new_ones(batch_size, width * height)
|
weights = xy_grid.new_ones(batch_size, width * height)
|
||||||
# pyre-fixme[6]: For 2nd param expected `int` but got `Union[bool,
|
# pyre-fixme[6]: For 2nd param expected `int` but got `Union[bool,
|
||||||
# float, int]`.
|
# float, int]`.
|
||||||
rays_idx = _safe_multinomial(weights, num_rays)[..., None].expand(-1, -1, 2)
|
rays_idx = _safe_multinomial(weights, n_rays_per_image)[..., None].expand(
|
||||||
|
-1, -1, 2
|
||||||
|
)
|
||||||
|
|
||||||
xy_grid = torch.gather(xy_grid.reshape(batch_size, -1, 2), 1, rays_idx)[
|
xy_grid = torch.gather(xy_grid.reshape(batch_size, -1, 2), 1, rays_idx)[
|
||||||
:, :, None
|
:, :, None
|
||||||
@ -198,7 +249,7 @@ class MultinomialRaysampler(torch.nn.Module):
|
|||||||
else self._stratified_sampling
|
else self._stratified_sampling
|
||||||
)
|
)
|
||||||
|
|
||||||
return _xy_to_ray_bundle(
|
ray_bundle = _xy_to_ray_bundle(
|
||||||
cameras,
|
cameras,
|
||||||
xy_grid,
|
xy_grid,
|
||||||
min_depth,
|
min_depth,
|
||||||
@ -208,6 +259,13 @@ class MultinomialRaysampler(torch.nn.Module):
|
|||||||
stratified_sampling,
|
stratified_sampling,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
# pyre-ignore[61]
|
||||||
|
_pack_ray_bundle(ray_bundle, camera_ids, camera_counts)
|
||||||
|
if n_rays_total
|
||||||
|
else ray_bundle
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NDCMultinomialRaysampler(MultinomialRaysampler):
|
class NDCMultinomialRaysampler(MultinomialRaysampler):
|
||||||
"""
|
"""
|
||||||
@ -231,6 +289,7 @@ class NDCMultinomialRaysampler(MultinomialRaysampler):
|
|||||||
min_depth: float,
|
min_depth: float,
|
||||||
max_depth: float,
|
max_depth: float,
|
||||||
n_rays_per_image: Optional[int] = None,
|
n_rays_per_image: Optional[int] = None,
|
||||||
|
n_rays_total: Optional[int] = None,
|
||||||
unit_directions: bool = False,
|
unit_directions: bool = False,
|
||||||
stratified_sampling: bool = False,
|
stratified_sampling: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -254,6 +313,7 @@ class NDCMultinomialRaysampler(MultinomialRaysampler):
|
|||||||
min_depth=min_depth,
|
min_depth=min_depth,
|
||||||
max_depth=max_depth,
|
max_depth=max_depth,
|
||||||
n_rays_per_image=n_rays_per_image,
|
n_rays_per_image=n_rays_per_image,
|
||||||
|
n_rays_total=n_rays_total,
|
||||||
unit_directions=unit_directions,
|
unit_directions=unit_directions,
|
||||||
stratified_sampling=stratified_sampling,
|
stratified_sampling=stratified_sampling,
|
||||||
)
|
)
|
||||||
@ -281,6 +341,7 @@ class MonteCarloRaysampler(torch.nn.Module):
|
|||||||
min_depth: float,
|
min_depth: float,
|
||||||
max_depth: float,
|
max_depth: float,
|
||||||
*,
|
*,
|
||||||
|
n_rays_total: Optional[int] = None,
|
||||||
unit_directions: bool = False,
|
unit_directions: bool = False,
|
||||||
stratified_sampling: bool = False,
|
stratified_sampling: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -294,6 +355,11 @@ class MonteCarloRaysampler(torch.nn.Module):
|
|||||||
n_pts_per_ray: The number of points sampled along each ray.
|
n_pts_per_ray: The number of points sampled along each ray.
|
||||||
min_depth: The minimum depth of each ray-point.
|
min_depth: The minimum depth of each ray-point.
|
||||||
max_depth: The maximum depth of each ray-point.
|
max_depth: The maximum depth of each ray-point.
|
||||||
|
n_rays_total: How many rays in total to sample from the cameras provided. The result
|
||||||
|
is as if `n_rays_total_training` cameras were sampled with replacement from the
|
||||||
|
cameras provided and for every camera one ray was sampled. If set, this disables
|
||||||
|
`n_rays_per_image` and returns the HeterogeneousRayBundleyBundle with
|
||||||
|
batch_size=n_rays_total.
|
||||||
unit_directions: whether to normalize direction vectors in ray bundle.
|
unit_directions: whether to normalize direction vectors in ray bundle.
|
||||||
stratified_sampling: if True, performs stratified sampling in n_pts_per_ray
|
stratified_sampling: if True, performs stratified sampling in n_pts_per_ray
|
||||||
bins for each ray; otherwise takes n_pts_per_ray deterministic points
|
bins for each ray; otherwise takes n_pts_per_ray deterministic points
|
||||||
@ -308,6 +374,7 @@ class MonteCarloRaysampler(torch.nn.Module):
|
|||||||
self._n_pts_per_ray = n_pts_per_ray
|
self._n_pts_per_ray = n_pts_per_ray
|
||||||
self._min_depth = min_depth
|
self._min_depth = min_depth
|
||||||
self._max_depth = max_depth
|
self._max_depth = max_depth
|
||||||
|
self._n_rays_total = n_rays_total
|
||||||
self._unit_directions = unit_directions
|
self._unit_directions = unit_directions
|
||||||
self._stratified_sampling = stratified_sampling
|
self._stratified_sampling = stratified_sampling
|
||||||
|
|
||||||
@ -317,15 +384,16 @@ class MonteCarloRaysampler(torch.nn.Module):
|
|||||||
*,
|
*,
|
||||||
stratified_sampling: Optional[bool] = None,
|
stratified_sampling: Optional[bool] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> RayBundle:
|
) -> Union[RayBundle, HeterogeneousRayBundle]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
cameras: A batch of `batch_size` cameras from which the rays are emitted.
|
cameras: A batch of `batch_size` cameras from which the rays are emitted.
|
||||||
stratified_sampling: if set, overrides stratified_sampling provided
|
stratified_sampling: if set, overrides stratified_sampling provided
|
||||||
in __init__.
|
in __init__.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A named tuple RayBundle with the following fields:
|
A named tuple `RayBundle` or dataclass `HeterogeneousRayBundle` with the
|
||||||
|
following fields:
|
||||||
|
|
||||||
origins: A tensor of shape
|
origins: A tensor of shape
|
||||||
`(batch_size, n_rays_per_image, 3)`
|
`(batch_size, n_rays_per_image, 3)`
|
||||||
denoting the locations of ray origins in the world coordinates.
|
denoting the locations of ray origins in the world coordinates.
|
||||||
@ -338,7 +406,31 @@ class MonteCarloRaysampler(torch.nn.Module):
|
|||||||
xys: A tensor of shape
|
xys: A tensor of shape
|
||||||
`(batch_size, n_rays_per_image, 2)`
|
`(batch_size, n_rays_per_image, 2)`
|
||||||
containing the 2D image coordinates of each ray.
|
containing the 2D image coordinates of each ray.
|
||||||
|
If `n_rays_total` is provided `batch_size=n_rays_total`and
|
||||||
|
`n_rays_per_image=1` and `HeterogeneousRayBundle` is returned else `RayBundle`
|
||||||
|
is returned.
|
||||||
|
|
||||||
|
`HeterogeneousRayBundle` has additional members:
|
||||||
|
- camera_ids: tensor of shape (M,), where `M` is the number of unique sampled
|
||||||
|
cameras. It represents unique ids of sampled cameras.
|
||||||
|
- camera_counts: tensor of shape (M,), where `M` is the number of unique sampled
|
||||||
|
cameras. Represents how many times each camera from `camera_ids` was sampled
|
||||||
"""
|
"""
|
||||||
|
assert (self._n_rays_total is None) or (
|
||||||
|
self._n_rays_per_image is None
|
||||||
|
), "`self.n_rays_total` and `self.n_rays_per_image` cannot both be defined."
|
||||||
|
|
||||||
|
if self._n_rays_total:
|
||||||
|
(
|
||||||
|
cameras,
|
||||||
|
_,
|
||||||
|
camera_ids,
|
||||||
|
camera_counts,
|
||||||
|
n_rays_per_image,
|
||||||
|
) = _sample_cameras_and_masks(self._n_rays_total, cameras, None)
|
||||||
|
else:
|
||||||
|
camera_ids = torch.range(0, len(cameras), dtype=torch.long)
|
||||||
|
n_rays_per_image = self._n_rays_per_image
|
||||||
|
|
||||||
batch_size = cameras.R.shape[0]
|
batch_size = cameras.R.shape[0]
|
||||||
|
|
||||||
@ -349,7 +441,7 @@ class MonteCarloRaysampler(torch.nn.Module):
|
|||||||
rays_xy = torch.cat(
|
rays_xy = torch.cat(
|
||||||
[
|
[
|
||||||
torch.rand(
|
torch.rand(
|
||||||
size=(batch_size, self._n_rays_per_image, 1),
|
size=(batch_size, n_rays_per_image, 1),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
@ -369,7 +461,7 @@ class MonteCarloRaysampler(torch.nn.Module):
|
|||||||
else self._stratified_sampling
|
else self._stratified_sampling
|
||||||
)
|
)
|
||||||
|
|
||||||
return _xy_to_ray_bundle(
|
ray_bundle = _xy_to_ray_bundle(
|
||||||
cameras,
|
cameras,
|
||||||
rays_xy,
|
rays_xy,
|
||||||
self._min_depth,
|
self._min_depth,
|
||||||
@ -379,6 +471,13 @@ class MonteCarloRaysampler(torch.nn.Module):
|
|||||||
stratified_sampling,
|
stratified_sampling,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
# pyre-ignore[61]
|
||||||
|
_pack_ray_bundle(ray_bundle, camera_ids, camera_counts)
|
||||||
|
if self._n_rays_total
|
||||||
|
else ray_bundle
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Settings for backwards compatibility
|
# Settings for backwards compatibility
|
||||||
def GridRaysampler(
|
def GridRaysampler(
|
||||||
@ -602,3 +701,74 @@ def _jiggle_within_stratas(bin_centers: torch.Tensor) -> torch.Tensor:
|
|||||||
# Samples in those intervals.
|
# Samples in those intervals.
|
||||||
jiggled = lower + (upper - lower) * torch.rand_like(lower)
|
jiggled = lower + (upper - lower) * torch.rand_like(lower)
|
||||||
return jiggled
|
return jiggled
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_cameras_and_masks(
|
||||||
|
n_samples: int, cameras: CamerasBase, mask: Optional[torch.Tensor] = None
|
||||||
|
) -> Tuple[
|
||||||
|
CamerasBase, Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor
|
||||||
|
]:
|
||||||
|
"""
|
||||||
|
Samples n_rays_total cameras and masks and returns them in a form
|
||||||
|
(camera_idx, count), where count represents number of times the same camera
|
||||||
|
has been sampled.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_samples: how many camera and mask pairs to sample
|
||||||
|
cameras: A batch of `batch_size` cameras from which the rays are emitted.
|
||||||
|
mask: Optional. Should be of size (batch_size, image_height, image_width).
|
||||||
|
Returns:
|
||||||
|
tuple of a form (sampled_cameras, sampled_masks, unique_sampled_camera_ids,
|
||||||
|
number_of_times_each_sampled_camera_has_been_sampled,
|
||||||
|
max_number_of_times_camera_has_been_sampled,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
sampled_ids = torch.randint(
|
||||||
|
0,
|
||||||
|
len(cameras),
|
||||||
|
size=(n_samples,),
|
||||||
|
dtype=torch.long,
|
||||||
|
)
|
||||||
|
unique_ids, counts = torch.unique(sampled_ids, return_counts=True)
|
||||||
|
return (
|
||||||
|
cameras[unique_ids],
|
||||||
|
mask[unique_ids] if mask is not None else None,
|
||||||
|
unique_ids,
|
||||||
|
counts,
|
||||||
|
torch.max(counts),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _pack_ray_bundle(
|
||||||
|
ray_bundle: RayBundle, camera_ids: torch.Tensor, camera_counts: torch.Tensor
|
||||||
|
) -> HeterogeneousRayBundle:
|
||||||
|
"""
|
||||||
|
Pack the raybundle from [n_cameras, max(rays_per_camera), ...] to
|
||||||
|
[total_num_rays, 1, ...]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ray_bundle: A ray_bundle to pack
|
||||||
|
camera_ids: Unique ids of cameras that were sampled
|
||||||
|
camera_counts: how many of which camera to pack, each count coresponds to
|
||||||
|
one 'row' of the ray_bundle and says how many rays wll be taken
|
||||||
|
from it and packed.
|
||||||
|
Returns:
|
||||||
|
HeterogeneousRayBundle where batch_size=sum(camera_counts) and n_rays_per_image=1
|
||||||
|
"""
|
||||||
|
camera_counts = camera_counts.to(ray_bundle.origins.device)
|
||||||
|
cumsum = torch.cumsum(camera_counts, dim=0, dtype=torch.long)
|
||||||
|
first_idxs = torch.cat(
|
||||||
|
(camera_counts.new_zeros((1,), dtype=torch.long), cumsum[:-1])
|
||||||
|
)
|
||||||
|
num_inputs = int(camera_counts.sum())
|
||||||
|
|
||||||
|
return HeterogeneousRayBundle(
|
||||||
|
origins=padded_to_packed(ray_bundle.origins, first_idxs, num_inputs)[:, None],
|
||||||
|
directions=padded_to_packed(ray_bundle.directions, first_idxs, num_inputs)[
|
||||||
|
:, None
|
||||||
|
],
|
||||||
|
lengths=padded_to_packed(ray_bundle.lengths, first_idxs, num_inputs)[:, None],
|
||||||
|
xys=padded_to_packed(ray_bundle.xys, first_idxs, num_inputs)[:, None],
|
||||||
|
camera_ids=camera_ids,
|
||||||
|
camera_counts=camera_counts,
|
||||||
|
)
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Callable, Tuple
|
from typing import Callable, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -12,7 +12,7 @@ from ...ops.utils import eyes
|
|||||||
from ...structures import Volumes
|
from ...structures import Volumes
|
||||||
from ...transforms import Transform3d
|
from ...transforms import Transform3d
|
||||||
from ..cameras import CamerasBase
|
from ..cameras import CamerasBase
|
||||||
from .raysampling import RayBundle
|
from .raysampling import HeterogeneousRayBundle, RayBundle
|
||||||
from .utils import _validate_ray_bundle_variables, ray_bundle_variables_to_ray_points
|
from .utils import _validate_ray_bundle_variables, ray_bundle_variables_to_ray_points
|
||||||
|
|
||||||
|
|
||||||
@ -44,12 +44,14 @@ class ImplicitRenderer(torch.nn.Module):
|
|||||||
A standard `volumetric_function` has the following signature:
|
A standard `volumetric_function` has the following signature:
|
||||||
```
|
```
|
||||||
def volumetric_function(
|
def volumetric_function(
|
||||||
ray_bundle: RayBundle,
|
ray_bundle: Union[RayBundle, HeterogeneousRayBundle],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]
|
) -> Tuple[torch.Tensor, torch.Tensor]
|
||||||
```
|
```
|
||||||
With the following arguments:
|
With the following arguments:
|
||||||
`ray_bundle`: A RayBundle object containing the following variables:
|
`ray_bundle`: A RayBundle or HeterogeneousRayBundle object
|
||||||
|
containing the following variables:
|
||||||
|
|
||||||
`origins`: A tensor of shape `(minibatch, ..., 3)` denoting
|
`origins`: A tensor of shape `(minibatch, ..., 3)` denoting
|
||||||
the origins of the rendering rays.
|
the origins of the rendering rays.
|
||||||
`directions`: A tensor of shape `(minibatch, ..., 3)`
|
`directions`: A tensor of shape `(minibatch, ..., 3)`
|
||||||
@ -80,7 +82,7 @@ class ImplicitRenderer(torch.nn.Module):
|
|||||||
RGB sphere with a unit diameter is defined as follows:
|
RGB sphere with a unit diameter is defined as follows:
|
||||||
```
|
```
|
||||||
def volumetric_function(
|
def volumetric_function(
|
||||||
ray_bundle: RayBundle,
|
ray_bundle: Union[RayBundle, HeterogeneousRayBundle],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
@ -109,7 +111,8 @@ class ImplicitRenderer(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
raysampler: A `Callable` that takes as input scene cameras
|
raysampler: A `Callable` that takes as input scene cameras
|
||||||
(an instance of `CamerasBase`) and returns a `RayBundle` that
|
(an instance of `CamerasBase`) and returns a
|
||||||
|
RayBundle or HeterogeneousRayBundle, that
|
||||||
describes the rays emitted from the cameras.
|
describes the rays emitted from the cameras.
|
||||||
raymarcher: A `Callable` that receives the response of the
|
raymarcher: A `Callable` that receives the response of the
|
||||||
`volumetric_function` (an input to `self.forward`) evaluated
|
`volumetric_function` (an input to `self.forward`) evaluated
|
||||||
@ -128,7 +131,7 @@ class ImplicitRenderer(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, cameras: CamerasBase, volumetric_function: Callable, **kwargs
|
self, cameras: CamerasBase, volumetric_function: Callable, **kwargs
|
||||||
) -> Tuple[torch.Tensor, RayBundle]:
|
) -> Tuple[torch.Tensor, Union[RayBundle, HeterogeneousRayBundle]]:
|
||||||
"""
|
"""
|
||||||
Render a batch of images using a volumetric function
|
Render a batch of images using a volumetric function
|
||||||
represented as a callable (e.g. a Pytorch module).
|
represented as a callable (e.g. a Pytorch module).
|
||||||
@ -145,15 +148,15 @@ class ImplicitRenderer(torch.nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
images: A tensor of shape `(minibatch, ..., feature_dim + opacity_dim)`
|
images: A tensor of shape `(minibatch, ..., feature_dim + opacity_dim)`
|
||||||
containing the result of the rendering.
|
containing the result of the rendering.
|
||||||
ray_bundle: A `RayBundle` containing the parametrizations of the
|
ray_bundle: A `Union[RayBundle, HeterogeneousRayBundle]` containing
|
||||||
sampled rendering rays.
|
the parametrizations of the sampled rendering rays.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not callable(volumetric_function):
|
if not callable(volumetric_function):
|
||||||
raise ValueError('"volumetric_function" has to be a "Callable" object.')
|
raise ValueError('"volumetric_function" has to be a "Callable" object.')
|
||||||
|
|
||||||
# first call the ray sampler that returns the RayBundle parametrizing
|
# first call the ray sampler that returns the RayBundle or HeterogeneousRayBundle
|
||||||
# the rendering rays.
|
# parametrizing the rendering rays.
|
||||||
ray_bundle = self.raysampler(
|
ray_bundle = self.raysampler(
|
||||||
cameras=cameras, volumetric_function=volumetric_function, **kwargs
|
cameras=cameras, volumetric_function=volumetric_function, **kwargs
|
||||||
)
|
)
|
||||||
@ -211,7 +214,8 @@ class VolumeRenderer(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
raysampler: A `Callable` that takes as input scene cameras
|
raysampler: A `Callable` that takes as input scene cameras
|
||||||
(an instance of `CamerasBase`) and returns a `RayBundle` that
|
(an instance of `CamerasBase`) and returns a
|
||||||
|
`Union[RayBundle, HeterogeneousRayBundle],` that
|
||||||
describes the rays emitted from the cameras.
|
describes the rays emitted from the cameras.
|
||||||
raymarcher: A `Callable` that receives the `volumes`
|
raymarcher: A `Callable` that receives the `volumes`
|
||||||
(an instance of `Volumes` input to `self.forward`)
|
(an instance of `Volumes` input to `self.forward`)
|
||||||
@ -227,7 +231,7 @@ class VolumeRenderer(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, cameras: CamerasBase, volumes: Volumes, **kwargs
|
self, cameras: CamerasBase, volumes: Volumes, **kwargs
|
||||||
) -> Tuple[torch.Tensor, RayBundle]:
|
) -> Tuple[torch.Tensor, Union[RayBundle, HeterogeneousRayBundle]]:
|
||||||
"""
|
"""
|
||||||
Render a batch of images using raymarching over rays cast through
|
Render a batch of images using raymarching over rays cast through
|
||||||
input `Volumes`.
|
input `Volumes`.
|
||||||
@ -242,8 +246,8 @@ class VolumeRenderer(torch.nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
images: A tensor of shape `(minibatch, ..., (feature_dim + opacity_dim)`
|
images: A tensor of shape `(minibatch, ..., (feature_dim + opacity_dim)`
|
||||||
containing the result of the rendering.
|
containing the result of the rendering.
|
||||||
ray_bundle: A `RayBundle` containing the parametrizations of the
|
ray_bundle: A `RayBundle` or `HeterogeneousRayBundle` containing the
|
||||||
sampled rendering rays.
|
parametrizations of the sampled rendering rays.
|
||||||
"""
|
"""
|
||||||
volumetric_function = VolumeSampler(volumes, sample_mode=self._sample_mode)
|
volumetric_function = VolumeSampler(volumes, sample_mode=self._sample_mode)
|
||||||
return self.renderer(
|
return self.renderer(
|
||||||
@ -288,14 +292,14 @@ class VolumeSampler(torch.nn.Module):
|
|||||||
return directions_transform
|
return directions_transform
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, ray_bundle: RayBundle, **kwargs
|
self, ray_bundle: Union[RayBundle, HeterogeneousRayBundle], **kwargs
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Given an input ray parametrization, the forward function samples
|
Given an input ray parametrization, the forward function samples
|
||||||
`self._volumes` at the respective 3D ray-points.
|
`self._volumes` at the respective 3D ray-points.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ray_bundle: A RayBundle object with the following fields:
|
ray_bundle: A RayBundle or HeterogeneousRayBundle object with the following fields:
|
||||||
rays_origins_world: A tensor of shape `(minibatch, ..., 3)` denoting the
|
rays_origins_world: A tensor of shape `(minibatch, ..., 3)` denoting the
|
||||||
origins of the sampling rays in world coords.
|
origins of the sampling rays in world coords.
|
||||||
rays_directions_world: A tensor of shape `(minibatch, ..., 3)`
|
rays_directions_world: A tensor of shape `(minibatch, ..., 3)`
|
||||||
|
@ -4,19 +4,25 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from typing import NamedTuple
|
import dataclasses
|
||||||
|
from typing import NamedTuple, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class RayBundle(NamedTuple):
|
class RayBundle(NamedTuple):
|
||||||
"""
|
"""
|
||||||
RayBundle parametrizes points along projection rays by storing ray `origins`,
|
Parametrizes points along projection rays by storing:
|
||||||
`directions` vectors and `lengths` at which the ray-points are sampled.
|
|
||||||
Furthermore, the xy-locations (`xys`) of the ray pixels are stored as well.
|
origins: A tensor of shape `(..., 3)` denoting the
|
||||||
Note that `directions` don't have to be normalized; they define unit vectors
|
origins of the sampling rays in world coords.
|
||||||
in the respective 1D coordinate systems; see documentation for
|
directions: A tensor of shape `(..., 3)` containing the direction
|
||||||
:func:`ray_bundle_to_ray_points` for the conversion formula.
|
vectors of sampling rays in world coords. They don't have to be normalized;
|
||||||
|
they define unit vectors in the respective 1D coordinate systems; see
|
||||||
|
documentation for :func:`ray_bundle_to_ray_points` for the conversion formula.
|
||||||
|
lengths: A tensor of shape `(..., num_points_per_ray)`
|
||||||
|
containing the lengths at which the rays are sampled.
|
||||||
|
xys: A tensor of shape `(..., 2)`, the xy-locations (`xys`) of the ray pixels
|
||||||
"""
|
"""
|
||||||
|
|
||||||
origins: torch.Tensor
|
origins: torch.Tensor
|
||||||
@ -25,11 +31,46 @@ class RayBundle(NamedTuple):
|
|||||||
xys: torch.Tensor
|
xys: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
def ray_bundle_to_ray_points(ray_bundle: RayBundle) -> torch.Tensor:
|
@dataclasses.dataclass
|
||||||
|
class HeterogeneousRayBundle:
|
||||||
|
"""
|
||||||
|
Members:
|
||||||
|
origins: A tensor of shape `(..., 3)` denoting the
|
||||||
|
origins of the sampling rays in world coords.
|
||||||
|
directions: A tensor of shape `(..., 3)` containing the direction
|
||||||
|
vectors of sampling rays in world coords. They don't have to be normalized;
|
||||||
|
they define unit vectors in the respective 1D coordinate systems; see
|
||||||
|
documentation for :func:`ray_bundle_to_ray_points` for the conversion formula.
|
||||||
|
lengths: A tensor of shape `(..., num_points_per_ray)`
|
||||||
|
containing the lengths at which the rays are sampled.
|
||||||
|
xys: A tensor of shape `(..., 2)`, the xy-locations (`xys`) of the ray pixels
|
||||||
|
camera_ids: A tensor of shape (N, ) which indicates which camera
|
||||||
|
was used to sample the rays. `N` is the number of unique sampled cameras.
|
||||||
|
camera_counts: A tensor of shape (N, ) which how many times the
|
||||||
|
coresponding camera in `camera_ids` was sampled.
|
||||||
|
`sum(camera_counts)==total_number_of_rays`
|
||||||
|
|
||||||
|
If we sample cameras of ids [0, 3, 5, 3, 1, 0, 0] that would be
|
||||||
|
stored as camera_ids=[1, 3, 5, 0] and camera_counts=[1, 2, 1, 3]. `camera_ids` is a
|
||||||
|
set like object with no particular ordering of elements. ith element of
|
||||||
|
`camera_ids` coresponds to the ith element of `camera_counts`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
origins: torch.Tensor
|
||||||
|
directions: torch.Tensor
|
||||||
|
lengths: torch.Tensor
|
||||||
|
xys: torch.Tensor
|
||||||
|
camera_ids: Optional[torch.Tensor] = None
|
||||||
|
camera_counts: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
def ray_bundle_to_ray_points(
|
||||||
|
ray_bundle: Union[RayBundle, HeterogeneousRayBundle]
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Converts rays parametrized with a `ray_bundle` (an instance of the `RayBundle`
|
Converts rays parametrized with a `ray_bundle` (an instance of the `RayBundle`
|
||||||
named tuple) to 3D points by extending each ray according to the corresponding
|
named tuple or HeterogeneousRayBundle dataclass) to 3D points by
|
||||||
length.
|
extending each ray according to the corresponding length.
|
||||||
|
|
||||||
E.g. for 2 dimensional tensors `ray_bundle.origins`, `ray_bundle.directions`
|
E.g. for 2 dimensional tensors `ray_bundle.origins`, `ray_bundle.directions`
|
||||||
and `ray_bundle.lengths`, the ray point at position `[i, j]` is:
|
and `ray_bundle.lengths`, the ray point at position `[i, j]` is:
|
||||||
@ -43,7 +84,7 @@ def ray_bundle_to_ray_points(ray_bundle: RayBundle) -> torch.Tensor:
|
|||||||
`ray_bundle.directions` matter.
|
`ray_bundle.directions` matter.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ray_bundle: A `RayBundle` object with fields:
|
ray_bundle: A `RayBundle` or `HeterogeneousRayBundle` object with fields:
|
||||||
origins: A tensor of shape `(..., 3)`
|
origins: A tensor of shape `(..., 3)`
|
||||||
directions: A tensor of shape `(..., 3)`
|
directions: A tensor of shape `(..., 3)`
|
||||||
lengths: A tensor of shape `(..., num_points_per_ray)`
|
lengths: A tensor of shape `(..., num_points_per_ray)`
|
||||||
|
@ -152,6 +152,8 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
|||||||
min_depth=1.0,
|
min_depth=1.0,
|
||||||
max_depth=10.0,
|
max_depth=10.0,
|
||||||
n_pts_per_ray=10,
|
n_pts_per_ray=10,
|
||||||
|
n_rays_total=None,
|
||||||
|
n_rays_per_image=None,
|
||||||
):
|
):
|
||||||
raysampler_params = {
|
raysampler_params = {
|
||||||
"min_x": min_x,
|
"min_x": min_x,
|
||||||
@ -161,6 +163,8 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
|||||||
"n_pts_per_ray": n_pts_per_ray,
|
"n_pts_per_ray": n_pts_per_ray,
|
||||||
"min_depth": min_depth,
|
"min_depth": min_depth,
|
||||||
"max_depth": max_depth,
|
"max_depth": max_depth,
|
||||||
|
"n_rays_total": n_rays_total,
|
||||||
|
"n_rays_per_image": n_rays_per_image,
|
||||||
}
|
}
|
||||||
|
|
||||||
if issubclass(raysampler_type, MultinomialRaysampler):
|
if issubclass(raysampler_type, MultinomialRaysampler):
|
||||||
@ -168,7 +172,11 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
|||||||
{"image_width": image_width, "image_height": image_height}
|
{"image_width": image_width, "image_height": image_height}
|
||||||
)
|
)
|
||||||
elif issubclass(raysampler_type, MonteCarloRaysampler):
|
elif issubclass(raysampler_type, MonteCarloRaysampler):
|
||||||
raysampler_params["n_rays_per_image"] = image_width * image_height
|
raysampler_params["n_rays_per_image"] = (
|
||||||
|
image_width * image_height
|
||||||
|
if (n_rays_total is None) and (n_rays_per_image is None)
|
||||||
|
else n_rays_per_image
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(str(raysampler_type))
|
raise ValueError(str(raysampler_type))
|
||||||
|
|
||||||
@ -580,3 +588,55 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
|||||||
# samples[3] has enough sources, so must contain 3 distinct values.
|
# samples[3] has enough sources, so must contain 3 distinct values.
|
||||||
self.assertLessEqual(samples[3].max(), 3)
|
self.assertLessEqual(samples[3].max(), 3)
|
||||||
self.assertEqual(len(set(samples[3].tolist())), 3)
|
self.assertEqual(len(set(samples[3].tolist())), 3)
|
||||||
|
|
||||||
|
def test_heterogeneous_sampling(self, batch_size=8):
|
||||||
|
"""
|
||||||
|
Test that the output of heterogeneous sampling has the first dimension equal
|
||||||
|
to n_rays_total and second to 1 and that ray_bundle elements from different
|
||||||
|
sampled cameras are different and equal for same sampled cameras.
|
||||||
|
"""
|
||||||
|
cameras = init_random_cameras(PerspectiveCameras, batch_size, random_z=True)
|
||||||
|
for n_rays_total in [2, 3, 17, 21, 32]:
|
||||||
|
for cls in (MultinomialRaysampler, MonteCarloRaysampler):
|
||||||
|
with self.subTest(cls.__name__ + ", n_rays_total=" + str(n_rays_total)):
|
||||||
|
raysampler = self.init_raysampler(
|
||||||
|
cls, n_rays_total=n_rays_total, n_rays_per_image=None
|
||||||
|
)
|
||||||
|
ray_bundle = raysampler(cameras)
|
||||||
|
|
||||||
|
# test weather they are of the correct shape
|
||||||
|
for attr in ("origins", "directions", "lengths", "xys"):
|
||||||
|
tensor = getattr(ray_bundle, attr)
|
||||||
|
assert tensor.shape[:2] == torch.Size(
|
||||||
|
(n_rays_total, 1)
|
||||||
|
), tensor.shape
|
||||||
|
|
||||||
|
# if two camera ids are same than origins should also be the same
|
||||||
|
# directions and xys are always different and lengths equal
|
||||||
|
for i1, (origin1, dir1, len1, id1) in enumerate(
|
||||||
|
zip(
|
||||||
|
ray_bundle.origins,
|
||||||
|
ray_bundle.directions,
|
||||||
|
ray_bundle.lengths,
|
||||||
|
torch.repeat_interleave(
|
||||||
|
ray_bundle.camera_ids, ray_bundle.camera_counts
|
||||||
|
),
|
||||||
|
)
|
||||||
|
):
|
||||||
|
for i2, (origin2, dir2, len2, id2) in enumerate(
|
||||||
|
zip(
|
||||||
|
ray_bundle.origins,
|
||||||
|
ray_bundle.directions,
|
||||||
|
ray_bundle.lengths,
|
||||||
|
torch.repeat_interleave(
|
||||||
|
ray_bundle.camera_ids, ray_bundle.camera_counts
|
||||||
|
),
|
||||||
|
)
|
||||||
|
):
|
||||||
|
if i1 == i2:
|
||||||
|
continue
|
||||||
|
assert torch.allclose(
|
||||||
|
origin1, origin2, rtol=1e-4, atol=1e-4
|
||||||
|
) == (id1 == id2), (origin1, origin2, id1, id2)
|
||||||
|
assert not torch.allclose(dir1, dir2), (dir1, dir2)
|
||||||
|
self.assertClose(len1, len2), (len1, len2)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user