mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-02 18:26:01 +08:00
Adapt RayPointRefiner and RayMarcher to support bins.
Summary:
## Context
Bins are used in mipnerf to allow to manipulate easily intervals. For example, by doing the following, `bins[..., :-1]` you will obtain all the left coordinates of your intervals, while doing `bins[..., 1:]` is equals to the right coordinates of your intervals.
We introduce here the support of bins like in MipNerf implementation.
## RayPointRefiner
Small changes have been made to modify RayPointRefiner.
- If bins is None
```
mids = torch.lerp(ray_bundle.lengths[..., 1:], ray_bundle.lengths[…, :-1], 0.5)
z_samples = sample_pdf(
mids, # [..., npt]
weights[..., 1:-1], # [..., npt - 1]
….
)
```
- If bins is not None
In the MipNerf implementation the sampling is done on all the bins. It allows us to use the full weights tensor without slashing it.
```
z_samples = sample_pdf(
ray_bundle.bins, # [..., npt + 1]
weights, # [..., npt]
...
)
```
## RayMarcher
Add a ray_deltas optional argument. If None, keep the same deltas computation from ray_lengths.
Reviewed By: shapovalov
Differential Revision: D46389092
fbshipit-source-id: d4f1963310065bd31c1c7fac1adfe11cbeaba606
This commit is contained in:
committed by
Facebook GitHub Bot
parent
5910d81b7b
commit
3d011a9198
@@ -78,19 +78,28 @@ class RayPointRefiner(Configurable, torch.nn.Module):
|
||||
|
||||
"""
|
||||
|
||||
z_vals = input_ray_bundle.lengths
|
||||
with torch.no_grad():
|
||||
if self.blurpool_weights:
|
||||
ray_weights = apply_blurpool_on_weights(ray_weights)
|
||||
|
||||
z_vals_mid = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5)
|
||||
n_pts_per_ray = self.n_pts_per_ray
|
||||
ray_weights = ray_weights.view(-1, ray_weights.shape[-1])
|
||||
if input_ray_bundle.bins is None:
|
||||
z_vals: torch.Tensor = input_ray_bundle.lengths
|
||||
ray_weights = ray_weights[..., 1:-1]
|
||||
bins = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5)
|
||||
else:
|
||||
z_vals = input_ray_bundle.bins
|
||||
n_pts_per_ray += 1
|
||||
bins = z_vals
|
||||
z_samples = sample_pdf(
|
||||
z_vals_mid.view(-1, z_vals_mid.shape[-1]),
|
||||
ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1],
|
||||
self.n_pts_per_ray,
|
||||
bins.view(-1, bins.shape[-1]),
|
||||
ray_weights,
|
||||
n_pts_per_ray,
|
||||
det=not self.random_sampling,
|
||||
eps=self.sample_pdf_eps,
|
||||
).view(*z_vals.shape[:-1], self.n_pts_per_ray)
|
||||
).view(*z_vals.shape[:-1], n_pts_per_ray)
|
||||
|
||||
if self.add_input_samples:
|
||||
z_vals = torch.cat((z_vals, z_samples), dim=-1)
|
||||
else:
|
||||
@@ -98,9 +107,13 @@ class RayPointRefiner(Configurable, torch.nn.Module):
|
||||
# Resort by depth.
|
||||
z_vals, _ = torch.sort(z_vals, dim=-1)
|
||||
|
||||
new_bundle = ImplicitronRayBundle(**vars(input_ray_bundle))
|
||||
new_bundle.lengths = z_vals
|
||||
return new_bundle
|
||||
kwargs_ray = dict(vars(input_ray_bundle))
|
||||
if input_ray_bundle.bins is None:
|
||||
kwargs_ray["lengths"] = z_vals
|
||||
return ImplicitronRayBundle(**kwargs_ray)
|
||||
kwargs_ray["bins"] = z_vals
|
||||
del kwargs_ray["lengths"]
|
||||
return ImplicitronRayBundle.from_bins(**kwargs_ray)
|
||||
|
||||
|
||||
def apply_blurpool_on_weights(weights) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user