Main training script

Summary: Implements the training script of NeRF.

Reviewed By: nikhilaravi

Differential Revision: D25684439

fbshipit-source-id: 8b19b6dc282eb6bf6e46ec4476bb0f13a84c90dd
This commit is contained in:
David Novotny
2021-02-02 05:42:59 -08:00
committed by Facebook GitHub Bot
parent 5b74911881
commit 9751f1f185
6 changed files with 466 additions and 1 deletions

View File

@@ -2,8 +2,11 @@
from typing import Tuple, List, Optional
import torch
from pytorch3d.renderer import ImplicitRenderer
from pytorch3d.renderer import ImplicitRenderer, ray_bundle_to_ray_points
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.structures import Pointclouds
from pytorch3d.vis.plotly_vis import plot_scene
from visdom import Visdom
from .implicit_function import NeuralRadianceField
from .raymarcher import EmissionAbsorptionNeRFRaymarcher
@@ -357,3 +360,68 @@ class RadianceFieldRenderer(torch.nn.Module):
)
return out, metrics
def visualize_nerf_outputs(
nerf_out: dict, output_cache: List, viz: Visdom, visdom_env: str
):
"""
Visualizes the outputs of the `RadianceFieldRenderer`.
Args:
nerf_out: An output of the validation rendering pass.
output_cache: A list with outputs of several training render passes.
viz: A visdom connection object.
visdom_env: The name of visdom environment for visualization.
"""
# Show the training images.
ims = torch.stack([o["image"] for o in output_cache])
ims = torch.cat(list(ims), dim=1)
viz.image(
ims.permute(2, 0, 1),
env=visdom_env,
win="images",
opts={"title": "train_images"},
)
# Show the coarse and fine renders together with the ground truth images.
ims_full = torch.cat(
[
nerf_out[imvar][0].permute(2, 0, 1).detach().cpu().clamp(0.0, 1.0)
for imvar in ("rgb_coarse", "rgb_fine", "rgb_gt")
],
dim=2,
)
viz.image(
ims_full,
env=visdom_env,
win="images_full",
opts={"title": "coarse | fine | target"},
)
# Make a 3D plot of training cameras and their emitted rays.
camera_trace = {
f"camera_{ci:03d}": o["camera"].cpu() for ci, o in enumerate(output_cache)
}
ray_pts_trace = {
f"ray_pts_{ci:03d}": Pointclouds(
ray_bundle_to_ray_points(o["coarse_ray_bundle"])
.detach()
.cpu()
.view(1, -1, 3)
)
for ci, o in enumerate(output_cache)
}
plotly_plot = plot_scene(
{
"training_scene": {
**camera_trace,
**ray_pts_trace,
},
},
pointcloud_max_points=5000,
pointcloud_marker_size=1,
camera_scale=0.3,
)
viz.plotlyplot(plotly_plot, env=visdom_env, win="scenes")