mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
New raysamplers
Summary: New MultinomialRaysampler succeeds GridRaysampler bringing masking and subsampling. Correspondingly, NDCMultinomialRaysampler succeeds NDCGridRaysampler. Reviewed By: nikhilaravi, shapovalov Differential Revision: D33256897 fbshipit-source-id: cd80ec6f35b110d1d20a75c62f4e889ba8fa5d45
This commit is contained in:
parent
174738c33e
commit
3eb4233844
@ -7,7 +7,7 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.renderer import RayBundle, ray_bundle_to_ray_points, HarmonicEmbedding
|
||||
from pytorch3d.renderer import HarmonicEmbedding, RayBundle, ray_bundle_to_ray_points
|
||||
|
||||
from .linear_with_repeat import LinearWithRepeat
|
||||
|
||||
|
@ -32,7 +32,9 @@ from .implicit import (
|
||||
HarmonicEmbedding,
|
||||
ImplicitRenderer,
|
||||
MonteCarloRaysampler,
|
||||
MultinomialRaysampler,
|
||||
NDCGridRaysampler,
|
||||
NDCMultinomialRaysampler,
|
||||
RayBundle,
|
||||
VolumeRenderer,
|
||||
VolumeSampler,
|
||||
|
@ -6,7 +6,13 @@
|
||||
|
||||
from .harmonic_embedding import HarmonicEmbedding
|
||||
from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
|
||||
from .raysampling import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler
|
||||
from .raysampling import (
|
||||
GridRaysampler,
|
||||
MonteCarloRaysampler,
|
||||
MultinomialRaysampler,
|
||||
NDCGridRaysampler,
|
||||
NDCMultinomialRaysampler,
|
||||
)
|
||||
from .renderer import ImplicitRenderer, VolumeRenderer, VolumeSampler
|
||||
from .utils import (
|
||||
RayBundle,
|
||||
@ -14,4 +20,5 @@ from .utils import (
|
||||
ray_bundle_variables_to_ray_points,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
|
@ -4,22 +4,26 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
from ..cameras import CamerasBase
|
||||
from .utils import RayBundle
|
||||
import torch
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
from pytorch3d.renderer.implicit.utils import RayBundle
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
"""
|
||||
This file defines three raysampling techniques:
|
||||
- GridRaysampler which can be used to sample rays from pixels of an image grid
|
||||
- NDCGridRaysampler which can be used to sample rays from pixels of an image grid,
|
||||
- MultinomialRaysampler which can be used to sample rays from pixels of an image grid
|
||||
- NDCMultinomialRaysampler which can be used to sample rays from pixels of an image grid,
|
||||
which follows the pytorch3d convention for image grid coordinates
|
||||
- MonteCarloRaysampler which randomly selects image pixels and emits rays from them
|
||||
- MonteCarloRaysampler which randomly selects real-valued locations in the image plane
|
||||
and emits rays from them
|
||||
"""
|
||||
|
||||
|
||||
class GridRaysampler(torch.nn.Module):
|
||||
class MultinomialRaysampler(torch.nn.Module):
|
||||
"""
|
||||
Samples a fixed number of points along rays which are regularly distributed
|
||||
in a batch of rectangular image grids. Points along each ray
|
||||
@ -44,19 +48,20 @@ class GridRaysampler(torch.nn.Module):
|
||||
< --- image_width --- >
|
||||
```
|
||||
|
||||
In order to generate ray points, `GridRaysampler` takes each 3D point of
|
||||
In order to generate ray points, `MultinomialRaysampler` takes each 3D point of
|
||||
the grid (with coordinates `[x, y, depth]`) and unprojects it
|
||||
with `cameras.unproject_points([x, y, depth])`, where `cameras` are an
|
||||
additional input to the `forward` function.
|
||||
|
||||
Note that this is a generic implementation that can support any image grid
|
||||
coordinate convention. For a raysampler which follows the PyTorch3D
|
||||
coordinate conventions please refer to `NDCGridRaysampler`.
|
||||
As such, `NDCGridRaysampler` is a special case of `GridRaysampler`.
|
||||
coordinate conventions please refer to `NDCMultinomialRaysampler`.
|
||||
As such, `NDCMultinomialRaysampler` is a special case of `MultinomialRaysampler`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
min_x: float,
|
||||
max_x: float,
|
||||
min_y: float,
|
||||
@ -66,6 +71,9 @@ class GridRaysampler(torch.nn.Module):
|
||||
n_pts_per_ray: int,
|
||||
min_depth: float,
|
||||
max_depth: float,
|
||||
n_rays_per_image: Optional[int] = None,
|
||||
unit_directions: bool = False,
|
||||
stratified_sampling: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
@ -78,11 +86,18 @@ class GridRaysampler(torch.nn.Module):
|
||||
n_pts_per_ray: The number of points sampled along each ray.
|
||||
min_depth: The minimum depth of a ray-point.
|
||||
max_depth: The maximum depth of a ray-point.
|
||||
n_rays_per_image: If given, this amount of rays are sampled from the grid.
|
||||
unit_directions: whether to normalize direction vectors in ray bundle.
|
||||
stratified_sampling: if set, performs stratified random sampling
|
||||
along the ray; otherwise takes ray points at deterministic offsets.
|
||||
"""
|
||||
super().__init__()
|
||||
self._n_pts_per_ray = n_pts_per_ray
|
||||
self._min_depth = min_depth
|
||||
self._max_depth = max_depth
|
||||
self._n_rays_per_image = n_rays_per_image
|
||||
self._unit_directions = unit_directions
|
||||
self._stratified_sampling = stratified_sampling
|
||||
|
||||
# get the initial grid of image xy coords
|
||||
_xy_grid = torch.stack(
|
||||
@ -96,69 +111,127 @@ class GridRaysampler(torch.nn.Module):
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
self.register_buffer("_xy_grid", _xy_grid, persistent=False)
|
||||
|
||||
def forward(self, cameras: CamerasBase, **kwargs) -> RayBundle:
|
||||
def forward(
|
||||
self,
|
||||
cameras: CamerasBase,
|
||||
*,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
min_depth: Optional[float] = None,
|
||||
max_depth: Optional[float] = None,
|
||||
n_rays_per_image: Optional[int] = None,
|
||||
n_pts_per_ray: Optional[int] = None,
|
||||
stratified_sampling: bool = False,
|
||||
**kwargs,
|
||||
) -> RayBundle:
|
||||
"""
|
||||
Args:
|
||||
cameras: A batch of `batch_size` cameras from which the rays are emitted.
|
||||
mask: if given, the rays are sampled from the mask. Should be of size
|
||||
(batch_size, image_height, image_width).
|
||||
min_depth: The minimum depth of a ray-point.
|
||||
max_depth: The maximum depth of a ray-point.
|
||||
n_rays_per_image: If given, this amount of rays are sampled from the grid.
|
||||
n_pts_per_ray: The number of points sampled along each ray.
|
||||
stratified_sampling: if set, performs stratified sampling in n_pts_per_ray
|
||||
bins for each ray; otherwise takes n_pts_per_ray deterministic points
|
||||
on each ray with uniform offsets.
|
||||
Returns:
|
||||
A named tuple RayBundle with the following fields:
|
||||
origins: A tensor of shape
|
||||
`(batch_size, image_height, image_width, 3)`
|
||||
`(batch_size, s1, s2, 3)`
|
||||
denoting the locations of ray origins in the world coordinates.
|
||||
directions: A tensor of shape
|
||||
`(batch_size, image_height, image_width, 3)`
|
||||
`(batch_size, s1, s2, 3)`
|
||||
denoting the directions of each ray in the world coordinates.
|
||||
lengths: A tensor of shape
|
||||
`(batch_size, image_height, image_width, n_pts_per_ray)`
|
||||
`(batch_size, s1, s2, n_pts_per_ray)`
|
||||
containing the z-coordinate (=depth) of each ray in world units.
|
||||
xys: A tensor of shape
|
||||
`(batch_size, image_height, image_width, 2)`
|
||||
containing the 2D image coordinates of each ray.
|
||||
`(batch_size, s1, s2, 2)`
|
||||
containing the 2D image coordinates of each ray or,
|
||||
if mask is given, `(batch_size, n, 1, 2)`
|
||||
Here `s1, s2` refer to spatial dimensions. Unless the mask is
|
||||
given, they equal `(image_height, image_width)`, otherwise `(n, 1)`,
|
||||
where `n` is `n_rays_per_image` if provided, otherwise the minimum
|
||||
cardinality of the mask in the batch.
|
||||
"""
|
||||
|
||||
batch_size = cameras.R.shape[0]
|
||||
|
||||
device = cameras.device
|
||||
|
||||
# expand the (H, W, 2) grid batch_size-times to (B, H, W, 2)
|
||||
xy_grid = self._xy_grid.to(device)[None].expand(
|
||||
batch_size, *self._xy_grid.shape
|
||||
xy_grid = self._xy_grid.to(device).expand(batch_size, -1, -1, -1)
|
||||
|
||||
num_rays = n_rays_per_image or self._n_rays_per_image
|
||||
if mask is not None and num_rays is None:
|
||||
# if num rays not given, sample according to the smallest mask
|
||||
num_rays = num_rays or mask.sum(dim=(1, 2)).min().int().item()
|
||||
|
||||
if num_rays is not None:
|
||||
if mask is not None:
|
||||
assert mask.shape == xy_grid.shape[:3]
|
||||
weights = mask.reshape(batch_size, -1)
|
||||
else:
|
||||
# it is probably more efficient to use torch.randperm
|
||||
# for uniform weights but it is unlikely given that randperm
|
||||
# is not batched and does not support partial permutation
|
||||
_, width, height, _ = xy_grid.shape
|
||||
weights = xy_grid.new_ones(batch_size, width * height)
|
||||
rays_idx = _safe_multinomial(weights, num_rays)[..., None].expand(-1, -1, 2)
|
||||
|
||||
xy_grid = torch.gather(xy_grid.reshape(batch_size, -1, 2), 1, rays_idx)[
|
||||
:, :, None
|
||||
]
|
||||
|
||||
min_depth = min_depth if min_depth is not None else self._min_depth
|
||||
max_depth = max_depth if max_depth is not None else self._max_depth
|
||||
n_pts_per_ray = (
|
||||
n_pts_per_ray if n_pts_per_ray is not None else self._n_pts_per_ray
|
||||
)
|
||||
stratified_sampling = (
|
||||
stratified_sampling
|
||||
if stratified_sampling is not None
|
||||
else self._stratified_sampling
|
||||
)
|
||||
|
||||
return _xy_to_ray_bundle(
|
||||
cameras, xy_grid, self._min_depth, self._max_depth, self._n_pts_per_ray
|
||||
cameras,
|
||||
xy_grid,
|
||||
min_depth,
|
||||
max_depth,
|
||||
n_pts_per_ray,
|
||||
self._unit_directions,
|
||||
stratified_sampling,
|
||||
)
|
||||
|
||||
|
||||
class NDCGridRaysampler(GridRaysampler):
|
||||
class NDCMultinomialRaysampler(MultinomialRaysampler):
|
||||
"""
|
||||
Samples a fixed number of points along rays which are regularly distributed
|
||||
in a batch of rectangular image grids. Points along each ray
|
||||
have uniformly-spaced z-coordinates between a predefined minimum and maximum depth.
|
||||
|
||||
`NDCGridRaysampler` follows the screen conventions of the `Meshes` and `Pointclouds`
|
||||
`NDCMultinomialRaysampler` follows the screen conventions of the `Meshes` and `Pointclouds`
|
||||
renderers. I.e. the pixel coordinates are in [-1, 1]x[-u, u] or [-u, u]x[-1, 1]
|
||||
where u > 1 is the aspect ratio of the image.
|
||||
|
||||
For the description of arguments, see the documentation to MultinomialRaysampler.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
n_pts_per_ray: int,
|
||||
min_depth: float,
|
||||
max_depth: float,
|
||||
n_rays_per_image: Optional[int] = None,
|
||||
unit_directions: bool = False,
|
||||
stratified_sampling: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
image_width: The horizontal size of the image grid.
|
||||
image_height: The vertical size of the image grid.
|
||||
n_pts_per_ray: The number of points sampled along each ray.
|
||||
min_depth: The minimum depth of a ray-point.
|
||||
max_depth: The maximum depth of a ray-point.
|
||||
"""
|
||||
if image_width >= image_height:
|
||||
range_x = image_width / image_height
|
||||
range_y = 1.0
|
||||
@ -178,6 +251,9 @@ class NDCGridRaysampler(GridRaysampler):
|
||||
n_pts_per_ray=n_pts_per_ray,
|
||||
min_depth=min_depth,
|
||||
max_depth=max_depth,
|
||||
n_rays_per_image=n_rays_per_image,
|
||||
unit_directions=unit_directions,
|
||||
stratified_sampling=stratified_sampling,
|
||||
)
|
||||
|
||||
|
||||
@ -187,6 +263,9 @@ class MonteCarloRaysampler(torch.nn.Module):
|
||||
For each pixel, a fixed number of points is sampled along its ray at uniformly-spaced
|
||||
z-coordinates such that the z-coordinates range between a predefined minimum
|
||||
and maximum depth.
|
||||
|
||||
For practical purposes, this is similar to MultinomialRaysampler without a mask,
|
||||
however sampling at real-valued locations bypassing replacement checks may be faster.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -199,6 +278,9 @@ class MonteCarloRaysampler(torch.nn.Module):
|
||||
n_pts_per_ray: int,
|
||||
min_depth: float,
|
||||
max_depth: float,
|
||||
*,
|
||||
unit_directions: bool = False,
|
||||
stratified_sampling: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
@ -210,6 +292,10 @@ class MonteCarloRaysampler(torch.nn.Module):
|
||||
n_pts_per_ray: The number of points sampled along each ray.
|
||||
min_depth: The minimum depth of each ray-point.
|
||||
max_depth: The maximum depth of each ray-point.
|
||||
unit_directions: whether to normalize direction vectors in ray bundle.
|
||||
stratified_sampling: if set, performs stratified sampling in n_pts_per_ray
|
||||
bins for each ray; otherwise takes n_pts_per_ray deterministic points
|
||||
on each ray with uniform offsets.
|
||||
"""
|
||||
super().__init__()
|
||||
self._min_x = min_x
|
||||
@ -220,11 +306,18 @@ class MonteCarloRaysampler(torch.nn.Module):
|
||||
self._n_pts_per_ray = n_pts_per_ray
|
||||
self._min_depth = min_depth
|
||||
self._max_depth = max_depth
|
||||
self._unit_directions = unit_directions
|
||||
self._stratified_sampling = stratified_sampling
|
||||
|
||||
def forward(self, cameras: CamerasBase, **kwargs) -> RayBundle:
|
||||
def forward(
|
||||
self, cameras: CamerasBase, *, stratified_sampling: bool = False, **kwargs
|
||||
) -> RayBundle:
|
||||
"""
|
||||
Args:
|
||||
cameras: A batch of `batch_size` cameras from which the rays are emitted.
|
||||
stratified_sampling: if set, performs stratified sampling in n_pts_per_ray
|
||||
bins for each ray; otherwise takes n_pts_per_ray deterministic points
|
||||
on each ray with uniform offsets.
|
||||
Returns:
|
||||
A named tuple RayBundle with the following fields:
|
||||
origins: A tensor of shape
|
||||
@ -264,10 +357,132 @@ class MonteCarloRaysampler(torch.nn.Module):
|
||||
dim=2,
|
||||
)
|
||||
|
||||
return _xy_to_ray_bundle(
|
||||
cameras, rays_xy, self._min_depth, self._max_depth, self._n_pts_per_ray
|
||||
stratified_sampling = (
|
||||
stratified_sampling
|
||||
if stratified_sampling is not None
|
||||
else self._stratified_sampling
|
||||
)
|
||||
|
||||
return _xy_to_ray_bundle(
|
||||
cameras,
|
||||
rays_xy,
|
||||
self._min_depth,
|
||||
self._max_depth,
|
||||
self._n_pts_per_ray,
|
||||
self._unit_directions,
|
||||
stratified_sampling,
|
||||
)
|
||||
|
||||
|
||||
# Settings for backwards compatibility
|
||||
def GridRaysampler(
|
||||
min_x: float,
|
||||
max_x: float,
|
||||
min_y: float,
|
||||
max_y: float,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
n_pts_per_ray: int,
|
||||
min_depth: float,
|
||||
max_depth: float,
|
||||
) -> "MultinomialRaysampler":
|
||||
"""
|
||||
GridRaysampler has been DEPRECATED. Use MultinomialRaysampler instead.
|
||||
Preserving GridRaysampler for backward compatibility.
|
||||
"""
|
||||
|
||||
warnings.warn(
|
||||
"""GridRaysampler is deprecated,
|
||||
Use MultinomialRaysampler instead.
|
||||
GridRaysampler will be removed in future releases.""",
|
||||
PendingDeprecationWarning,
|
||||
)
|
||||
|
||||
return MultinomialRaysampler(
|
||||
min_x=min_x,
|
||||
max_x=max_x,
|
||||
min_y=min_y,
|
||||
max_y=max_y,
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
n_pts_per_ray=n_pts_per_ray,
|
||||
min_depth=min_depth,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
|
||||
|
||||
# Settings for backwards compatibility
|
||||
def NDCGridRaysampler(
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
n_pts_per_ray: int,
|
||||
min_depth: float,
|
||||
max_depth: float,
|
||||
) -> "NDCMultinomialRaysampler":
|
||||
"""
|
||||
NDCGridRaysampler has been DEPRECATED. Use NDCMultinomialRaysampler instead.
|
||||
Preserving NDCGridRaysampler for backward compatibility.
|
||||
"""
|
||||
|
||||
warnings.warn(
|
||||
"""NDCGridRaysampler is deprecated,
|
||||
Use NDCMultinomialRaysampler instead.
|
||||
NDCGridRaysampler will be removed in future releases.""",
|
||||
PendingDeprecationWarning,
|
||||
)
|
||||
|
||||
return NDCMultinomialRaysampler(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
n_pts_per_ray=n_pts_per_ray,
|
||||
min_depth=min_depth,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
|
||||
|
||||
def _safe_multinomial(input: torch.Tensor, num_samples: int) -> torch.Tensor:
|
||||
"""
|
||||
Wrapper around torch.multinomial that attempts sampling without replacement
|
||||
when possible, otherwise resorts to sampling with replacement.
|
||||
|
||||
Args:
|
||||
input: tensor of shape [B, n] containing non-negative values;
|
||||
rows are interpreted as unnormalized event probabilities
|
||||
in categorical distributions.
|
||||
num_samples: number of samples to take.
|
||||
|
||||
Returns:
|
||||
LongTensor of shape [B, num_samples] containing
|
||||
values from {0, ..., n - 1} where the elements [i, :] of row i make
|
||||
(1) if there are num_samples or more non-zero values in input[i],
|
||||
a random subset of the indices of those values, with
|
||||
probabilities proportional to the values in input[i, :].
|
||||
|
||||
(2) if not, a random sample with replacement of the indices of
|
||||
those values, with probabilities proportional to them.
|
||||
This sample might not contain all the indices of the
|
||||
non-zero values.
|
||||
Behavior undetermined if there are no non-zero values in a whole row
|
||||
or if there are negative values.
|
||||
"""
|
||||
try:
|
||||
res = torch.multinomial(input, num_samples, replacement=False)
|
||||
except RuntimeError:
|
||||
# this is probably rare, so we don't mind sampling twice
|
||||
res = torch.multinomial(input, num_samples, replacement=True)
|
||||
no_repl = (input > 0.0).sum(dim=-1) >= num_samples
|
||||
res[no_repl] = torch.multinomial(input[no_repl], num_samples, replacement=False)
|
||||
return res
|
||||
|
||||
# in some versions of Pytorch, zero probabilty samples can be drawn without an error
|
||||
# due to this bug: https://github.com/pytorch/pytorch/issues/50034. Handle this case:
|
||||
repl = (input > 0.0).sum(dim=-1) < num_samples
|
||||
# pyre-fixme[16]: Undefined attribute `torch.ByteTensor` has no attribute `any`.
|
||||
if repl.any():
|
||||
res[repl] = torch.multinomial(input[repl], num_samples, replacement=True)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def _xy_to_ray_bundle(
|
||||
cameras: CamerasBase,
|
||||
@ -275,6 +490,8 @@ def _xy_to_ray_bundle(
|
||||
min_depth: float,
|
||||
max_depth: float,
|
||||
n_pts_per_ray: int,
|
||||
unit_directions: bool,
|
||||
stratified_sampling: bool = False,
|
||||
) -> RayBundle:
|
||||
"""
|
||||
Extends the `xy_grid` input of shape `(batch_size, ..., 2)` to rays.
|
||||
@ -283,16 +500,36 @@ def _xy_to_ray_bundle(
|
||||
|
||||
The extended grid is then unprojected with `cameras` to yield
|
||||
ray origins, directions and depths.
|
||||
|
||||
Args:
|
||||
cameras: cameras object representing a batch of cameras.
|
||||
xy_grid: torch.tensor grid of image xy coords.
|
||||
min_depth: The minimum depth of each ray-point.
|
||||
max_depth: The maximum depth of each ray-point.
|
||||
n_pts_per_ray: The number of points sampled along each ray.
|
||||
unit_directions: whether to normalize direction vectors in ray bundle.
|
||||
stratified_sampling: if set, performs stratified sampling in n_pts_per_ray
|
||||
bins for each ray; otherwise takes n_pts_per_ray deterministic points
|
||||
on each ray with uniform offsets.
|
||||
"""
|
||||
batch_size = xy_grid.shape[0]
|
||||
spatial_size = xy_grid.shape[1:-1]
|
||||
n_rays_per_image = spatial_size.numel() # pyre-ignore
|
||||
|
||||
# ray z-coords
|
||||
depths = torch.linspace(
|
||||
min_depth, max_depth, n_pts_per_ray, dtype=xy_grid.dtype, device=xy_grid.device
|
||||
)
|
||||
rays_zs = depths[None, None].expand(batch_size, n_rays_per_image, n_pts_per_ray)
|
||||
rays_zs = xy_grid.new_empty((0,))
|
||||
if n_pts_per_ray > 0:
|
||||
depths = torch.linspace(
|
||||
min_depth,
|
||||
max_depth,
|
||||
n_pts_per_ray,
|
||||
dtype=xy_grid.dtype,
|
||||
device=xy_grid.device,
|
||||
)
|
||||
rays_zs = depths[None, None].expand(batch_size, n_rays_per_image, n_pts_per_ray)
|
||||
|
||||
if stratified_sampling:
|
||||
rays_zs = _jiggle_within_stratas(rays_zs)
|
||||
|
||||
# make two sets of points at a constant depth=1 and 2
|
||||
to_unproject = torch.cat(
|
||||
@ -320,6 +557,8 @@ def _xy_to_ray_bundle(
|
||||
|
||||
# directions are the differences between the two planes of points
|
||||
rays_directions_world = rays_plane_2_world - rays_plane_1_world
|
||||
if unit_directions:
|
||||
rays_directions_world = F.normalize(rays_directions_world, dim=-1)
|
||||
|
||||
# origins are given by subtracting the ray directions from the first plane
|
||||
rays_origins_world = rays_plane_1_world - rays_directions_world
|
||||
@ -330,3 +569,31 @@ def _xy_to_ray_bundle(
|
||||
rays_zs.view(batch_size, *spatial_size, n_pts_per_ray),
|
||||
xy_grid,
|
||||
)
|
||||
|
||||
|
||||
def _jiggle_within_stratas(bin_centers: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Performs sampling of 1 point per bin given the bin centers.
|
||||
|
||||
More specifically, it replaces each point's value `z`
|
||||
with a sample from a uniform random distribution on
|
||||
`[z - delta_−, z + delta_+]`, where `delta_−` is half of the difference
|
||||
between `z` and the previous point, and `delta_+` is half of the difference
|
||||
between the next point and `z`. For the first and last items, the
|
||||
corresponding boundary deltas are assumed zero.
|
||||
|
||||
Args:
|
||||
`bin_centers`: The input points of size (..., N); the result is broadcast
|
||||
along all but the last dimension (the rows). Each row should be
|
||||
sorted in ascending order.
|
||||
|
||||
Returns:
|
||||
a tensor of size (..., N) with the locations jiggled within stratas/bins.
|
||||
"""
|
||||
# Get intervals between bin centers.
|
||||
mids = 0.5 * (bin_centers[..., 1:] + bin_centers[..., :-1])
|
||||
upper = torch.cat((mids, bin_centers[..., -1:]), dim=-1)
|
||||
lower = torch.cat((bin_centers[..., :1], mids), dim=-1)
|
||||
# Samples in those intervals.
|
||||
jiggled = lower + (upper - lower) * torch.rand_like(lower)
|
||||
return jiggled
|
||||
|
@ -10,9 +10,9 @@ from fvcore.common.benchmark import benchmark
|
||||
from pytorch3d.renderer import (
|
||||
FoVOrthographicCameras,
|
||||
FoVPerspectiveCameras,
|
||||
GridRaysampler,
|
||||
MonteCarloRaysampler,
|
||||
NDCGridRaysampler,
|
||||
MultinomialRaysampler,
|
||||
NDCMultinomialRaysampler,
|
||||
OrthographicCameras,
|
||||
PerspectiveCameras,
|
||||
)
|
||||
@ -21,7 +21,11 @@ from test_raysampling import TestRaysampling
|
||||
|
||||
def bm_raysampling() -> None:
|
||||
case_grid = {
|
||||
"raysampler_type": [GridRaysampler, NDCGridRaysampler, MonteCarloRaysampler],
|
||||
"raysampler_type": [
|
||||
MultinomialRaysampler,
|
||||
NDCMultinomialRaysampler,
|
||||
MonteCarloRaysampler,
|
||||
],
|
||||
"camera_type": [
|
||||
PerspectiveCameras,
|
||||
OrthographicCameras,
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from numbers import Real
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
@ -190,3 +191,13 @@ class TestCaseMixin(unittest.TestCase):
|
||||
if msg is not None:
|
||||
self.fail(f"{msg} {err}")
|
||||
self.fail(err)
|
||||
|
||||
def assertConstant(self, input: TensorOrArray, value: Real) -> None:
|
||||
"""
|
||||
Asserts input is entirely filled with value.
|
||||
|
||||
Args:
|
||||
input: tensor or array
|
||||
"""
|
||||
self.assertEqual(input.min(), value)
|
||||
self.assertEqual(input.max(), value)
|
||||
|
@ -5,17 +5,27 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.ops import eyes
|
||||
from pytorch3d.renderer import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler
|
||||
from pytorch3d.renderer import (
|
||||
MonteCarloRaysampler,
|
||||
MultinomialRaysampler,
|
||||
NDCGridRaysampler,
|
||||
NDCMultinomialRaysampler,
|
||||
)
|
||||
from pytorch3d.renderer.cameras import (
|
||||
FoVOrthographicCameras,
|
||||
FoVPerspectiveCameras,
|
||||
OrthographicCameras,
|
||||
PerspectiveCameras,
|
||||
)
|
||||
from pytorch3d.renderer.implicit.raysampling import (
|
||||
_jiggle_within_stratas,
|
||||
_safe_multinomial,
|
||||
)
|
||||
from pytorch3d.renderer.implicit.utils import (
|
||||
ray_bundle_to_ray_points,
|
||||
ray_bundle_variables_to_ray_points,
|
||||
@ -93,14 +103,16 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
@staticmethod
|
||||
def raysampler(
|
||||
raysampler_type=GridRaysampler,
|
||||
camera_type=PerspectiveCameras,
|
||||
n_pts_per_ray=10,
|
||||
batch_size=1,
|
||||
image_width=10,
|
||||
image_height=20,
|
||||
):
|
||||
|
||||
raysampler_type,
|
||||
camera_type,
|
||||
n_pts_per_ray: int,
|
||||
batch_size: int,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> Callable[[], None]:
|
||||
"""
|
||||
Used for benchmarks.
|
||||
"""
|
||||
device = torch.device("cuda")
|
||||
|
||||
# init raysamplers
|
||||
@ -120,7 +132,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
# init a batch of random cameras
|
||||
cameras = init_random_cameras(camera_type, batch_size, random_z=True).to(device)
|
||||
|
||||
def run_raysampler():
|
||||
def run_raysampler() -> None:
|
||||
raysampler(cameras=cameras)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@ -128,7 +140,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
@staticmethod
|
||||
def init_raysampler(
|
||||
raysampler_type=GridRaysampler,
|
||||
raysampler_type,
|
||||
min_x=-1.0,
|
||||
max_x=1.0,
|
||||
min_y=-1.0,
|
||||
@ -149,7 +161,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
"max_depth": max_depth,
|
||||
}
|
||||
|
||||
if issubclass(raysampler_type, GridRaysampler):
|
||||
if issubclass(raysampler_type, MultinomialRaysampler):
|
||||
raysampler_params.update(
|
||||
{"image_width": image_width, "image_height": image_height}
|
||||
)
|
||||
@ -158,7 +170,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
else:
|
||||
raise ValueError(str(raysampler_type))
|
||||
|
||||
if issubclass(raysampler_type, NDCGridRaysampler):
|
||||
if issubclass(raysampler_type, NDCMultinomialRaysampler):
|
||||
# NDCGridRaysampler does not use min/max_x/y
|
||||
for k in ("min_x", "max_x", "min_y", "max_y"):
|
||||
del raysampler_params[k]
|
||||
@ -191,8 +203,8 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
for raysampler_type in (
|
||||
MonteCarloRaysampler,
|
||||
GridRaysampler,
|
||||
NDCGridRaysampler,
|
||||
MultinomialRaysampler,
|
||||
NDCMultinomialRaysampler,
|
||||
):
|
||||
|
||||
raysampler = TestRaysampling.init_raysampler(
|
||||
@ -208,7 +220,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
n_pts_per_ray=n_pts_per_ray,
|
||||
)
|
||||
|
||||
if issubclass(raysampler_type, NDCGridRaysampler):
|
||||
if issubclass(raysampler_type, NDCMultinomialRaysampler):
|
||||
# adjust the gt bounds for NDCGridRaysampler
|
||||
if image_width >= image_height:
|
||||
range_x = image_width / image_height
|
||||
@ -297,7 +309,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
Checks the shapes of raysampler outputs.
|
||||
"""
|
||||
|
||||
if isinstance(raysampler, GridRaysampler):
|
||||
if isinstance(raysampler, MultinomialRaysampler):
|
||||
spatial_size = [image_height, image_width]
|
||||
elif isinstance(raysampler, MonteCarloRaysampler):
|
||||
spatial_size = [image_height * image_width]
|
||||
@ -386,7 +398,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
# check that projected world points' xy coordinates
|
||||
# range correctly between [minx/y, max/y]
|
||||
if isinstance(raysampler, GridRaysampler):
|
||||
if isinstance(raysampler, MultinomialRaysampler):
|
||||
# get the expected coordinates along each grid axis
|
||||
ys, xs = [
|
||||
torch.linspace(
|
||||
@ -518,3 +530,51 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
state = module1.state_dict()
|
||||
module2.load_state_dict(state)
|
||||
|
||||
def test_jiggle(self):
|
||||
# random data which is in ascending order along the last dimension
|
||||
scale = 180
|
||||
data = scale * torch.cumsum(torch.rand(8, 3, 4, 20), dim=-1)
|
||||
|
||||
out = _jiggle_within_stratas(data)
|
||||
self.assertTupleEqual(out.shape, data.shape)
|
||||
|
||||
# Check `out` is in ascending order
|
||||
self.assertGreater(torch.diff(out, dim=-1).min(), 0)
|
||||
|
||||
self.assertConstant(out[..., :-1] < data[..., 1:], True)
|
||||
self.assertConstant(data[..., :-1] < out[..., 1:], True)
|
||||
|
||||
jiggles = out - data
|
||||
# jiggles is random between -scale/2 and scale/2
|
||||
self.assertLess(jiggles.min(), -0.4 * scale)
|
||||
self.assertGreater(jiggles.min(), -0.5 * scale)
|
||||
self.assertGreater(jiggles.max(), 0.4 * scale)
|
||||
self.assertLess(jiggles.max(), 0.5 * scale)
|
||||
|
||||
def test_safe_multinomial(self):
|
||||
mask = [
|
||||
[1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0],
|
||||
[1, 1, 1, 1, 0],
|
||||
]
|
||||
tmask = torch.tensor(mask, dtype=torch.float32)
|
||||
|
||||
for _ in range(5):
|
||||
random_scalar = torch.rand(1)
|
||||
samples = _safe_multinomial(tmask * random_scalar, 3)
|
||||
self.assertTupleEqual(samples.shape, (4, 3))
|
||||
|
||||
# samples[0] is exactly determined
|
||||
self.assertConstant(samples[0], 0)
|
||||
|
||||
self.assertGreaterEqual(samples[1].min(), 0)
|
||||
self.assertLessEqual(samples[1].max(), 1)
|
||||
|
||||
# samples[2] is exactly determined
|
||||
self.assertSetEqual(set(samples[2].tolist()), {0, 1, 2})
|
||||
|
||||
# samples[3] has enough sources, so must contain 3 distinct values.
|
||||
self.assertLessEqual(samples[3].max(), 3)
|
||||
self.assertEqual(len(set(samples[3].tolist())), 3)
|
||||
|
Loading…
x
Reference in New Issue
Block a user