Adapt RayPointRefiner and RayMarcher to support bins.

Summary:
## Context

Bins are used in mipnerf to allow to manipulate easily intervals. For example, by doing the following, `bins[..., :-1]` you will obtain all the left coordinates of your intervals, while doing `bins[..., 1:]` is equals to the right coordinates of your intervals.

We introduce here the support of bins like in MipNerf implementation.

## RayPointRefiner

Small changes have been made to modify RayPointRefiner.
- If bins is None

```
mids = torch.lerp(ray_bundle.lengths[..., 1:], ray_bundle.lengths[…, :-1], 0.5)
z_samples = sample_pdf(
		mids, # [..., npt]
		weights[..., 1:-1], # [..., npt - 1]
               ….
            )
```

- If bins is not None
In the MipNerf implementation the sampling is done on all the bins. It allows us to use the full weights tensor without slashing it.

```
z_samples = sample_pdf(
		ray_bundle.bins, # [..., npt + 1]
		weights, # [..., npt]
               ...
            )
```

## RayMarcher

Add a ray_deltas optional argument. If None, keep the same deltas computation from ray_lengths.

Reviewed By: shapovalov

Differential Revision: D46389092

fbshipit-source-id: d4f1963310065bd31c1c7fac1adfe11cbeaba606
This commit is contained in:
Emilien Garreau
2023-07-06 02:41:15 -07:00
committed by Facebook GitHub Bot
parent 5910d81b7b
commit 3d011a9198
5 changed files with 107 additions and 18 deletions

View File

@@ -78,19 +78,28 @@ class RayPointRefiner(Configurable, torch.nn.Module):
"""
z_vals = input_ray_bundle.lengths
with torch.no_grad():
if self.blurpool_weights:
ray_weights = apply_blurpool_on_weights(ray_weights)
z_vals_mid = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5)
n_pts_per_ray = self.n_pts_per_ray
ray_weights = ray_weights.view(-1, ray_weights.shape[-1])
if input_ray_bundle.bins is None:
z_vals: torch.Tensor = input_ray_bundle.lengths
ray_weights = ray_weights[..., 1:-1]
bins = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5)
else:
z_vals = input_ray_bundle.bins
n_pts_per_ray += 1
bins = z_vals
z_samples = sample_pdf(
z_vals_mid.view(-1, z_vals_mid.shape[-1]),
ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1],
self.n_pts_per_ray,
bins.view(-1, bins.shape[-1]),
ray_weights,
n_pts_per_ray,
det=not self.random_sampling,
eps=self.sample_pdf_eps,
).view(*z_vals.shape[:-1], self.n_pts_per_ray)
).view(*z_vals.shape[:-1], n_pts_per_ray)
if self.add_input_samples:
z_vals = torch.cat((z_vals, z_samples), dim=-1)
else:
@@ -98,9 +107,13 @@ class RayPointRefiner(Configurable, torch.nn.Module):
# Resort by depth.
z_vals, _ = torch.sort(z_vals, dim=-1)
new_bundle = ImplicitronRayBundle(**vars(input_ray_bundle))
new_bundle.lengths = z_vals
return new_bundle
kwargs_ray = dict(vars(input_ray_bundle))
if input_ray_bundle.bins is None:
kwargs_ray["lengths"] = z_vals
return ImplicitronRayBundle(**kwargs_ray)
kwargs_ray["bins"] = z_vals
del kwargs_ray["lengths"]
return ImplicitronRayBundle.from_bins(**kwargs_ray)
def apply_blurpool_on_weights(weights) -> torch.Tensor: