mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 06:40:35 +08:00
Main training script
Summary: Implements the training script of NeRF. Reviewed By: nikhilaravi Differential Revision: D25684439 fbshipit-source-id: 8b19b6dc282eb6bf6e46ec4476bb0f13a84c90dd
This commit is contained in:
committed by
Facebook GitHub Bot
parent
5b74911881
commit
9751f1f185
@@ -2,8 +2,11 @@
|
||||
from typing import Tuple, List, Optional
|
||||
|
||||
import torch
|
||||
from pytorch3d.renderer import ImplicitRenderer
|
||||
from pytorch3d.renderer import ImplicitRenderer, ray_bundle_to_ray_points
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
from pytorch3d.structures import Pointclouds
|
||||
from pytorch3d.vis.plotly_vis import plot_scene
|
||||
from visdom import Visdom
|
||||
|
||||
from .implicit_function import NeuralRadianceField
|
||||
from .raymarcher import EmissionAbsorptionNeRFRaymarcher
|
||||
@@ -357,3 +360,68 @@ class RadianceFieldRenderer(torch.nn.Module):
|
||||
)
|
||||
|
||||
return out, metrics
|
||||
|
||||
|
||||
def visualize_nerf_outputs(
|
||||
nerf_out: dict, output_cache: List, viz: Visdom, visdom_env: str
|
||||
):
|
||||
"""
|
||||
Visualizes the outputs of the `RadianceFieldRenderer`.
|
||||
|
||||
Args:
|
||||
nerf_out: An output of the validation rendering pass.
|
||||
output_cache: A list with outputs of several training render passes.
|
||||
viz: A visdom connection object.
|
||||
visdom_env: The name of visdom environment for visualization.
|
||||
"""
|
||||
|
||||
# Show the training images.
|
||||
ims = torch.stack([o["image"] for o in output_cache])
|
||||
ims = torch.cat(list(ims), dim=1)
|
||||
viz.image(
|
||||
ims.permute(2, 0, 1),
|
||||
env=visdom_env,
|
||||
win="images",
|
||||
opts={"title": "train_images"},
|
||||
)
|
||||
|
||||
# Show the coarse and fine renders together with the ground truth images.
|
||||
ims_full = torch.cat(
|
||||
[
|
||||
nerf_out[imvar][0].permute(2, 0, 1).detach().cpu().clamp(0.0, 1.0)
|
||||
for imvar in ("rgb_coarse", "rgb_fine", "rgb_gt")
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
viz.image(
|
||||
ims_full,
|
||||
env=visdom_env,
|
||||
win="images_full",
|
||||
opts={"title": "coarse | fine | target"},
|
||||
)
|
||||
|
||||
# Make a 3D plot of training cameras and their emitted rays.
|
||||
camera_trace = {
|
||||
f"camera_{ci:03d}": o["camera"].cpu() for ci, o in enumerate(output_cache)
|
||||
}
|
||||
ray_pts_trace = {
|
||||
f"ray_pts_{ci:03d}": Pointclouds(
|
||||
ray_bundle_to_ray_points(o["coarse_ray_bundle"])
|
||||
.detach()
|
||||
.cpu()
|
||||
.view(1, -1, 3)
|
||||
)
|
||||
for ci, o in enumerate(output_cache)
|
||||
}
|
||||
plotly_plot = plot_scene(
|
||||
{
|
||||
"training_scene": {
|
||||
**camera_trace,
|
||||
**ray_pts_trace,
|
||||
},
|
||||
},
|
||||
pointcloud_max_points=5000,
|
||||
pointcloud_marker_size=1,
|
||||
camera_scale=0.3,
|
||||
)
|
||||
viz.plotlyplot(plotly_plot, env=visdom_env, win="scenes")
|
||||
|
||||
Reference in New Issue
Block a user