Adapt RayPointRefiner and RayMarcher to support bins.

Summary:
## Context

Bins are used in mipnerf to allow to manipulate easily intervals. For example, by doing the following, `bins[..., :-1]` you will obtain all the left coordinates of your intervals, while doing `bins[..., 1:]` is equals to the right coordinates of your intervals.

We introduce here the support of bins like in MipNerf implementation.

## RayPointRefiner

Small changes have been made to modify RayPointRefiner.
- If bins is None

```
mids = torch.lerp(ray_bundle.lengths[..., 1:], ray_bundle.lengths[…, :-1], 0.5)
z_samples = sample_pdf(
		mids, # [..., npt]
		weights[..., 1:-1], # [..., npt - 1]
               ….
            )
```

- If bins is not None
In the MipNerf implementation the sampling is done on all the bins. It allows us to use the full weights tensor without slashing it.

```
z_samples = sample_pdf(
		ray_bundle.bins, # [..., npt + 1]
		weights, # [..., npt]
               ...
            )
```

## RayMarcher

Add a ray_deltas optional argument. If None, keep the same deltas computation from ray_lengths.

Reviewed By: shapovalov

Differential Revision: D46389092

fbshipit-source-id: d4f1963310065bd31c1c7fac1adfe11cbeaba606
This commit is contained in:
Emilien Garreau 2023-07-06 02:41:15 -07:00 committed by Facebook GitHub Bot
parent 5910d81b7b
commit 3d011a9198
5 changed files with 107 additions and 18 deletions

View File

@ -157,9 +157,13 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
else 0.0 else 0.0
) )
ray_deltas = (
None if ray_bundle.bins is None else torch.diff(ray_bundle.bins, dim=-1)
)
output = self.raymarcher( output = self.raymarcher(
*implicit_functions[0](ray_bundle=ray_bundle), *implicit_functions[0](ray_bundle=ray_bundle),
ray_lengths=ray_bundle.lengths, ray_lengths=ray_bundle.lengths,
ray_deltas=ray_deltas,
density_noise_std=density_noise_std, density_noise_std=density_noise_std,
) )
output.prev_stage = prev_stage output.prev_stage = prev_stage

View File

@ -78,19 +78,28 @@ class RayPointRefiner(Configurable, torch.nn.Module):
""" """
z_vals = input_ray_bundle.lengths
with torch.no_grad(): with torch.no_grad():
if self.blurpool_weights: if self.blurpool_weights:
ray_weights = apply_blurpool_on_weights(ray_weights) ray_weights = apply_blurpool_on_weights(ray_weights)
z_vals_mid = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5) n_pts_per_ray = self.n_pts_per_ray
ray_weights = ray_weights.view(-1, ray_weights.shape[-1])
if input_ray_bundle.bins is None:
z_vals: torch.Tensor = input_ray_bundle.lengths
ray_weights = ray_weights[..., 1:-1]
bins = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5)
else:
z_vals = input_ray_bundle.bins
n_pts_per_ray += 1
bins = z_vals
z_samples = sample_pdf( z_samples = sample_pdf(
z_vals_mid.view(-1, z_vals_mid.shape[-1]), bins.view(-1, bins.shape[-1]),
ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1], ray_weights,
self.n_pts_per_ray, n_pts_per_ray,
det=not self.random_sampling, det=not self.random_sampling,
eps=self.sample_pdf_eps, eps=self.sample_pdf_eps,
).view(*z_vals.shape[:-1], self.n_pts_per_ray) ).view(*z_vals.shape[:-1], n_pts_per_ray)
if self.add_input_samples: if self.add_input_samples:
z_vals = torch.cat((z_vals, z_samples), dim=-1) z_vals = torch.cat((z_vals, z_samples), dim=-1)
else: else:
@ -98,9 +107,13 @@ class RayPointRefiner(Configurable, torch.nn.Module):
# Resort by depth. # Resort by depth.
z_vals, _ = torch.sort(z_vals, dim=-1) z_vals, _ = torch.sort(z_vals, dim=-1)
new_bundle = ImplicitronRayBundle(**vars(input_ray_bundle)) kwargs_ray = dict(vars(input_ray_bundle))
new_bundle.lengths = z_vals if input_ray_bundle.bins is None:
return new_bundle kwargs_ray["lengths"] = z_vals
return ImplicitronRayBundle(**kwargs_ray)
kwargs_ray["bins"] = z_vals
del kwargs_ray["lengths"]
return ImplicitronRayBundle.from_bins(**kwargs_ray)
def apply_blurpool_on_weights(weights) -> torch.Tensor: def apply_blurpool_on_weights(weights) -> torch.Tensor:

View File

