Adapt RayPointRefiner and RayMarcher to support bins.

Summary:
## Context

Bins are used in mipnerf to allow to manipulate easily intervals. For example, by doing the following, `bins[..., :-1]` you will obtain all the left coordinates of your intervals, while doing `bins[..., 1:]` is equals to the right coordinates of your intervals.

We introduce here the support of bins like in MipNerf implementation.

## RayPointRefiner

Small changes have been made to modify RayPointRefiner.
- If bins is None

```
mids = torch.lerp(ray_bundle.lengths[..., 1:], ray_bundle.lengths[…, :-1], 0.5)
z_samples = sample_pdf(
		mids, # [..., npt]
		weights[..., 1:-1], # [..., npt - 1]
               ….
            )
```

- If bins is not None
In the MipNerf implementation the sampling is done on all the bins. It allows us to use the full weights tensor without slashing it.

```
z_samples = sample_pdf(
		ray_bundle.bins, # [..., npt + 1]
		weights, # [..., npt]
               ...
            )
```

## RayMarcher

Add a ray_deltas optional argument. If None, keep the same deltas computation from ray_lengths.

Reviewed By: shapovalov

Differential Revision: D46389092

fbshipit-source-id: d4f1963310065bd31c1c7fac1adfe11cbeaba606
This commit is contained in:
Emilien Garreau 2023-07-06 02:41:15 -07:00 committed by Facebook GitHub Bot
parent 5910d81b7b
commit 3d011a9198
5 changed files with 107 additions and 18 deletions

View File

@ -157,9 +157,13 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
else 0.0
)
ray_deltas = (
None if ray_bundle.bins is None else torch.diff(ray_bundle.bins, dim=-1)
)
output = self.raymarcher(
*implicit_functions[0](ray_bundle=ray_bundle),
ray_lengths=ray_bundle.lengths,
ray_deltas=ray_deltas,
density_noise_std=density_noise_std,
)
output.prev_stage = prev_stage

View File

@ -78,19 +78,28 @@ class RayPointRefiner(Configurable, torch.nn.Module):
"""
z_vals = input_ray_bundle.lengths
with torch.no_grad():
if self.blurpool_weights:
ray_weights = apply_blurpool_on_weights(ray_weights)
z_vals_mid = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5)
n_pts_per_ray = self.n_pts_per_ray
ray_weights = ray_weights.view(-1, ray_weights.shape[-1])
if input_ray_bundle.bins is None:
z_vals: torch.Tensor = input_ray_bundle.lengths
ray_weights = ray_weights[..., 1:-1]
bins = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5)
else:
z_vals = input_ray_bundle.bins
n_pts_per_ray += 1
bins = z_vals
z_samples = sample_pdf(
z_vals_mid.view(-1, z_vals_mid.shape[-1]),
ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1],
self.n_pts_per_ray,
bins.view(-1, bins.shape[-1]),
ray_weights,
n_pts_per_ray,
det=not self.random_sampling,
eps=self.sample_pdf_eps,
).view(*z_vals.shape[:-1], self.n_pts_per_ray)
).view(*z_vals.shape[:-1], n_pts_per_ray)
if self.add_input_samples:
z_vals = torch.cat((z_vals, z_samples), dim=-1)
else:
@ -98,9 +107,13 @@ class RayPointRefiner(Configurable, torch.nn.Module):
# Resort by depth.
z_vals, _ = torch.sort(z_vals, dim=-1)
new_bundle = ImplicitronRayBundle(**vars(input_ray_bundle))
new_bundle.lengths = z_vals
return new_bundle
kwargs_ray = dict(vars(input_ray_bundle))
if input_ray_bundle.bins is None:
kwargs_ray["lengths"] = z_vals
return ImplicitronRayBundle(**kwargs_ray)
kwargs_ray["bins"] = z_vals
del kwargs_ray["lengths"]
return ImplicitronRayBundle.from_bins(**kwargs_ray)
def apply_blurpool_on_weights(weights) -> torch.Tensor:

View File

@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable, Dict, Tuple
from typing import Any, Callable, Dict, Optional, Tuple
import torch
from pytorch3d.implicitron.models.renderer.base import RendererOutput
@ -119,6 +119,7 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
rays_features: torch.Tensor,
aux: Dict[str, Any],
ray_lengths: torch.Tensor,
ray_deltas: Optional[torch.Tensor] = None,
density_noise_std: float = 0.0,
**kwargs,
) -> RendererOutput:
@ -131,6 +132,9 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
aux: a dictionary with extra information.
ray_lengths: Per-ray depth values represented with a tensor
of shape `(..., n_points_per_ray, feature_dim)`.
ray_deltas: Optional differences between consecutive elements along the ray bundle
represented with a tensor of shape `(..., n_points_per_ray)`. If None,
these differences are computed from ray_lengths.
density_noise_std: the magnitude of the noise added to densities.
Returns:
@ -152,7 +156,8 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
density_1d=True,
)
ray_lengths_diffs = ray_lengths[..., 1:] - ray_lengths[..., :-1]
if ray_deltas is None:
ray_lengths_diffs = torch.diff(ray_lengths, dim=-1)
if self.replicate_last_interval:
last_interval = ray_lengths_diffs[..., -1:]
else:
@ -160,6 +165,8 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
ray_lengths[..., :1], self.background_opacity
)
deltas = torch.cat((ray_lengths_diffs, last_interval), dim=-1)
else:
deltas = ray_deltas
rays_densities = rays_densities[..., 0]

View File

@ -24,7 +24,7 @@ class HarmonicEmbedding(torch.nn.Module):
and the integrated position encoding in
`MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
During, the inference you can provide the extra argument `diag_cov`.
During the inference you can provide the extra argument `diag_cov`.
If `diag_cov is None`, it converts
rays parametrized with a `ray_bundle` to 3D points by

View File

@ -70,6 +70,71 @@ class TestRayPointRefiner(TestCaseMixin, unittest.TestCase):
(lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all()
)
def test_simple_use_bins(self):
"""
Same spirit than test_simple but use bins in the ImplicitronRayBunle.
It has been duplicated to avoid cognitive overload while reading the
test (lot of if else).
"""
length = 15
n_pts_per_ray = 10
for add_input_samples, use_blurpool in product([False, True], [False, True]):
ray_point_refiner = RayPointRefiner(
n_pts_per_ray=n_pts_per_ray,
random_sampling=False,
add_input_samples=add_input_samples,
)
bundle = ImplicitronRayBundle(
lengths=None,
bins=torch.arange(length + 1, dtype=torch.float32).expand(
3, 25, length + 1
),
origins=None,
directions=None,
xys=None,
camera_ids=None,
camera_counts=None,
)
weights = torch.ones(3, 25, length)
refined = ray_point_refiner(bundle, weights, blurpool_weights=use_blurpool)
self.assertIsNone(refined.directions)
self.assertIsNone(refined.origins)
self.assertIsNone(refined.xys)
expected_bins = torch.linspace(0, length, n_pts_per_ray + 1)
expected_bins = expected_bins.expand(3, 25, n_pts_per_ray + 1)
if add_input_samples:
expected_bins = torch.cat((bundle.bins, expected_bins), dim=-1).sort()[
0
]
full_expected = torch.lerp(
expected_bins[..., :-1], expected_bins[..., 1:], 0.5
)
self.assertClose(refined.lengths, full_expected)
ray_point_refiner_random = RayPointRefiner(
n_pts_per_ray=n_pts_per_ray,
random_sampling=True,
add_input_samples=add_input_samples,
)
refined_random = ray_point_refiner_random(
bundle, weights, blurpool_weights=use_blurpool
)
lengths_random = refined_random.lengths
self.assertEqual(lengths_random.shape, full_expected.shape)
if not add_input_samples:
self.assertGreater(lengths_random.min().item(), 0)
self.assertLess(lengths_random.max().item(), length)
# Check sorted
self.assertTrue(
(lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all()
)
def test_apply_blurpool_on_weights(self):
weights = torch.tensor(
[