mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Heterogeneous raysampling -> RayBundleHeterogeneous
Summary: Added heterogeneous raysampling to pytorch3d raysampler, different cameras are sampled different number of times. It now returns RayBundle if heterogeneous raysampling is off and new RayBundleHeterogeneous (with added fields `camera_ids` and `camera_counts`). Heterogeneous raysampling is on if `n_rays_total` is not None. Reviewed By: bottler Differential Revision: D39542222 fbshipit-source-id: d3d88d822ec7696e856007c088dc36a1cfa8c625
This commit is contained in:
		
							parent
							
								
									9a0f9ae572
								
							
						
					
					
						commit
						6ae863f301
					
				@ -31,6 +31,7 @@ from .implicit import (
 | 
			
		||||
    EmissionAbsorptionRaymarcher,
 | 
			
		||||
    GridRaysampler,
 | 
			
		||||
    HarmonicEmbedding,
 | 
			
		||||
    HeterogeneousRayBundle,
 | 
			
		||||
    ImplicitRenderer,
 | 
			
		||||
    MonteCarloRaysampler,
 | 
			
		||||
    MultinomialRaysampler,
 | 
			
		||||
 | 
			
		||||
@ -15,6 +15,7 @@ from .raysampling import (
 | 
			
		||||
)
 | 
			
		||||
from .renderer import ImplicitRenderer, VolumeRenderer, VolumeSampler
 | 
			
		||||
from .utils import (
 | 
			
		||||
    HeterogeneousRayBundle,
 | 
			
		||||
    ray_bundle_to_ray_points,
 | 
			
		||||
    ray_bundle_variables_to_ray_points,
 | 
			
		||||
    RayBundle,
 | 
			
		||||
 | 
			
		||||
@ -5,12 +5,13 @@
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import Optional
 | 
			
		||||
from typing import Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.common.compat import meshgrid_ij
 | 
			
		||||
from pytorch3d.ops import padded_to_packed
 | 
			
		||||
from pytorch3d.renderer.cameras import CamerasBase
 | 
			
		||||
from pytorch3d.renderer.implicit.utils import RayBundle
 | 
			
		||||
from pytorch3d.renderer.implicit.utils import HeterogeneousRayBundle, RayBundle
 | 
			
		||||
from torch.nn import functional as F
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -73,6 +74,7 @@ class MultinomialRaysampler(torch.nn.Module):
 | 
			
		||||
        min_depth: float,
 | 
			
		||||
        max_depth: float,
 | 
			
		||||
        n_rays_per_image: Optional[int] = None,
 | 
			
		||||
        n_rays_total: Optional[int] = None,
 | 
			
		||||
        unit_directions: bool = False,
 | 
			
		||||
        stratified_sampling: bool = False,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
@ -88,6 +90,11 @@ class MultinomialRaysampler(torch.nn.Module):
 | 
			
		||||
            min_depth: The minimum depth of a ray-point.
 | 
			
		||||
            max_depth: The maximum depth of a ray-point.
 | 
			
		||||
            n_rays_per_image: If given, this amount of rays are sampled from the grid.
 | 
			
		||||
            n_rays_total: How many rays in total to sample from the cameras provided. The result
 | 
			
		||||
                is as if `n_rays_total` cameras were sampled with replacement from the
 | 
			
		||||
                cameras provided and for every camera one ray was sampled. If set, this disables
 | 
			
		||||
                `n_rays_per_image` and returns the HeterogeneousRayBundle with
 | 
			
		||||
                batch_size=n_rays_total.
 | 
			
		||||
            unit_directions: whether to normalize direction vectors in ray bundle.
 | 
			
		||||
            stratified_sampling: if True, performs stratified random sampling
 | 
			
		||||
                along the ray; otherwise takes ray points at deterministic offsets.
 | 
			
		||||
@ -97,6 +104,7 @@ class MultinomialRaysampler(torch.nn.Module):
 | 
			
		||||
        self._min_depth = min_depth
 | 
			
		||||
        self._max_depth = max_depth
 | 
			
		||||
        self._n_rays_per_image = n_rays_per_image
 | 
			
		||||
        self._n_rays_total = n_rays_total
 | 
			
		||||
        self._unit_directions = unit_directions
 | 
			
		||||
        self._stratified_sampling = stratified_sampling
 | 
			
		||||
 | 
			
		||||
@ -125,8 +133,9 @@ class MultinomialRaysampler(torch.nn.Module):
 | 
			
		||||
        n_rays_per_image: Optional[int] = None,
 | 
			
		||||
        n_pts_per_ray: Optional[int] = None,
 | 
			
		||||
        stratified_sampling: Optional[bool] = None,
 | 
			
		||||
        n_rays_total: Optional[int] = None,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> RayBundle:
 | 
			
		||||
    ) -> Union[RayBundle, HeterogeneousRayBundle]:
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            cameras: A batch of `batch_size` cameras from which the rays are emitted.
 | 
			
		||||
@ -138,8 +147,15 @@ class MultinomialRaysampler(torch.nn.Module):
 | 
			
		||||
            n_pts_per_ray: The number of points sampled along each ray.
 | 
			
		||||
            stratified_sampling: if set, overrides stratified_sampling provided
 | 
			
		||||
                in __init__.
 | 
			
		||||
            n_rays_total: How many rays in total to sample from the cameras provided. The result
 | 
			
		||||
                is as if `n_rays_total_training` cameras were sampled with replacement from the
 | 
			
		||||
                cameras provided and for every camera one ray was sampled. If set, this disables
 | 
			
		||||
                `n_rays_per_image` and returns the HeterogeneousRayBundle with
 | 
			
		||||
                batch_size=n_rays_total.
 | 
			
		||||
        Returns:
 | 
			
		||||
            A named tuple RayBundle with the following fields:
 | 
			
		||||
            A named tuple RayBundle or dataclass HeterogeneousRayBundle with the
 | 
			
		||||
            following fields:
 | 
			
		||||
 | 
			
		||||
            origins: A tensor of shape
 | 
			
		||||
                `(batch_size, s1, s2, 3)`
 | 
			
		||||
                denoting the locations of ray origins in the world coordinates.
 | 
			
		||||
@ -153,23 +169,56 @@ class MultinomialRaysampler(torch.nn.Module):
 | 
			
		||||
                `(batch_size, s1, s2, 2)`
 | 
			
		||||
                containing the 2D image coordinates of each ray or,
 | 
			
		||||
                if mask is given, `(batch_size, n, 1, 2)`
 | 
			
		||||
            Here `s1, s2` refer to spatial dimensions. Unless the mask is
 | 
			
		||||
            given, they equal `(image_height, image_width)`, otherwise `(n, 1)`,
 | 
			
		||||
            where `n` is `n_rays_per_image` if provided, otherwise the minimum
 | 
			
		||||
            cardinality of the mask in the batch.
 | 
			
		||||
            Here `s1, s2` refer to spatial dimensions.
 | 
			
		||||
            `(s1, s2)` refer to (highest priority first):
 | 
			
		||||
                - `(1, 1)` if `n_rays_total` is provided, (batch_size=n_rays_total)
 | 
			
		||||
                - `(n_rays_per_image, 1) if `n_rays_per_image` if provided,
 | 
			
		||||
                - `(n, 1)` where n is the minimum cardinality of the mask
 | 
			
		||||
                        in the batch if `mask` is provided
 | 
			
		||||
                - `(image_height, image_width)` if nothing from above is satisfied
 | 
			
		||||
 | 
			
		||||
            `HeterogeneousRayBundle` has additional members:
 | 
			
		||||
                - camera_ids: tensor of shape (M,), where `M` is the number of unique sampled
 | 
			
		||||
                    cameras. It represents unique ids of sampled cameras.
 | 
			
		||||
                - camera_counts: tensor of shape (M,), where `M` is the number of unique sampled
 | 
			
		||||
                    cameras. Represents how many times each camera from `camera_ids` was sampled
 | 
			
		||||
 | 
			
		||||
            `HeterogeneousRayBundle` is returned if `n_rays_total` is provided else `RayBundle`
 | 
			
		||||
            is returned.
 | 
			
		||||
        """
 | 
			
		||||
        n_rays_total = n_rays_total or self._n_rays_total
 | 
			
		||||
        n_rays_per_image = n_rays_per_image or self._n_rays_per_image
 | 
			
		||||
        assert (n_rays_total is None) or (
 | 
			
		||||
            n_rays_per_image is None
 | 
			
		||||
        ), "`n_rays_total` and `n_rays_per_image` cannot both be defined."
 | 
			
		||||
        if n_rays_total:
 | 
			
		||||
            (
 | 
			
		||||
                cameras,
 | 
			
		||||
                mask,
 | 
			
		||||
                camera_ids,  # unique ids of sampled cameras
 | 
			
		||||
                camera_counts,  # number of times unique camera id was sampled
 | 
			
		||||
                # `n_rays_per_image` is equal to the max number of times a simgle camera
 | 
			
		||||
                # was sampled. We sample all cameras at `camera_ids` `n_rays_per_image` times
 | 
			
		||||
                # and then discard the unneeded rays.
 | 
			
		||||
                # pyre-ignore[9]
 | 
			
		||||
                n_rays_per_image,
 | 
			
		||||
            ) = _sample_cameras_and_masks(n_rays_total, cameras, mask)
 | 
			
		||||
        else:
 | 
			
		||||
            camera_ids = torch.range(0, len(cameras), dtype=torch.long)
 | 
			
		||||
 | 
			
		||||
        batch_size = cameras.R.shape[0]
 | 
			
		||||
        device = cameras.device
 | 
			
		||||
 | 
			
		||||
        # expand the (H, W, 2) grid batch_size-times to (B, H, W, 2)
 | 
			
		||||
        xy_grid = self._xy_grid.to(device).expand(batch_size, -1, -1, -1)
 | 
			
		||||
 | 
			
		||||
        num_rays = n_rays_per_image or self._n_rays_per_image
 | 
			
		||||
        if mask is not None and num_rays is None:
 | 
			
		||||
        if mask is not None and n_rays_per_image is None:
 | 
			
		||||
            # if num rays not given, sample according to the smallest mask
 | 
			
		||||
            num_rays = num_rays or mask.sum(dim=(1, 2)).min().int().item()
 | 
			
		||||
            n_rays_per_image = (
 | 
			
		||||
                n_rays_per_image or mask.sum(dim=(1, 2)).min().int().item()
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if num_rays is not None:
 | 
			
		||||
        if n_rays_per_image is not None:
 | 
			
		||||
            if mask is not None:
 | 
			
		||||
                assert mask.shape == xy_grid.shape[:3]
 | 
			
		||||
                weights = mask.reshape(batch_size, -1)
 | 
			
		||||
@ -181,7 +230,9 @@ class MultinomialRaysampler(torch.nn.Module):
 | 
			
		||||
                weights = xy_grid.new_ones(batch_size, width * height)
 | 
			
		||||
            # pyre-fixme[6]: For 2nd param expected `int` but got `Union[bool,
 | 
			
		||||
            #  float, int]`.
 | 
			
		||||
            rays_idx = _safe_multinomial(weights, num_rays)[..., None].expand(-1, -1, 2)
 | 
			
		||||
            rays_idx = _safe_multinomial(weights, n_rays_per_image)[..., None].expand(
 | 
			
		||||
                -1, -1, 2
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            xy_grid = torch.gather(xy_grid.reshape(batch_size, -1, 2), 1, rays_idx)[
 | 
			
		||||
                :, :, None
 | 
			
		||||
@ -198,7 +249,7 @@ class MultinomialRaysampler(torch.nn.Module):
 | 
			
		||||
            else self._stratified_sampling
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return _xy_to_ray_bundle(
 | 
			
		||||
        ray_bundle = _xy_to_ray_bundle(
 | 
			
		||||
            cameras,
 | 
			
		||||
            xy_grid,
 | 
			
		||||
            min_depth,
 | 
			
		||||
@ -208,6 +259,13 @@ class MultinomialRaysampler(torch.nn.Module):
 | 
			
		||||
            stratified_sampling,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return (
 | 
			
		||||
            # pyre-ignore[61]
 | 
			
		||||
            _pack_ray_bundle(ray_bundle, camera_ids, camera_counts)
 | 
			
		||||
            if n_rays_total
 | 
			
		||||
            else ray_bundle
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NDCMultinomialRaysampler(MultinomialRaysampler):
 | 
			
		||||
    """
 | 
			
		||||
@ -231,6 +289,7 @@ class NDCMultinomialRaysampler(MultinomialRaysampler):
 | 
			
		||||
        min_depth: float,
 | 
			
		||||
        max_depth: float,
 | 
			
		||||
        n_rays_per_image: Optional[int] = None,
 | 
			
		||||
        n_rays_total: Optional[int] = None,
 | 
			
		||||
        unit_directions: bool = False,
 | 
			
		||||
        stratified_sampling: bool = False,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
@ -254,6 +313,7 @@ class NDCMultinomialRaysampler(MultinomialRaysampler):
 | 
			
		||||
            min_depth=min_depth,
 | 
			
		||||
            max_depth=max_depth,
 | 
			
		||||
            n_rays_per_image=n_rays_per_image,
 | 
			
		||||
            n_rays_total=n_rays_total,
 | 
			
		||||
            unit_directions=unit_directions,
 | 
			
		||||
            stratified_sampling=stratified_sampling,
 | 
			
		||||
        )
 | 
			
		||||
@ -281,6 +341,7 @@ class MonteCarloRaysampler(torch.nn.Module):
 | 
			
		||||
        min_depth: float,
 | 
			
		||||
        max_depth: float,
 | 
			
		||||
        *,
 | 
			
		||||
        n_rays_total: Optional[int] = None,
 | 
			
		||||
        unit_directions: bool = False,
 | 
			
		||||
        stratified_sampling: bool = False,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
@ -294,6 +355,11 @@ class MonteCarloRaysampler(torch.nn.Module):
 | 
			
		||||
            n_pts_per_ray: The number of points sampled along each ray.
 | 
			
		||||
            min_depth: The minimum depth of each ray-point.
 | 
			
		||||
            max_depth: The maximum depth of each ray-point.
 | 
			
		||||
            n_rays_total: How many rays in total to sample from the cameras provided. The result
 | 
			
		||||
                is as if `n_rays_total_training` cameras were sampled with replacement from the
 | 
			
		||||
                cameras provided and for every camera one ray was sampled. If set, this disables
 | 
			
		||||
                `n_rays_per_image` and returns the HeterogeneousRayBundleyBundle with
 | 
			
		||||
                batch_size=n_rays_total.
 | 
			
		||||
            unit_directions: whether to normalize direction vectors in ray bundle.
 | 
			
		||||
            stratified_sampling: if True, performs stratified sampling in n_pts_per_ray
 | 
			
		||||
                bins for each ray; otherwise takes n_pts_per_ray deterministic points
 | 
			
		||||
@ -308,6 +374,7 @@ class MonteCarloRaysampler(torch.nn.Module):
 | 
			
		||||
        self._n_pts_per_ray = n_pts_per_ray
 | 
			
		||||
        self._min_depth = min_depth
 | 
			
		||||
        self._max_depth = max_depth
 | 
			
		||||
        self._n_rays_total = n_rays_total
 | 
			
		||||
        self._unit_directions = unit_directions
 | 
			
		||||
        self._stratified_sampling = stratified_sampling
 | 
			
		||||
 | 
			
		||||
@ -317,15 +384,16 @@ class MonteCarloRaysampler(torch.nn.Module):
 | 
			
		||||
        *,
 | 
			
		||||
        stratified_sampling: Optional[bool] = None,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> RayBundle:
 | 
			
		||||
    ) -> Union[RayBundle, HeterogeneousRayBundle]:
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            cameras: A batch of `batch_size` cameras from which the rays are emitted.
 | 
			
		||||
            stratified_sampling: if set, overrides stratified_sampling provided
 | 
			
		||||
                in __init__.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            A named tuple RayBundle with the following fields:
 | 
			
		||||
            A named tuple `RayBundle` or dataclass `HeterogeneousRayBundle` with the
 | 
			
		||||
            following fields:
 | 
			
		||||
 | 
			
		||||
            origins: A tensor of shape
 | 
			
		||||
                `(batch_size, n_rays_per_image, 3)`
 | 
			
		||||
                denoting the locations of ray origins in the world coordinates.
 | 
			
		||||
@ -338,7 +406,31 @@ class MonteCarloRaysampler(torch.nn.Module):
 | 
			
		||||
            xys: A tensor of shape
 | 
			
		||||
                `(batch_size, n_rays_per_image, 2)`
 | 
			
		||||
                containing the 2D image coordinates of each ray.
 | 
			
		||||
            If `n_rays_total` is provided `batch_size=n_rays_total`and
 | 
			
		||||
            `n_rays_per_image=1` and `HeterogeneousRayBundle` is returned else `RayBundle`
 | 
			
		||||
            is returned.
 | 
			
		||||
 | 
			
		||||
            `HeterogeneousRayBundle` has additional members:
 | 
			
		||||
                - camera_ids: tensor of shape (M,), where `M` is the number of unique sampled
 | 
			
		||||
                    cameras. It represents unique ids of sampled cameras.
 | 
			
		||||
                - camera_counts: tensor of shape (M,), where `M` is the number of unique sampled
 | 
			
		||||
                    cameras. Represents how many times each camera from `camera_ids` was sampled
 | 
			
		||||
        """
 | 
			
		||||
        assert (self._n_rays_total is None) or (
 | 
			
		||||
            self._n_rays_per_image is None
 | 
			
		||||
        ), "`self.n_rays_total` and `self.n_rays_per_image` cannot both be defined."
 | 
			
		||||
 | 
			
		||||
        if self._n_rays_total:
 | 
			
		||||
            (
 | 
			
		||||
                cameras,
 | 
			
		||||
                _,
 | 
			
		||||
                camera_ids,
 | 
			
		||||
                camera_counts,
 | 
			
		||||
                n_rays_per_image,
 | 
			
		||||
            ) = _sample_cameras_and_masks(self._n_rays_total, cameras, None)
 | 
			
		||||
        else:
 | 
			
		||||
            camera_ids = torch.range(0, len(cameras), dtype=torch.long)
 | 
			
		||||
            n_rays_per_image = self._n_rays_per_image
 | 
			
		||||
 | 
			
		||||
        batch_size = cameras.R.shape[0]
 | 
			
		||||
 | 
			
		||||
@ -349,7 +441,7 @@ class MonteCarloRaysampler(torch.nn.Module):
 | 
			
		||||
        rays_xy = torch.cat(
 | 
			
		||||
            [
 | 
			
		||||
                torch.rand(
 | 
			
		||||
                    size=(batch_size, self._n_rays_per_image, 1),
 | 
			
		||||
                    size=(batch_size, n_rays_per_image, 1),
 | 
			
		||||
                    dtype=torch.float32,
 | 
			
		||||
                    device=device,
 | 
			
		||||
                )
 | 
			
		||||
@ -369,7 +461,7 @@ class MonteCarloRaysampler(torch.nn.Module):
 | 
			
		||||
            else self._stratified_sampling
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return _xy_to_ray_bundle(
 | 
			
		||||
        ray_bundle = _xy_to_ray_bundle(
 | 
			
		||||
            cameras,
 | 
			
		||||
            rays_xy,
 | 
			
		||||
            self._min_depth,
 | 
			
		||||
@ -379,6 +471,13 @@ class MonteCarloRaysampler(torch.nn.Module):
 | 
			
		||||
            stratified_sampling,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return (
 | 
			
		||||
            # pyre-ignore[61]
 | 
			
		||||
            _pack_ray_bundle(ray_bundle, camera_ids, camera_counts)
 | 
			
		||||
            if self._n_rays_total
 | 
			
		||||
            else ray_bundle
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Settings for backwards compatibility
 | 
			
		||||
def GridRaysampler(
 | 
			
		||||
@ -602,3 +701,74 @@ def _jiggle_within_stratas(bin_centers: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    # Samples in those intervals.
 | 
			
		||||
    jiggled = lower + (upper - lower) * torch.rand_like(lower)
 | 
			
		||||
    return jiggled
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _sample_cameras_and_masks(
 | 
			
		||||
    n_samples: int, cameras: CamerasBase, mask: Optional[torch.Tensor] = None
 | 
			
		||||
) -> Tuple[
 | 
			
		||||
    CamerasBase, Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor
 | 
			
		||||
]:
 | 
			
		||||
    """
 | 
			
		||||
    Samples n_rays_total cameras and masks and returns them in a form
 | 
			
		||||
    (camera_idx, count), where count represents number of times the same camera
 | 
			
		||||
    has been sampled.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        n_samples: how many camera and mask pairs to sample
 | 
			
		||||
        cameras: A batch of `batch_size` cameras from which the rays are emitted.
 | 
			
		||||
        mask: Optional. Should be of size (batch_size, image_height, image_width).
 | 
			
		||||
    Returns:
 | 
			
		||||
        tuple of a form (sampled_cameras, sampled_masks, unique_sampled_camera_ids,
 | 
			
		||||
            number_of_times_each_sampled_camera_has_been_sampled,
 | 
			
		||||
            max_number_of_times_camera_has_been_sampled,
 | 
			
		||||
            )
 | 
			
		||||
    """
 | 
			
		||||
    sampled_ids = torch.randint(
 | 
			
		||||
        0,
 | 
			
		||||
        len(cameras),
 | 
			
		||||
        size=(n_samples,),
 | 
			
		||||
        dtype=torch.long,
 | 
			
		||||
    )
 | 
			
		||||
    unique_ids, counts = torch.unique(sampled_ids, return_counts=True)
 | 
			
		||||
    return (
 | 
			
		||||
        cameras[unique_ids],
 | 
			
		||||
        mask[unique_ids] if mask is not None else None,
 | 
			
		||||
        unique_ids,
 | 
			
		||||
        counts,
 | 
			
		||||
        torch.max(counts),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _pack_ray_bundle(
 | 
			
		||||
    ray_bundle: RayBundle, camera_ids: torch.Tensor, camera_counts: torch.Tensor
 | 
			
		||||
) -> HeterogeneousRayBundle:
 | 
			
		||||
    """
 | 
			
		||||
    Pack the raybundle from [n_cameras, max(rays_per_camera), ...] to
 | 
			
		||||
        [total_num_rays, 1, ...]
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        ray_bundle: A ray_bundle to pack
 | 
			
		||||
        camera_ids: Unique ids of cameras that were sampled
 | 
			
		||||
        camera_counts: how many of which camera to pack, each count coresponds to
 | 
			
		||||
            one 'row' of the ray_bundle and says how many rays wll be taken
 | 
			
		||||
            from it and packed.
 | 
			
		||||
    Returns:
 | 
			
		||||
        HeterogeneousRayBundle where batch_size=sum(camera_counts) and n_rays_per_image=1
 | 
			
		||||
    """
 | 
			
		||||
    camera_counts = camera_counts.to(ray_bundle.origins.device)
 | 
			
		||||
    cumsum = torch.cumsum(camera_counts, dim=0, dtype=torch.long)
 | 
			
		||||
    first_idxs = torch.cat(
 | 
			
		||||
        (camera_counts.new_zeros((1,), dtype=torch.long), cumsum[:-1])
 | 
			
		||||
    )
 | 
			
		||||
    num_inputs = int(camera_counts.sum())
 | 
			
		||||
 | 
			
		||||
    return HeterogeneousRayBundle(
 | 
			
		||||
        origins=padded_to_packed(ray_bundle.origins, first_idxs, num_inputs)[:, None],
 | 
			
		||||
        directions=padded_to_packed(ray_bundle.directions, first_idxs, num_inputs)[
 | 
			
		||||
            :, None
 | 
			
		||||
        ],
 | 
			
		||||
        lengths=padded_to_packed(ray_bundle.lengths, first_idxs, num_inputs)[:, None],
 | 
			
		||||
        xys=padded_to_packed(ray_bundle.xys, first_idxs, num_inputs)[:, None],
 | 
			
		||||
        camera_ids=camera_ids,
 | 
			
		||||
        camera_counts=camera_counts,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -4,7 +4,7 @@
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
from typing import Callable, Tuple
 | 
			
		||||
from typing import Callable, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
@ -12,7 +12,7 @@ from ...ops.utils import eyes
 | 
			
		||||
from ...structures import Volumes
 | 
			
		||||
from ...transforms import Transform3d
 | 
			
		||||
from ..cameras import CamerasBase
 | 
			
		||||
from .raysampling import RayBundle
 | 
			
		||||
from .raysampling import HeterogeneousRayBundle, RayBundle
 | 
			
		||||
from .utils import _validate_ray_bundle_variables, ray_bundle_variables_to_ray_points
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -44,12 +44,14 @@ class ImplicitRenderer(torch.nn.Module):
 | 
			
		||||
    A standard `volumetric_function` has the following signature:
 | 
			
		||||
    ```
 | 
			
		||||
    def volumetric_function(
 | 
			
		||||
        ray_bundle: RayBundle,
 | 
			
		||||
        ray_bundle: Union[RayBundle, HeterogeneousRayBundle],
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Tuple[torch.Tensor, torch.Tensor]
 | 
			
		||||
    ```
 | 
			
		||||
    With the following arguments:
 | 
			
		||||
        `ray_bundle`: A RayBundle object containing the following variables:
 | 
			
		||||
        `ray_bundle`: A RayBundle or HeterogeneousRayBundle object
 | 
			
		||||
            containing the following variables:
 | 
			
		||||
 | 
			
		||||
            `origins`: A tensor of shape `(minibatch, ..., 3)` denoting
 | 
			
		||||
                the origins of the rendering rays.
 | 
			
		||||
            `directions`: A tensor of shape `(minibatch, ..., 3)`
 | 
			
		||||
@ -80,7 +82,7 @@ class ImplicitRenderer(torch.nn.Module):
 | 
			
		||||
        RGB sphere with a unit diameter is defined as follows:
 | 
			
		||||
        ```
 | 
			
		||||
        def volumetric_function(
 | 
			
		||||
            ray_bundle: RayBundle,
 | 
			
		||||
            ray_bundle: Union[RayBundle, HeterogeneousRayBundle],
 | 
			
		||||
            **kwargs,
 | 
			
		||||
        ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
 | 
			
		||||
@ -109,7 +111,8 @@ class ImplicitRenderer(torch.nn.Module):
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            raysampler: A `Callable` that takes as input scene cameras
 | 
			
		||||
                (an instance of `CamerasBase`) and returns a `RayBundle` that
 | 
			
		||||
                (an instance of `CamerasBase`) and returns a
 | 
			
		||||
                RayBundle or HeterogeneousRayBundle, that
 | 
			
		||||
                describes the rays emitted from the cameras.
 | 
			
		||||
            raymarcher: A `Callable` that receives the response of the
 | 
			
		||||
                `volumetric_function` (an input to `self.forward`) evaluated
 | 
			
		||||
@ -128,7 +131,7 @@ class ImplicitRenderer(torch.nn.Module):
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self, cameras: CamerasBase, volumetric_function: Callable, **kwargs
 | 
			
		||||
    ) -> Tuple[torch.Tensor, RayBundle]:
 | 
			
		||||
    ) -> Tuple[torch.Tensor, Union[RayBundle, HeterogeneousRayBundle]]:
 | 
			
		||||
        """
 | 
			
		||||
        Render a batch of images using a volumetric function
 | 
			
		||||
        represented as a callable (e.g. a Pytorch module).
 | 
			
		||||
@ -145,15 +148,15 @@ class ImplicitRenderer(torch.nn.Module):
 | 
			
		||||
        Returns:
 | 
			
		||||
            images: A tensor of shape `(minibatch, ..., feature_dim + opacity_dim)`
 | 
			
		||||
                containing the result of the rendering.
 | 
			
		||||
            ray_bundle: A `RayBundle` containing the parametrizations of the
 | 
			
		||||
                sampled rendering rays.
 | 
			
		||||
            ray_bundle: A `Union[RayBundle, HeterogeneousRayBundle]` containing
 | 
			
		||||
                the parametrizations of the sampled rendering rays.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        if not callable(volumetric_function):
 | 
			
		||||
            raise ValueError('"volumetric_function" has to be a "Callable" object.')
 | 
			
		||||
 | 
			
		||||
        # first call the ray sampler that returns the RayBundle parametrizing
 | 
			
		||||
        # the rendering rays.
 | 
			
		||||
        # first call the ray sampler that returns the RayBundle or HeterogeneousRayBundle
 | 
			
		||||
        # parametrizing the rendering rays.
 | 
			
		||||
        ray_bundle = self.raysampler(
 | 
			
		||||
            cameras=cameras, volumetric_function=volumetric_function, **kwargs
 | 
			
		||||
        )
 | 
			
		||||
@ -211,7 +214,8 @@ class VolumeRenderer(torch.nn.Module):
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            raysampler: A `Callable` that takes as input scene cameras
 | 
			
		||||
                (an instance of `CamerasBase`) and returns a `RayBundle` that
 | 
			
		||||
                (an instance of `CamerasBase`) and returns a
 | 
			
		||||
                `Union[RayBundle, HeterogeneousRayBundle],` that
 | 
			
		||||
                describes the rays emitted from the cameras.
 | 
			
		||||
            raymarcher: A `Callable` that receives the `volumes`
 | 
			
		||||
                (an instance of `Volumes` input to `self.forward`)
 | 
			
		||||
@ -227,7 +231,7 @@ class VolumeRenderer(torch.nn.Module):
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self, cameras: CamerasBase, volumes: Volumes, **kwargs
 | 
			
		||||
    ) -> Tuple[torch.Tensor, RayBundle]:
 | 
			
		||||
    ) -> Tuple[torch.Tensor, Union[RayBundle, HeterogeneousRayBundle]]:
 | 
			
		||||
        """
 | 
			
		||||
        Render a batch of images using raymarching over rays cast through
 | 
			
		||||
        input `Volumes`.
 | 
			
		||||
@ -242,8 +246,8 @@ class VolumeRenderer(torch.nn.Module):
 | 
			
		||||
        Returns:
 | 
			
		||||
            images: A tensor of shape `(minibatch, ..., (feature_dim + opacity_dim)`
 | 
			
		||||
                containing the result of the rendering.
 | 
			
		||||
            ray_bundle: A `RayBundle` containing the parametrizations of the
 | 
			
		||||
                sampled rendering rays.
 | 
			
		||||
            ray_bundle: A `RayBundle` or `HeterogeneousRayBundle` containing the
 | 
			
		||||
                parametrizations of the sampled rendering rays.
 | 
			
		||||
        """
 | 
			
		||||
        volumetric_function = VolumeSampler(volumes, sample_mode=self._sample_mode)
 | 
			
		||||
        return self.renderer(
 | 
			
		||||
@ -288,14 +292,14 @@ class VolumeSampler(torch.nn.Module):
 | 
			
		||||
        return directions_transform
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self, ray_bundle: RayBundle, **kwargs
 | 
			
		||||
        self, ray_bundle: Union[RayBundle, HeterogeneousRayBundle], **kwargs
 | 
			
		||||
    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
        """
 | 
			
		||||
        Given an input ray parametrization, the forward function samples
 | 
			
		||||
        `self._volumes` at the respective 3D ray-points.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            ray_bundle: A RayBundle object with the following fields:
 | 
			
		||||
            ray_bundle: A RayBundle or HeterogeneousRayBundle object with the following fields:
 | 
			
		||||
                rays_origins_world: A tensor of shape `(minibatch, ..., 3)` denoting the
 | 
			
		||||
                    origins of the sampling rays in world coords.
 | 
			
		||||
                rays_directions_world: A tensor of shape `(minibatch, ..., 3)`
 | 
			
		||||
 | 
			
		||||
@ -4,19 +4,25 @@
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
from typing import NamedTuple
 | 
			
		||||
import dataclasses
 | 
			
		||||
from typing import NamedTuple, Optional, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RayBundle(NamedTuple):
 | 
			
		||||
    """
 | 
			
		||||
    RayBundle parametrizes points along projection rays by storing ray `origins`,
 | 
			
		||||
    `directions` vectors and `lengths` at which the ray-points are sampled.
 | 
			
		||||
    Furthermore, the xy-locations (`xys`) of the ray pixels are stored as well.
 | 
			
		||||
    Note that `directions` don't have to be normalized; they define unit vectors
 | 
			
		||||
    in the respective 1D coordinate systems; see documentation for
 | 
			
		||||
    :func:`ray_bundle_to_ray_points` for the conversion formula.
 | 
			
		||||
    Parametrizes points along projection rays by storing:
 | 
			
		||||
 | 
			
		||||
        origins: A tensor of shape `(..., 3)` denoting the
 | 
			
		||||
            origins of the sampling rays in world coords.
 | 
			
		||||
        directions: A tensor of shape `(..., 3)` containing the direction
 | 
			
		||||
            vectors of sampling rays in world coords. They don't have to be normalized;
 | 
			
		||||
            they define unit vectors in the respective 1D coordinate systems; see
 | 
			
		||||
            documentation for :func:`ray_bundle_to_ray_points` for the conversion formula.
 | 
			
		||||
        lengths: A tensor of shape `(..., num_points_per_ray)`
 | 
			
		||||
            containing the lengths at which the rays are sampled.
 | 
			
		||||
        xys: A tensor of shape `(..., 2)`, the xy-locations (`xys`) of the ray pixels
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    origins: torch.Tensor
 | 
			
		||||
@ -25,11 +31,46 @@ class RayBundle(NamedTuple):
 | 
			
		||||
    xys: torch.Tensor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ray_bundle_to_ray_points(ray_bundle: RayBundle) -> torch.Tensor:
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class HeterogeneousRayBundle:
 | 
			
		||||
    """
 | 
			
		||||
    Members:
 | 
			
		||||
        origins: A tensor of shape `(..., 3)` denoting the
 | 
			
		||||
            origins of the sampling rays in world coords.
 | 
			
		||||
        directions: A tensor of shape `(..., 3)` containing the direction
 | 
			
		||||
            vectors of sampling rays in world coords. They don't have to be normalized;
 | 
			
		||||
            they define unit vectors in the respective 1D coordinate systems; see
 | 
			
		||||
            documentation for :func:`ray_bundle_to_ray_points` for the conversion formula.
 | 
			
		||||
        lengths: A tensor of shape `(..., num_points_per_ray)`
 | 
			
		||||
            containing the lengths at which the rays are sampled.
 | 
			
		||||
        xys: A tensor of shape `(..., 2)`, the xy-locations (`xys`) of the ray pixels
 | 
			
		||||
        camera_ids: A tensor of shape (N, ) which indicates which camera
 | 
			
		||||
            was used to sample the rays. `N` is the number of unique sampled cameras.
 | 
			
		||||
        camera_counts: A tensor of shape (N, ) which how many times the
 | 
			
		||||
            coresponding camera in `camera_ids` was sampled.
 | 
			
		||||
            `sum(camera_counts)==total_number_of_rays`
 | 
			
		||||
 | 
			
		||||
    If we sample cameras of ids [0, 3, 5, 3, 1, 0, 0] that would be
 | 
			
		||||
    stored as camera_ids=[1, 3, 5, 0] and camera_counts=[1, 2, 1, 3]. `camera_ids` is a
 | 
			
		||||
    set like object with no particular ordering of elements. ith element of
 | 
			
		||||
    `camera_ids` coresponds to the ith element of `camera_counts`.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    origins: torch.Tensor
 | 
			
		||||
    directions: torch.Tensor
 | 
			
		||||
    lengths: torch.Tensor
 | 
			
		||||
    xys: torch.Tensor
 | 
			
		||||
    camera_ids: Optional[torch.Tensor] = None
 | 
			
		||||
    camera_counts: Optional[torch.Tensor] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ray_bundle_to_ray_points(
 | 
			
		||||
    ray_bundle: Union[RayBundle, HeterogeneousRayBundle]
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    Converts rays parametrized with a `ray_bundle` (an instance of the `RayBundle`
 | 
			
		||||
    named tuple) to 3D points by extending each ray according to the corresponding
 | 
			
		||||
    length.
 | 
			
		||||
    named tuple or HeterogeneousRayBundle dataclass) to 3D points by
 | 
			
		||||
    extending each ray according to the corresponding length.
 | 
			
		||||
 | 
			
		||||
    E.g. for 2 dimensional tensors `ray_bundle.origins`, `ray_bundle.directions`
 | 
			
		||||
        and `ray_bundle.lengths`, the ray point at position `[i, j]` is:
 | 
			
		||||
@ -43,7 +84,7 @@ def ray_bundle_to_ray_points(ray_bundle: RayBundle) -> torch.Tensor:
 | 
			
		||||
    `ray_bundle.directions` matter.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        ray_bundle: A `RayBundle` object with fields:
 | 
			
		||||
        ray_bundle: A `RayBundle` or `HeterogeneousRayBundle` object with fields:
 | 
			
		||||
            origins: A tensor of shape `(..., 3)`
 | 
			
		||||
            directions: A tensor of shape `(..., 3)`
 | 
			
		||||
            lengths: A tensor of shape `(..., num_points_per_ray)`
 | 
			
		||||
 | 
			
		||||
@ -152,6 +152,8 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        min_depth=1.0,
 | 
			
		||||
        max_depth=10.0,
 | 
			
		||||
        n_pts_per_ray=10,
 | 
			
		||||
        n_rays_total=None,
 | 
			
		||||
        n_rays_per_image=None,
 | 
			
		||||
    ):
 | 
			
		||||
        raysampler_params = {
 | 
			
		||||
            "min_x": min_x,
 | 
			
		||||
@ -161,6 +163,8 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            "n_pts_per_ray": n_pts_per_ray,
 | 
			
		||||
            "min_depth": min_depth,
 | 
			
		||||
            "max_depth": max_depth,
 | 
			
		||||
            "n_rays_total": n_rays_total,
 | 
			
		||||
            "n_rays_per_image": n_rays_per_image,
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if issubclass(raysampler_type, MultinomialRaysampler):
 | 
			
		||||
@ -168,7 +172,11 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
                {"image_width": image_width, "image_height": image_height}
 | 
			
		||||
            )
 | 
			
		||||
        elif issubclass(raysampler_type, MonteCarloRaysampler):
 | 
			
		||||
            raysampler_params["n_rays_per_image"] = image_width * image_height
 | 
			
		||||
            raysampler_params["n_rays_per_image"] = (
 | 
			
		||||
                image_width * image_height
 | 
			
		||||
                if (n_rays_total is None) and (n_rays_per_image is None)
 | 
			
		||||
                else n_rays_per_image
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(str(raysampler_type))
 | 
			
		||||
 | 
			
		||||
@ -580,3 +588,55 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            # samples[3] has enough sources, so must contain 3 distinct values.
 | 
			
		||||
            self.assertLessEqual(samples[3].max(), 3)
 | 
			
		||||
            self.assertEqual(len(set(samples[3].tolist())), 3)
 | 
			
		||||
 | 
			
		||||
    def test_heterogeneous_sampling(self, batch_size=8):
 | 
			
		||||
        """
 | 
			
		||||
        Test that the output of heterogeneous sampling has the first dimension equal
 | 
			
		||||
        to n_rays_total and second to 1 and that ray_bundle elements from different
 | 
			
		||||
        sampled cameras are different and equal for same sampled cameras.
 | 
			
		||||
        """
 | 
			
		||||
        cameras = init_random_cameras(PerspectiveCameras, batch_size, random_z=True)
 | 
			
		||||
        for n_rays_total in [2, 3, 17, 21, 32]:
 | 
			
		||||
            for cls in (MultinomialRaysampler, MonteCarloRaysampler):
 | 
			
		||||
                with self.subTest(cls.__name__ + ", n_rays_total=" + str(n_rays_total)):
 | 
			
		||||
                    raysampler = self.init_raysampler(
 | 
			
		||||
                        cls, n_rays_total=n_rays_total, n_rays_per_image=None
 | 
			
		||||
                    )
 | 
			
		||||
                    ray_bundle = raysampler(cameras)
 | 
			
		||||
 | 
			
		||||
                    # test weather they are of the correct shape
 | 
			
		||||
                    for attr in ("origins", "directions", "lengths", "xys"):
 | 
			
		||||
                        tensor = getattr(ray_bundle, attr)
 | 
			
		||||
                        assert tensor.shape[:2] == torch.Size(
 | 
			
		||||
                            (n_rays_total, 1)
 | 
			
		||||
                        ), tensor.shape
 | 
			
		||||
 | 
			
		||||
                    # if two camera ids are same than origins should also be the same
 | 
			
		||||
                    # directions and xys are always different and lengths equal
 | 
			
		||||
                    for i1, (origin1, dir1, len1, id1) in enumerate(
 | 
			
		||||
                        zip(
 | 
			
		||||
                            ray_bundle.origins,
 | 
			
		||||
                            ray_bundle.directions,
 | 
			
		||||
                            ray_bundle.lengths,
 | 
			
		||||
                            torch.repeat_interleave(
 | 
			
		||||
                                ray_bundle.camera_ids, ray_bundle.camera_counts
 | 
			
		||||
                            ),
 | 
			
		||||
                        )
 | 
			
		||||
                    ):
 | 
			
		||||
                        for i2, (origin2, dir2, len2, id2) in enumerate(
 | 
			
		||||
                            zip(
 | 
			
		||||
                                ray_bundle.origins,
 | 
			
		||||
                                ray_bundle.directions,
 | 
			
		||||
                                ray_bundle.lengths,
 | 
			
		||||
                                torch.repeat_interleave(
 | 
			
		||||
                                    ray_bundle.camera_ids, ray_bundle.camera_counts
 | 
			
		||||
                                ),
 | 
			
		||||
                            )
 | 
			
		||||
                        ):
 | 
			
		||||
                            if i1 == i2:
 | 
			
		||||
                                continue
 | 
			
		||||
                            assert torch.allclose(
 | 
			
		||||
                                origin1, origin2, rtol=1e-4, atol=1e-4
 | 
			
		||||
                            ) == (id1 == id2), (origin1, origin2, id1, id2)
 | 
			
		||||
                            assert not torch.allclose(dir1, dir2), (dir1, dir2)
 | 
			
		||||
                            self.assertClose(len1, len2), (len1, len2)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user