mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Main training script
Summary: Implements the training script of NeRF. Reviewed By: nikhilaravi Differential Revision: D25684439 fbshipit-source-id: 8b19b6dc282eb6bf6e46ec4476bb0f13a84c90dd
This commit is contained in:
parent
5b74911881
commit
9751f1f185
0
projects/nerf/__init__.py
Normal file
0
projects/nerf/__init__.py
Normal file
44
projects/nerf/configs/fern.yaml
Normal file
44
projects/nerf/configs/fern.yaml
Normal file
@ -0,0 +1,44 @@
|
||||
seed: 3
|
||||
resume: True
|
||||
stats_print_interval: 10
|
||||
validation_epoch_interval: 150
|
||||
checkpoint_epoch_interval: 150
|
||||
checkpoint_path: 'checkpoints/fern_pt3d.pth'
|
||||
data:
|
||||
dataset_name: 'fern'
|
||||
image_size: [378, 504] # [height, width]
|
||||
precache_rays: True
|
||||
test:
|
||||
mode: 'evaluation'
|
||||
trajectory_type: 'figure_eight'
|
||||
up: [0.0, 1.0, 0.0]
|
||||
scene_center: [0.0, 0.0, -2.0]
|
||||
n_frames: 100
|
||||
fps: 20
|
||||
optimizer:
|
||||
max_epochs: 37500
|
||||
lr: 0.0005
|
||||
lr_scheduler_step_size: 12500
|
||||
lr_scheduler_gamma: 0.1
|
||||
visualization:
|
||||
history_size: 10
|
||||
visdom: True
|
||||
visdom_server: 'localhost'
|
||||
visdom_port: 8097
|
||||
visdom_env: 'nerf_pytorch3d'
|
||||
raysampler:
|
||||
n_pts_per_ray: 64
|
||||
n_pts_per_ray_fine: 64
|
||||
n_rays_per_image: 1024
|
||||
min_depth: 1.2
|
||||
max_depth: 6.28
|
||||
stratified: True
|
||||
stratified_test: False
|
||||
chunk_size_test: 6000
|
||||
implicit_function:
|
||||
n_harmonic_functions_xyz: 10
|
||||
n_harmonic_functions_dir: 4
|
||||
n_hidden_neurons_xyz: 256
|
||||
n_hidden_neurons_dir: 128
|
||||
density_noise_std: 0.0
|
||||
n_layers_xyz: 8
|
44
projects/nerf/configs/lego.yaml
Normal file
44
projects/nerf/configs/lego.yaml
Normal file
@ -0,0 +1,44 @@
|
||||
seed: 3
|
||||
resume: True
|
||||
stats_print_interval: 10
|
||||
validation_epoch_interval: 30
|
||||
checkpoint_epoch_interval: 30
|
||||
checkpoint_path: 'checkpoints/lego_pt3d.pth'
|
||||
data:
|
||||
dataset_name: 'lego'
|
||||
image_size: [800, 800] # [height, width]
|
||||
precache_rays: True
|
||||
test:
|
||||
mode: 'evaluation'
|
||||
trajectory_type: 'circular'
|
||||
up: [0.0, 0.0, 1.0]
|
||||
scene_center: [0.0, 0.0, 0.0]
|
||||
n_frames: 100
|
||||
fps: 20
|
||||
optimizer:
|
||||
max_epochs: 20000
|
||||
lr: 0.0005
|
||||
lr_scheduler_step_size: 5000
|
||||
lr_scheduler_gamma: 0.1
|
||||
visualization:
|
||||
history_size: 10
|
||||
visdom: True
|
||||
visdom_server: 'localhost'
|
||||
visdom_port: 8097
|
||||
visdom_env: 'nerf_pytorch3d'
|
||||
raysampler:
|
||||
n_pts_per_ray: 64
|
||||
n_pts_per_ray_fine: 64
|
||||
n_rays_per_image: 1024
|
||||
min_depth: 2.0
|
||||
max_depth: 6.0
|
||||
stratified: True
|
||||
stratified_test: False
|
||||
chunk_size_test: 6000
|
||||
implicit_function:
|
||||
n_harmonic_functions_xyz: 10
|
||||
n_harmonic_functions_dir: 4
|
||||
n_hidden_neurons_xyz: 256
|
||||
n_hidden_neurons_dir: 128
|
||||
density_noise_std: 0.0
|
||||
n_layers_xyz: 8
|
44
projects/nerf/configs/pt3logo.yaml
Normal file
44
projects/nerf/configs/pt3logo.yaml
Normal file
@ -0,0 +1,44 @@
|
||||
seed: 3
|
||||
resume: True
|
||||
stats_print_interval: 10
|
||||
validation_epoch_interval: 30
|
||||
checkpoint_epoch_interval: 30
|
||||
checkpoint_path: 'checkpoints/pt3logo_pt3d.pth'
|
||||
data:
|
||||
dataset_name: 'pt3logo'
|
||||
image_size: [512, 1024] # [height, width]
|
||||
precache_rays: True
|
||||
test:
|
||||
mode: 'export_video'
|
||||
trajectory_type: 'figure_eight'
|
||||
up: [0.0, -1.0, 0.0]
|
||||
scene_center: [0.0, 0.0, 0.0]
|
||||
n_frames: 100
|
||||
fps: 20
|
||||
optimizer:
|
||||
max_epochs: 100000
|
||||
lr: 0.0005
|
||||
lr_scheduler_step_size: 10000
|
||||
lr_scheduler_gamma: 0.1
|
||||
visualization:
|
||||
history_size: 20
|
||||
visdom: True
|
||||
visdom_server: 'localhost'
|
||||
visdom_port: 8097
|
||||
visdom_env: 'nerf_pytorch3d'
|
||||
raysampler:
|
||||
n_pts_per_ray: 64
|
||||
n_pts_per_ray_fine: 64
|
||||
n_rays_per_image: 1024
|
||||
min_depth: 8.0
|
||||
max_depth: 23.0
|
||||
stratified: True
|
||||
stratified_test: False
|
||||
chunk_size_test: 6000
|
||||
implicit_function:
|
||||
n_harmonic_functions_xyz: 10
|
||||
n_harmonic_functions_dir: 4
|
||||
n_hidden_neurons_xyz: 256
|
||||
n_hidden_neurons_dir: 128
|
||||
density_noise_std: 0.0
|
||||
n_layers_xyz: 8
|
@ -2,8 +2,11 @@
|
||||
from typing import Tuple, List, Optional
|
||||
|
||||
import torch
|
||||
from pytorch3d.renderer import ImplicitRenderer
|
||||
from pytorch3d.renderer import ImplicitRenderer, ray_bundle_to_ray_points
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
from pytorch3d.structures import Pointclouds
|
||||
from pytorch3d.vis.plotly_vis import plot_scene
|
||||
from visdom import Visdom
|
||||
|
||||
from .implicit_function import NeuralRadianceField
|
||||
from .raymarcher import EmissionAbsorptionNeRFRaymarcher
|
||||
@ -357,3 +360,68 @@ class RadianceFieldRenderer(torch.nn.Module):
|
||||
)
|
||||
|
||||
return out, metrics
|
||||
|
||||
|
||||
def visualize_nerf_outputs(
|
||||
nerf_out: dict, output_cache: List, viz: Visdom, visdom_env: str
|
||||
):
|
||||
"""
|
||||
Visualizes the outputs of the `RadianceFieldRenderer`.
|
||||
|
||||
Args:
|
||||
nerf_out: An output of the validation rendering pass.
|
||||
output_cache: A list with outputs of several training render passes.
|
||||
viz: A visdom connection object.
|
||||
visdom_env: The name of visdom environment for visualization.
|
||||
"""
|
||||
|
||||
# Show the training images.
|
||||
ims = torch.stack([o["image"] for o in output_cache])
|
||||
ims = torch.cat(list(ims), dim=1)
|
||||
viz.image(
|
||||
ims.permute(2, 0, 1),
|
||||
env=visdom_env,
|
||||
win="images",
|
||||
opts={"title": "train_images"},
|
||||
)
|
||||
|
||||
# Show the coarse and fine renders together with the ground truth images.
|
||||
ims_full = torch.cat(
|
||||
[
|
||||
nerf_out[imvar][0].permute(2, 0, 1).detach().cpu().clamp(0.0, 1.0)
|
||||
for imvar in ("rgb_coarse", "rgb_fine", "rgb_gt")
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
viz.image(
|
||||
ims_full,
|
||||
env=visdom_env,
|
||||
win="images_full",
|
||||
opts={"title": "coarse | fine | target"},
|
||||
)
|
||||
|
||||
# Make a 3D plot of training cameras and their emitted rays.
|
||||
camera_trace = {
|
||||
f"camera_{ci:03d}": o["camera"].cpu() for ci, o in enumerate(output_cache)
|
||||
}
|
||||
ray_pts_trace = {
|
||||
f"ray_pts_{ci:03d}": Pointclouds(
|
||||
ray_bundle_to_ray_points(o["coarse_ray_bundle"])
|
||||
.detach()
|
||||
.cpu()
|
||||
.view(1, -1, 3)
|
||||
)
|
||||
for ci, o in enumerate(output_cache)
|
||||
}
|
||||
plotly_plot = plot_scene(
|
||||
{
|
||||
"training_scene": {
|
||||
**camera_trace,
|
||||
**ray_pts_trace,
|
||||
},
|
||||
},
|
||||
pointcloud_max_points=5000,
|
||||
pointcloud_marker_size=1,
|
||||
camera_scale=0.3,
|
||||
)
|
||||
viz.plotlyplot(plotly_plot, env=visdom_env, win="scenes")
|
||||
|
265
projects/nerf/train_nerf.py
Normal file
265
projects/nerf/train_nerf.py
Normal file
@ -0,0 +1,265 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
import collections
|
||||
import os
|
||||
import pickle
|
||||
import warnings
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
import torch
|
||||
from nerf.dataset import get_nerf_datasets, trivial_collate
|
||||
from nerf.nerf_renderer import RadianceFieldRenderer, visualize_nerf_outputs
|
||||
from nerf.stats import Stats
|
||||
from omegaconf import DictConfig
|
||||
from visdom import Visdom
|
||||
|
||||
CONFIG_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs")
|
||||
|
||||
|
||||
@hydra.main(config_path=CONFIG_DIR, config_name="lego")
|
||||
def main(cfg: DictConfig):
|
||||
|
||||
# Set the relevant seeds for reproducibility.
|
||||
np.random.seed(cfg.seed)
|
||||
torch.manual_seed(cfg.seed)
|
||||
|
||||
# Device on which to run.
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
warnings.warn(
|
||||
"Please note that although executing on CPU is supported,"
|
||||
+ "the training is unlikely to finish in resonable time."
|
||||
)
|
||||
device = "cpu"
|
||||
|
||||
# Initialize the Radiance Field model.
|
||||
model = RadianceFieldRenderer(
|
||||
image_size=cfg.data.image_size,
|
||||
n_pts_per_ray=cfg.raysampler.n_pts_per_ray,
|
||||
n_pts_per_ray_fine=cfg.raysampler.n_pts_per_ray,
|
||||
n_rays_per_image=cfg.raysampler.n_rays_per_image,
|
||||
min_depth=cfg.raysampler.min_depth,
|
||||
max_depth=cfg.raysampler.max_depth,
|
||||
stratified=cfg.raysampler.stratified,
|
||||
stratified_test=cfg.raysampler.stratified_test,
|
||||
chunk_size_test=cfg.raysampler.chunk_size_test,
|
||||
n_harmonic_functions_xyz=cfg.implicit_function.n_harmonic_functions_xyz,
|
||||
n_harmonic_functions_dir=cfg.implicit_function.n_harmonic_functions_dir,
|
||||
n_hidden_neurons_xyz=cfg.implicit_function.n_hidden_neurons_xyz,
|
||||
n_hidden_neurons_dir=cfg.implicit_function.n_hidden_neurons_dir,
|
||||
n_layers_xyz=cfg.implicit_function.n_layers_xyz,
|
||||
density_noise_std=cfg.implicit_function.density_noise_std,
|
||||
)
|
||||
|
||||
# Move the model to the relevant device.
|
||||
model.to(device)
|
||||
|
||||
# Init stats to None before loading.
|
||||
stats = None
|
||||
optimizer_state_dict = None
|
||||
start_epoch = 0
|
||||
|
||||
checkpoint_path = os.path.join(hydra.utils.get_original_cwd(), cfg.checkpoint_path)
|
||||
if len(cfg.checkpoint_path) > 0:
|
||||
# Make the root of the experiment directory.
|
||||
checkpoint_dir = os.path.split(checkpoint_path)[0]
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# Resume training if requested.
|
||||
if cfg.resume and os.path.isfile(checkpoint_path):
|
||||
print(f"Resuming from checkpoint {checkpoint_path}.")
|
||||
loaded_data = torch.load(checkpoint_path)
|
||||
model.load_state_dict(loaded_data["model"])
|
||||
stats = pickle.loads(loaded_data["stats"])
|
||||
print(f" => resuming from epoch {stats.epoch}.")
|
||||
optimizer_state_dict = loaded_data["optimizer"]
|
||||
start_epoch = stats.epoch
|
||||
|
||||
# Initialize the optimizer.
|
||||
optimizer = torch.optim.Adam(
|
||||
model.parameters(),
|
||||
lr=cfg.optimizer.lr,
|
||||
)
|
||||
|
||||
# Load the optimizer state dict in case we are resuming.
|
||||
if optimizer_state_dict is not None:
|
||||
optimizer.load_state_dict(optimizer_state_dict)
|
||||
optimizer.last_epoch = start_epoch
|
||||
|
||||
# Init the stats object.
|
||||
if stats is None:
|
||||
stats = Stats(
|
||||
["loss", "mse_coarse", "mse_fine", "psnr_coarse", "psnr_fine", "sec/it"],
|
||||
)
|
||||
|
||||
# Learning rate scheduler setup.
|
||||
|
||||
# Following the original code, we use exponential decay of the
|
||||
# learning rate: current_lr = base_lr * gamma ** (epoch / step_size)
|
||||
def lr_lambda(epoch):
|
||||
return cfg.optimizer.lr_scheduler_gamma ** (
|
||||
epoch / cfg.optimizer.lr_scheduler_step_size
|
||||
)
|
||||
|
||||
# The learning rate scheduling is implemented with LambdaLR PyTorch scheduler.
|
||||
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||
optimizer, lr_lambda, last_epoch=start_epoch - 1, verbose=False
|
||||
)
|
||||
|
||||
# Initialize the cache for storing variables needed for visulization.
|
||||
visuals_cache = collections.deque(maxlen=cfg.visualization.history_size)
|
||||
|
||||
# Init the visualization visdom env.
|
||||
if cfg.visualization.visdom:
|
||||
viz = Visdom(
|
||||
server=cfg.visualization.visdom_server,
|
||||
port=cfg.visualization.visdom_port,
|
||||
use_incoming_socket=False,
|
||||
)
|
||||
else:
|
||||
viz = None
|
||||
|
||||
# Load the training/validation data.
|
||||
train_dataset, val_dataset, _ = get_nerf_datasets(
|
||||
dataset_name=cfg.data.dataset_name,
|
||||
image_size=cfg.data.image_size,
|
||||
)
|
||||
|
||||
if cfg.data.precache_rays:
|
||||
# Precache the projection rays.
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for dataset in (train_dataset, val_dataset):
|
||||
cache_cameras = [e["camera"].to(device) for e in dataset]
|
||||
cache_camera_hashes = [e["camera_idx"] for e in dataset]
|
||||
model.precache_rays(cache_cameras, cache_camera_hashes)
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
collate_fn=trivial_collate,
|
||||
)
|
||||
|
||||
# The validation dataloader is just an endless stream of random samples.
|
||||
val_dataloader = torch.utils.data.DataLoader(
|
||||
val_dataset,
|
||||
batch_size=1,
|
||||
num_workers=0,
|
||||
collate_fn=trivial_collate,
|
||||
sampler=torch.utils.data.RandomSampler(
|
||||
val_dataset,
|
||||
replacement=True,
|
||||
num_samples=cfg.optimizer.max_epochs,
|
||||
),
|
||||
)
|
||||
|
||||
# Set the model to the training mode.
|
||||
model.train()
|
||||
|
||||
# Run the main training loop.
|
||||
for epoch in range(start_epoch, cfg.optimizer.max_epochs):
|
||||
stats.new_epoch() # Init a new epoch.
|
||||
for iteration, batch in enumerate(train_dataloader):
|
||||
image, camera, camera_idx = batch[0].values()
|
||||
image = image.to(device)
|
||||
camera = camera.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Run the forward pass of the model.
|
||||
nerf_out, metrics = model(
|
||||
camera_idx if cfg.data.precache_rays else None,
|
||||
camera,
|
||||
image,
|
||||
)
|
||||
|
||||
# The loss is a sum of coarse and fine MSEs
|
||||
loss = metrics["mse_coarse"] + metrics["mse_fine"]
|
||||
|
||||
# Take the training step.
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Update stats with the current metrics.
|
||||
stats.update(
|
||||
{"loss": float(loss), **metrics},
|
||||
stat_set="train",
|
||||
)
|
||||
|
||||
if iteration % cfg.stats_print_interval == 0:
|
||||
stats.print(stat_set="train")
|
||||
|
||||
# Update the visualisatioon cache.
|
||||
visuals_cache.append(
|
||||
{
|
||||
"camera": camera.cpu(),
|
||||
"camera_idx": camera_idx,
|
||||
"image": image.cpu().detach(),
|
||||
"rgb_fine": nerf_out["rgb_fine"].cpu().detach(),
|
||||
"rgb_coarse": nerf_out["rgb_coarse"].cpu().detach(),
|
||||
"rgb_gt": nerf_out["rgb_gt"].cpu().detach(),
|
||||
"coarse_ray_bundle": nerf_out["coarse_ray_bundle"],
|
||||
}
|
||||
)
|
||||
|
||||
# Adjust the learning rate.
|
||||
lr_scheduler.step()
|
||||
|
||||
# Validation
|
||||
if epoch % cfg.validation_epoch_interval == 0 and epoch > 0:
|
||||
|
||||
# Sample a validation camera/image.
|
||||
val_batch = next(val_dataloader.__iter__())
|
||||
val_image, val_camera, camera_idx = val_batch[0].values()
|
||||
val_image = val_image.to(device)
|
||||
val_camera = val_camera.to(device)
|
||||
|
||||
# Activate eval mode of the model (allows to do a full rendering pass).
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
val_nerf_out, val_metrics = model(
|
||||
camera_idx if cfg.data.precache_rays else None,
|
||||
val_camera,
|
||||
val_image,
|
||||
)
|
||||
|
||||
# Update stats with the validation metrics.
|
||||
stats.update(val_metrics, stat_set="val")
|
||||
stats.print(stat_set="val")
|
||||
|
||||
if viz is not None:
|
||||
# Plot that loss curves into visdom.
|
||||
stats.plot_stats(
|
||||
viz=viz,
|
||||
visdom_env=cfg.visualization.visdom_env,
|
||||
plot_file=None,
|
||||
)
|
||||
# Visualize the intermediate results.
|
||||
visualize_nerf_outputs(
|
||||
val_nerf_out, visuals_cache, viz, cfg.visualization.visdom_env
|
||||
)
|
||||
|
||||
# Set the model back to train mode.
|
||||
model.train()
|
||||
|
||||
# Checkpoint.
|
||||
if (
|
||||
epoch % cfg.checkpoint_epoch_interval == 0
|
||||
and len(cfg.checkpoint_path) > 0
|
||||
and epoch > 0
|
||||
):
|
||||
print(f"Storing checkpoint {checkpoint_path}.")
|
||||
data_to_store = {
|
||||
"model": model.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"stats": pickle.dumps(stats),
|
||||
}
|
||||
torch.save(data_to_store, checkpoint_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user