mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-01-17 11:50:35 +08:00
Add utils to approximate the conical frustums as multivariate gaussians.
Summary: Introduce methods to approximate the radii of conical frustums along rays as described in [MipNerf](https://arxiv.org/abs/2103.13415): - Two new attributes are added to ImplicitronRayBundle: bins and radii. Bins is of size n_pts_per_ray + 1. It allows us to manipulate easily and n_pts_per_ray intervals. For example we need the intervals coordinates in the radii computation for \(t_{\mu}, t_{\delta}\). Radii are used to store the radii of the conical frustums. - Add 3 new methods to compute the radii: - approximate_conical_frustum_as_gaussians: It computes the mean along the ray direction, the variance of the conical frustum with respect to t and variance of the conical frustum with respect to its radius. This implementation follows the stable computation defined in the paper. - compute_3d_diagonal_covariance_gaussian: Will leverage the two previously computed variances to find the diagonal covariance of the Gaussian. - conical_frustum_to_gaussian: Mix everything together to compute the means and the diagonal covariances along the ray of the Gaussians. - In AbstractMaskRaySampler, introduces the attribute `cast_ray_bundle_as_cone`. If False it won't change the previous behaviour of the RaySampler. However if True, the samplers will sample `n_pts_per_ray +1` instead of `n_pts_per_ray`. This points are then used to set the bins attribute of ImplicitronRayBundle. The support of HeterogeneousRayBundle has not been added since the current code does not allow it. A safeguard has been added to avoid a silent bug in the future. Reviewed By: shapovalov Differential Revision: D45269190 fbshipit-source-id: bf22fad12d71d55392f054e3f680013aa0d59b78
This commit is contained in:
committed by
Facebook GitHub Bot
parent
4e7715ce66
commit
29b8ebd802
@@ -16,6 +16,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import ReplaceableBase
|
||||
from pytorch3d.ops import packed_to_padded
|
||||
from pytorch3d.renderer.implicit.utils import ray_bundle_variables_to_ray_points
|
||||
|
||||
|
||||
class EvaluationMode(Enum):
|
||||
@@ -47,6 +48,27 @@ class ImplicitronRayBundle:
|
||||
camera_counts: A tensor of shape (N, ) which how many times the
|
||||
coresponding camera in `camera_ids` was sampled.
|
||||
`sum(camera_counts) == minibatch`, where `minibatch = origins.shape[0]`.
|
||||
|
||||
Attributes:
|
||||
origins: A tensor of shape `(..., 3)` denoting the
|
||||
origins of the sampling rays in world coords.
|
||||
directions: A tensor of shape `(..., 3)` containing the direction
|
||||
vectors of sampling rays in world coords. They don't have to be normalized;
|
||||
they define unit vectors in the respective 1D coordinate systems; see
|
||||
documentation for :func:`ray_bundle_to_ray_points` for the conversion formula.
|
||||
lengths: A tensor of shape `(..., num_points_per_ray)`
|
||||
containing the lengths at which the rays are sampled.
|
||||
xys: A tensor of shape `(..., 2)`, the xy-locations (`xys`) of the ray pixels
|
||||
camera_ids: An optional tensor of shape (N, ) which indicates which camera
|
||||
was used to sample the rays. `N` is the number of unique sampled cameras.
|
||||
camera_counts: An optional tensor of shape (N, ) indicates how many times the
|
||||
coresponding camera in `camera_ids` was sampled.
|
||||
`sum(camera_counts)==total_number_of_rays`.
|
||||
bins: An optional tensor of shape `(..., num_points_per_ray + 1)`
|
||||
containing the bins at which the rays are sampled. In this case
|
||||
lengths should be equal to the midpoints of bins `(..., num_points_per_ray)`.
|
||||
pixel_radii_2d: An optional tensor of shape `(..., 1)`
|
||||
base radii of the conical frustums.
|
||||
"""
|
||||
|
||||
origins: torch.Tensor
|
||||
@@ -55,6 +77,45 @@ class ImplicitronRayBundle:
|
||||
xys: torch.Tensor
|
||||
camera_ids: Optional[torch.LongTensor] = None
|
||||
camera_counts: Optional[torch.LongTensor] = None
|
||||
bins: Optional[torch.Tensor] = None
|
||||
pixel_radii_2d: Optional[torch.Tensor] = None
|
||||
|
||||
@classmethod
|
||||
def from_bins(
|
||||
cls,
|
||||
origins: torch.Tensor,
|
||||
directions: torch.Tensor,
|
||||
bins: torch.Tensor,
|
||||
xys: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> "ImplicitronRayBundle":
|
||||
"""
|
||||
Creates a new instance from bins instead of lengths.
|
||||
|
||||
Attributes:
|
||||
origins: A tensor of shape `(..., 3)` denoting the
|
||||
origins of the sampling rays in world coords.
|
||||
directions: A tensor of shape `(..., 3)` containing the direction
|
||||
vectors of sampling rays in world coords. They don't have to be normalized;
|
||||
they define unit vectors in the respective 1D coordinate systems; see
|
||||
documentation for :func:`ray_bundle_to_ray_points` for the conversion formula.
|
||||
bins: A tensor of shape `(..., num_points_per_ray + 1)`
|
||||
containing the bins at which the rays are sampled. In this case
|
||||
lengths is equal to the midpoints of bins `(..., num_points_per_ray)`.
|
||||
xys: A tensor of shape `(..., 2)`, the xy-locations (`xys`) of the ray pixels
|
||||
kwargs: Additional arguments passed to the constructor of ImplicitronRayBundle
|
||||
Returns:
|
||||
An instance of ImplicitronRayBundle.
|
||||
"""
|
||||
|
||||
if bins.shape[-1] <= 1:
|
||||
raise ValueError(
|
||||
"The last dim of bins must be at least superior or equal to 2."
|
||||
)
|
||||
# equivalent to: 0.5 * (bins[..., 1:] + bins[..., :-1]) but more efficient
|
||||
lengths = torch.lerp(bins[..., 1:], bins[..., :-1], 0.5)
|
||||
|
||||
return cls(origins, directions, lengths, xys, bins=bins, **kwargs)
|
||||
|
||||
def is_packed(self) -> bool:
|
||||
"""
|
||||
@@ -195,3 +256,154 @@ class BaseRenderer(ABC, ReplaceableBase):
|
||||
instance of RendererOutput
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def compute_3d_diagonal_covariance_gaussian(
|
||||
rays_directions: torch.Tensor,
|
||||
rays_dir_variance: torch.Tensor,
|
||||
radii_variance: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Transform the variances (rays_dir_variance, radii_variance) of the gaussians from
|
||||
the coordinate frame of the conical frustum to 3D world coordinates.
|
||||
|
||||
It follows the equation 16 of `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_
|
||||
|
||||
Args:
|
||||
rays_directions: A tensor of shape `(..., 3)`
|
||||
rays_dir_variance: A tensor of shape `(..., num_intervals)` representing
|
||||
the variance of the conical frustum with respect to the rays direction.
|
||||
radii_variance: A tensor of shape `(..., num_intervals)` representing
|
||||
the variance of the conical frustum with respect to its radius.
|
||||
eps: a small number to prevent division by zero.
|
||||
|
||||
Returns:
|
||||
A tensor of shape `(..., num_intervals, 3)` containing the diagonal
|
||||
of the covariance matrix.
|
||||
"""
|
||||
d_outer_diag = torch.pow(rays_directions, 2)
|
||||
dir_mag_sq = torch.clamp(torch.sum(d_outer_diag, dim=-1, keepdim=True), min=eps)
|
||||
|
||||
null_outer_diag = 1 - d_outer_diag / dir_mag_sq
|
||||
ray_dir_cov_diag = rays_dir_variance[..., None] * d_outer_diag[..., None, :]
|
||||
xy_cov_diag = radii_variance[..., None] * null_outer_diag[..., None, :]
|
||||
return ray_dir_cov_diag + xy_cov_diag
|
||||
|
||||
|
||||
def approximate_conical_frustum_as_gaussians(
|
||||
bins: torch.Tensor, radii: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Approximates a conical frustum as two Gaussian distributions.
|
||||
|
||||
The Gaussian distributions are characterized by
|
||||
three values:
|
||||
|
||||
- rays_dir_mean: mean along the rays direction
|
||||
(defined as t in the parametric representation of a cone).
|
||||
- rays_dir_variance: the variance of the conical frustum along the rays direction.
|
||||
- radii_variance: variance of the conical frustum with respect to its radius.
|
||||
|
||||
|
||||
The computation is stable and follows equation 7
|
||||
of `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
|
||||
|
||||
For more information on how the mean and variances are computed
|
||||
refers to the appendix of the paper.
|
||||
|
||||
Args:
|
||||
bins: A tensor of shape `(..., num_points_per_ray + 1)`
|
||||
containing the bins at which the rays are sampled.
|
||||
`bin[..., t]` and `bin[..., t+1]` represent respectively
|
||||
the left and right coordinates of the interval.
|
||||
t0: A tensor of shape `(..., num_points_per_ray)`
|
||||
containing the left coordinates of the intervals
|
||||
on which the rays are sampled.
|
||||
t1: A tensor of shape `(..., num_points_per_ray)`
|
||||
containing the rights coordinates of the intervals
|
||||
on which the rays are sampled.
|
||||
radii: A tensor of shape `(..., 1)`
|
||||
base radii of the conical frustums.
|
||||
|
||||
Returns:
|
||||
rays_dir_mean: A tensor of shape `(..., num_intervals)` representing
|
||||
the mean along the rays direction
|
||||
(t in the parametric represention of the cone)
|
||||
rays_dir_variance: A tensor of shape `(..., num_intervals)` representing
|
||||
the variance of the conical frustum along the rays
|
||||
(t in the parametric represention of the cone).
|
||||
radii_variance: A tensor of shape `(..., num_intervals)` representing
|
||||
the variance of the conical frustum with respect to its radius.
|
||||
"""
|
||||
t_mu = torch.lerp(bins[..., 1:], bins[..., :-1], 0.5)
|
||||
t_delta = torch.diff(bins, dim=-1) / 2
|
||||
|
||||
t_mu_pow2 = torch.pow(t_mu, 2)
|
||||
t_delta_pow2 = torch.pow(t_delta, 2)
|
||||
t_delta_pow4 = torch.pow(t_delta, 4)
|
||||
|
||||
den = 3 * t_mu_pow2 + t_delta_pow2
|
||||
|
||||
# mean along the rays direction
|
||||
rays_dir_mean = t_mu + 2 * t_mu * t_delta_pow2 / den
|
||||
|
||||
# Variance of the conical frustum with along the rays directions
|
||||
rays_dir_variance = t_delta_pow2 / 3 - (4 / 15) * (
|
||||
t_delta_pow4 * (12 * t_mu_pow2 - t_delta_pow2) / torch.pow(den, 2)
|
||||
)
|
||||
|
||||
# Variance of the conical frustum with respect to its radius
|
||||
radii_variance = torch.pow(radii, 2) * (
|
||||
t_mu_pow2 / 4 + (5 / 12) * t_delta_pow2 - 4 / 15 * (t_delta_pow4) / den
|
||||
)
|
||||
return rays_dir_mean, rays_dir_variance, radii_variance
|
||||
|
||||
|
||||
def conical_frustum_to_gaussian(
|
||||
ray_bundle: ImplicitronRayBundle,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Approximate a conical frustum following a ray bundle as a Gaussian.
|
||||
|
||||
Args:
|
||||
ray_bundle: A `RayBundle` or `HeterogeneousRayBundle` object with fields:
|
||||
origins: A tensor of shape `(..., 3)`
|
||||
directions: A tensor of shape `(..., 3)`
|
||||
lengths: A tensor of shape `(..., num_points_per_ray)`
|
||||
bins: A tensor of shape `(..., num_points_per_ray + 1)`
|
||||
containing the bins at which the rays are sampled. .
|
||||
pixel_radii_2d: A tensor of shape `(..., 1)`
|
||||
base radii of the conical frustums.
|
||||
|
||||
Returns:
|
||||
means: A tensor of shape `(..., num_points_per_ray - 1, 3)`
|
||||
representing the means of the Gaussians
|
||||
approximating the conical frustums.
|
||||
diag_covariances: A tensor of shape `(...,num_points_per_ray -1, 3)`
|
||||
representing the diagonal covariance matrices of our Gaussians.
|
||||
"""
|
||||
|
||||
if ray_bundle.pixel_radii_2d is None or ray_bundle.bins is None:
|
||||
raise ValueError(
|
||||
"RayBundle pixel_radii_2d or bins have not been provided."
|
||||
" Look at pytorch3d.renderer.implicit.renderer.ray_sampler::"
|
||||
"AbstractMaskRaySampler to see how to compute them. Have you forgot to set"
|
||||
"`cast_ray_bundle_as_cone` to True?"
|
||||
)
|
||||
|
||||
(
|
||||
rays_dir_mean,
|
||||
rays_dir_variance,
|
||||
radii_variance,
|
||||
) = approximate_conical_frustum_as_gaussians(
|
||||
ray_bundle.bins,
|
||||
ray_bundle.pixel_radii_2d,
|
||||
)
|
||||
means = ray_bundle_variables_to_ray_points(
|
||||
ray_bundle.origins, ray_bundle.directions, rays_dir_mean
|
||||
)
|
||||
diag_covariances = compute_3d_diagonal_covariance_gaussian(
|
||||
ray_bundle.directions, rays_dir_variance, radii_variance
|
||||
)
|
||||
return means, diag_covariances
|
||||
|
||||
@@ -11,6 +11,7 @@ from pytorch3d.implicitron.tools import camera_utils
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
from pytorch3d.renderer import NDCMultinomialRaysampler
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
from pytorch3d.renderer.implicit.utils import HeterogeneousRayBundle
|
||||
|
||||
from .base import EvaluationMode, ImplicitronRayBundle, RenderSamplingMode
|
||||
|
||||
@@ -83,7 +84,20 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
||||
stratified_point_sampling_training: if set, performs stratified random sampling
|
||||
along the ray; otherwise takes ray points at deterministic offsets.
|
||||
stratified_point_sampling_evaluation: Same as above but for evaluation.
|
||||
cast_ray_bundle_as_cone: If True, the sampling will generate the bins and radii
|
||||
attribute of ImplicitronRayBundle. The `bins` contain the z-coordinate
|
||||
(=depth) of each ray in world units and are of shape
|
||||
`(batch_size, n_rays_per_image, n_pts_per_ray_training/evaluation + 1)`
|
||||
while `lengths` is equal to the midpoint of the bins:
|
||||
(0.5 * (bins[..., 1:] + bins[..., :-1]).
|
||||
If False, `bins` is None, `radii` is None and `lengths` contains
|
||||
the z-coordinate (=depth) of each ray in world units and are of shape
|
||||
`(batch_size, n_rays_per_image, n_pts_per_ray_training/evaluation)`
|
||||
|
||||
Raises:
|
||||
TypeError: if cast_ray_bundle_as_cone is set to True and n_rays_total_training
|
||||
is not None will result in an error. HeterogeneousRayBundle is
|
||||
not supported for conical frustum computation yet.
|
||||
"""
|
||||
|
||||
image_width: int = 400
|
||||
@@ -97,6 +111,7 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
||||
# stratified sampling vs taking points at deterministic offsets
|
||||
stratified_point_sampling_training: bool = True
|
||||
stratified_point_sampling_evaluation: bool = False
|
||||
cast_ray_bundle_as_cone: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if (self.n_rays_per_image_sampled_from_mask is not None) and (
|
||||
@@ -114,10 +129,20 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
||||
),
|
||||
}
|
||||
|
||||
n_pts_per_ray_training = (
|
||||
self.n_pts_per_ray_training + 1
|
||||
if self.cast_ray_bundle_as_cone
|
||||
else self.n_pts_per_ray_training
|
||||
)
|
||||
n_pts_per_ray_evaluation = (
|
||||
self.n_pts_per_ray_evaluation + 1
|
||||
if self.cast_ray_bundle_as_cone
|
||||
else self.n_pts_per_ray_evaluation
|
||||
)
|
||||
self._training_raysampler = NDCMultinomialRaysampler(
|
||||
image_width=self.image_width,
|
||||
image_height=self.image_height,
|
||||
n_pts_per_ray=self.n_pts_per_ray_training,
|
||||
n_pts_per_ray=n_pts_per_ray_training,
|
||||
min_depth=0.0,
|
||||
max_depth=0.0,
|
||||
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
|
||||
@@ -132,7 +157,7 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
||||
self._evaluation_raysampler = NDCMultinomialRaysampler(
|
||||
image_width=self.image_width,
|
||||
image_height=self.image_height,
|
||||
n_pts_per_ray=self.n_pts_per_ray_evaluation,
|
||||
n_pts_per_ray=n_pts_per_ray_evaluation,
|
||||
min_depth=0.0,
|
||||
max_depth=0.0,
|
||||
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
|
||||
@@ -143,6 +168,11 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
||||
stratified_sampling=self.stratified_point_sampling_evaluation,
|
||||
)
|
||||
|
||||
max_y, min_y = self._training_raysampler.max_y, self._training_raysampler.min_y
|
||||
max_x, min_x = self._training_raysampler.max_x, self._training_raysampler.min_x
|
||||
self.pixel_height: float = (max_y - min_y) / (self.image_height - 1)
|
||||
self.pixel_width: float = (max_x - min_x) / (self.image_width - 1)
|
||||
|
||||
def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -193,19 +223,34 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
||||
min_depth=min_depth,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
|
||||
if isinstance(ray_bundle, tuple):
|
||||
return ImplicitronRayBundle(
|
||||
# pyre-ignore[16]
|
||||
**{k: v for k, v in ray_bundle._asdict().items()}
|
||||
if self.cast_ray_bundle_as_cone and isinstance(
|
||||
ray_bundle, HeterogeneousRayBundle
|
||||
):
|
||||
# If this error rises it means that raysampler has among
|
||||
# its arguments `n_ray_totals`. If it is the case
|
||||
# then you should update the radii computation and lengths
|
||||
# computation to handle padding and unpadding.
|
||||
raise TypeError(
|
||||
"Heterogeneous ray bundle is not supported for conical frustum computation yet"
|
||||
)
|
||||
elif self.cast_ray_bundle_as_cone:
|
||||
pixel_hw: Tuple[float, float] = (self.pixel_height, self.pixel_width)
|
||||
pixel_radii_2d = compute_radii(cameras, ray_bundle.xys[..., :2], pixel_hw)
|
||||
return ImplicitronRayBundle.from_bins(
|
||||
directions=ray_bundle.directions,
|
||||
origins=ray_bundle.origins,
|
||||
bins=ray_bundle.lengths,
|
||||
xys=ray_bundle.xys,
|
||||
pixel_radii_2d=pixel_radii_2d,
|
||||
)
|
||||
|
||||
return ImplicitronRayBundle(
|
||||
directions=ray_bundle.directions,
|
||||
origins=ray_bundle.origins,
|
||||
lengths=ray_bundle.lengths,
|
||||
xys=ray_bundle.xys,
|
||||
camera_ids=ray_bundle.camera_ids,
|
||||
camera_counts=ray_bundle.camera_counts,
|
||||
camera_counts=getattr(ray_bundle, "camera_counts", None),
|
||||
camera_ids=getattr(ray_bundle, "camera_ids", None),
|
||||
)
|
||||
|
||||
|
||||
@@ -274,3 +319,62 @@ class NearFarRaySampler(AbstractMaskRaySampler):
|
||||
Returns the stored near/far planes.
|
||||
"""
|
||||
return self.min_depth, self.max_depth
|
||||
|
||||
|
||||
def compute_radii(
|
||||
cameras: CamerasBase,
|
||||
xy_grid: torch.Tensor,
|
||||
pixel_hw_ndc: Tuple[float, float],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute radii of conical frustums in world coordinates.
|
||||
|
||||
Args:
|
||||
cameras: cameras object representing a batch of cameras.
|
||||
xy_grid: torch.tensor grid of image xy coords.
|
||||
pixel_hw_ndc: pixel height and width in NDC
|
||||
|
||||
Returns:
|
||||
radii: A tensor of shape `(..., 1)` radii of a cone.
|
||||
"""
|
||||
batch_size = xy_grid.shape[0]
|
||||
spatial_size = xy_grid.shape[1:-1]
|
||||
n_rays_per_image = spatial_size.numel()
|
||||
|
||||
xy = xy_grid.view(batch_size, n_rays_per_image, 2)
|
||||
|
||||
# [batch_size, 3 * n_rays_per_image, 2]
|
||||
xy = torch.cat(
|
||||
[
|
||||
xy,
|
||||
# Will allow to find the norm on the x axis
|
||||
xy + torch.tensor([pixel_hw_ndc[1], 0], device=xy.device),
|
||||
# Will allow to find the norm on the y axis
|
||||
xy + torch.tensor([0, pixel_hw_ndc[0]], device=xy.device),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
# [batch_size, 3 * n_rays_per_image, 3]
|
||||
xyz = torch.cat(
|
||||
(
|
||||
xy,
|
||||
xy.new_ones(batch_size, 3 * n_rays_per_image, 1),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# unproject the points
|
||||
unprojected_xyz = cameras.unproject_points(xyz, from_ndc=True)
|
||||
|
||||
plane_world, plane_world_dx, plane_world_dy = torch.split(
|
||||
unprojected_xyz, n_rays_per_image, dim=1
|
||||
)
|
||||
|
||||
# Distance from each unit-norm direction vector to its neighbors.
|
||||
dx_norm = torch.linalg.norm(plane_world_dx - plane_world, dim=-1, keepdims=True)
|
||||
dy_norm = torch.linalg.norm(plane_world_dy - plane_world, dim=-1, keepdims=True)
|
||||
# Cut the distance in half to obtain the base radius: (dx_norm + dy_norm) * 0.5
|
||||
# Scale it by 2/12**0.5 to match the variance of the pixel’s footprint
|
||||
radii = (dx_norm + dy_norm) / 12**0.5
|
||||
|
||||
return radii.view(batch_size, *spatial_size, 1)
|
||||
|
||||
@@ -177,6 +177,20 @@ def chunk_generator(
|
||||
|
||||
for start_idx in iter:
|
||||
end_idx = min(start_idx + chunk_size_in_rays, n_rays)
|
||||
bins = (
|
||||
None
|
||||
if ray_bundle.bins is None
|
||||
else ray_bundle.bins.reshape(batch_size, n_rays, n_pts_per_ray + 1)[
|
||||
:, start_idx:end_idx
|
||||
]
|
||||
)
|
||||
pixel_radii_2d = (
|
||||
None
|
||||
if ray_bundle.pixel_radii_2d is None
|
||||
else ray_bundle.pixel_radii_2d.reshape(batch_size, -1, 1)[
|
||||
:, start_idx:end_idx
|
||||
]
|
||||
)
|
||||
ray_bundle_chunk = ImplicitronRayBundle(
|
||||
origins=ray_bundle.origins.reshape(batch_size, -1, 3)[:, start_idx:end_idx],
|
||||
directions=ray_bundle.directions.reshape(batch_size, -1, 3)[
|
||||
@@ -186,6 +200,8 @@ def chunk_generator(
|
||||
:, start_idx:end_idx
|
||||
],
|
||||
xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx],
|
||||
bins=bins,
|
||||
pixel_radii_2d=pixel_radii_2d,
|
||||
camera_ids=_safe_slice(ray_bundle.camera_ids, start_idx, end_idx),
|
||||
camera_counts=_safe_slice(ray_bundle.camera_counts, start_idx, end_idx),
|
||||
)
|
||||
|
||||
@@ -58,6 +58,12 @@ class MultinomialRaysampler(torch.nn.Module):
|
||||
coordinate convention. For a raysampler which follows the PyTorch3D
|
||||
coordinate conventions please refer to `NDCMultinomialRaysampler`.
|
||||
As such, `NDCMultinomialRaysampler` is a special case of `MultinomialRaysampler`.
|
||||
|
||||
Attributes:
|
||||
min_x: The leftmost x-coordinate of each ray's source pixel's center.
|
||||
max_x: The rightmost x-coordinate of each ray's source pixel's center.
|
||||
min_y: The topmost y-coordinate of each ray's source pixel's center.
|
||||
max_y: The bottommost y-coordinate of each ray's source pixel's center.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -107,7 +113,8 @@ class MultinomialRaysampler(torch.nn.Module):
|
||||
self._n_rays_total = n_rays_total
|
||||
self._unit_directions = unit_directions
|
||||
self._stratified_sampling = stratified_sampling
|
||||
|
||||
self.min_x, self.max_x = min_x, max_x
|
||||
self.min_y, self.max_y = min_y, max_y
|
||||
# get the initial grid of image xy coords
|
||||
y, x = meshgrid_ij(
|
||||
torch.linspace(min_y, max_y, image_height, dtype=torch.float32),
|
||||
|
||||
Reference in New Issue
Block a user