@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Any, Callable, Dict, Tuple from typing import Any, Callable, Dict, Optional, Tuple
import torch import torch
from pytorch3d.implicitron.models.renderer.base import RendererOutput from pytorch3d.implicitron.models.renderer.base import RendererOutput
@ -119,6 +119,7 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
rays_features: torch.Tensor, rays_features: torch.Tensor,
aux: Dict[str, Any], aux: Dict[str, Any],
ray_lengths: torch.Tensor, ray_lengths: torch.Tensor,
ray_deltas: Optional[torch.Tensor] = None,
density_noise_std: float = 0.0, density_noise_std: float = 0.0,
**kwargs, **kwargs,
) -> RendererOutput: ) -> RendererOutput:
@ -131,6 +132,9 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
aux: a dictionary with extra information. aux: a dictionary with extra information.
ray_lengths: Per-ray depth values represented with a tensor ray_lengths: Per-ray depth values represented with a tensor
of shape `(..., n_points_per_ray, feature_dim)`. of shape `(..., n_points_per_ray, feature_dim)`.
ray_deltas: Optional differences between consecutive elements along the ray bundle
represented with a tensor of shape `(..., n_points_per_ray)`. If None,
these differences are computed from ray_lengths.
density_noise_std: the magnitude of the noise added to densities. density_noise_std: the magnitude of the noise added to densities.
Returns: Returns:
@ -152,14 +156,17 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
density_1d=True, density_1d=True,
) )
ray_lengths_diffs = ray_lengths[..., 1:] - ray_lengths[..., :-1] if ray_deltas is None:
if self.replicate_last_interval: ray_lengths_diffs = torch.diff(ray_lengths, dim=-1)
last_interval = ray_lengths_diffs[..., -1:] if self.replicate_last_interval:
last_interval = ray_lengths_diffs[..., -1:]
else:
last_interval = torch.full_like(
ray_lengths[..., :1], self.background_opacity
)
deltas = torch.cat((ray_lengths_diffs, last_interval), dim=-1)
else: else:
last_interval = torch.full_like( deltas = ray_deltas
ray_lengths[..., :1], self.background_opacity
)
deltas = torch.cat((ray_lengths_diffs, last_interval), dim=-1)
rays_densities = rays_densities[..., 0] rays_densities = rays_densities[..., 0]

View File

@ -24,7 +24,7 @@ class HarmonicEmbedding(torch.nn.Module):
and the integrated position encoding in and the integrated position encoding in
`MIP-NeRF <https://arxiv.org/abs/2103.13415>`_. `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
During, the inference you can provide the extra argument `diag_cov`. During the inference you can provide the extra argument `diag_cov`.
If `diag_cov is None`, it converts If `diag_cov is None`, it converts
rays parametrized with a `ray_bundle` to 3D points by rays parametrized with a `ray_bundle` to 3D points by

View File

@ -70,6 +70,71 @@ class TestRayPointRefiner(TestCaseMixin, unittest.TestCase):
(lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all() (lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all()
) )
def test_simple_use_bins(self):
"""
Same spirit than test_simple but use bins in the ImplicitronRayBunle.
It has been duplicated to avoid cognitive overload while reading the
test (lot of if else).
"""
length = 15
n_pts_per_ray = 10
for add_input_samples, use_blurpool in product([False, True], [False, True]):
ray_point_refiner = RayPointRefiner(
n_pts_per_ray=n_pts_per_ray,
random_sampling=False,
add_input_samples=add_input_samples,
)
bundle = ImplicitronRayBundle(
lengths=None,
bins=torch.arange(length + 1, dtype=torch.float32).expand(
3, 25, length + 1
),
origins=None,
directions=None,
xys=None,
camera_ids=None,
camera_counts=None,
)
weights = torch.ones(3, 25, length)
refined = ray_point_refiner(bundle, weights, blurpool_weights=use_blurpool)
self.assertIsNone(refined.directions)
self.assertIsNone(refined.origins)
self.assertIsNone(refined.xys)
expected_bins = torch.linspace(0, length, n_pts_per_ray + 1)
expected_bins = expected_bins.expand(3, 25, n_pts_per_ray + 1)
if add_input_samples:
expected_bins = torch.cat((bundle.bins, expected_bins), dim=-1).sort()[
0
]
full_expected = torch.lerp(
expected_bins[..., :-1], expected_bins[..., 1:], 0.5
)
self.assertClose(refined.lengths, full_expected)
ray_point_refiner_random = RayPointRefiner(
n_pts_per_ray=n_pts_per_ray,
random_sampling=True,
add_input_samples=add_input_samples,
)
refined_random = ray_point_refiner_random(
bundle, weights, blurpool_weights=use_blurpool
)
lengths_random = refined_random.lengths
self.assertEqual(lengths_random.shape, full_expected.shape)
if not add_input_samples:
self.assertGreater(lengths_random.min().item(), 0)
self.assertLess(lengths_random.max().item(), length)
# Check sorted
self.assertTrue(
(lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all()
)
def test_apply_blurpool_on_weights(self): def test_apply_blurpool_on_weights(self):
weights = torch.tensor( weights = torch.tensor(
[ [