diff --git a/projects/nerf/__init__.py b/projects/nerf/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/projects/nerf/configs/fern.yaml b/projects/nerf/configs/fern.yaml new file mode 100644 index 00000000..dd1db9ff --- /dev/null +++ b/projects/nerf/configs/fern.yaml @@ -0,0 +1,44 @@ +seed: 3 +resume: True +stats_print_interval: 10 +validation_epoch_interval: 150 +checkpoint_epoch_interval: 150 +checkpoint_path: 'checkpoints/fern_pt3d.pth' +data: + dataset_name: 'fern' + image_size: [378, 504] # [height, width] + precache_rays: True +test: + mode: 'evaluation' + trajectory_type: 'figure_eight' + up: [0.0, 1.0, 0.0] + scene_center: [0.0, 0.0, -2.0] + n_frames: 100 + fps: 20 +optimizer: + max_epochs: 37500 + lr: 0.0005 + lr_scheduler_step_size: 12500 + lr_scheduler_gamma: 0.1 +visualization: + history_size: 10 + visdom: True + visdom_server: 'localhost' + visdom_port: 8097 + visdom_env: 'nerf_pytorch3d' +raysampler: + n_pts_per_ray: 64 + n_pts_per_ray_fine: 64 + n_rays_per_image: 1024 + min_depth: 1.2 + max_depth: 6.28 + stratified: True + stratified_test: False + chunk_size_test: 6000 +implicit_function: + n_harmonic_functions_xyz: 10 + n_harmonic_functions_dir: 4 + n_hidden_neurons_xyz: 256 + n_hidden_neurons_dir: 128 + density_noise_std: 0.0 + n_layers_xyz: 8 diff --git a/projects/nerf/configs/lego.yaml b/projects/nerf/configs/lego.yaml new file mode 100644 index 00000000..f6483277 --- /dev/null +++ b/projects/nerf/configs/lego.yaml @@ -0,0 +1,44 @@ +seed: 3 +resume: True +stats_print_interval: 10 +validation_epoch_interval: 30 +checkpoint_epoch_interval: 30 +checkpoint_path: 'checkpoints/lego_pt3d.pth' +data: + dataset_name: 'lego' + image_size: [800, 800] # [height, width] + precache_rays: True +test: + mode: 'evaluation' + trajectory_type: 'circular' + up: [0.0, 0.0, 1.0] + scene_center: [0.0, 0.0, 0.0] + n_frames: 100 + fps: 20 +optimizer: + max_epochs: 20000 + lr: 0.0005 + lr_scheduler_step_size: 5000 + lr_scheduler_gamma: 0.1 +visualization: + history_size: 10 + visdom: True + visdom_server: 'localhost' + visdom_port: 8097 + visdom_env: 'nerf_pytorch3d' +raysampler: + n_pts_per_ray: 64 + n_pts_per_ray_fine: 64 + n_rays_per_image: 1024 + min_depth: 2.0 + max_depth: 6.0 + stratified: True + stratified_test: False + chunk_size_test: 6000 +implicit_function: + n_harmonic_functions_xyz: 10 + n_harmonic_functions_dir: 4 + n_hidden_neurons_xyz: 256 + n_hidden_neurons_dir: 128 + density_noise_std: 0.0 + n_layers_xyz: 8 diff --git a/projects/nerf/configs/pt3logo.yaml b/projects/nerf/configs/pt3logo.yaml new file mode 100644 index 00000000..7c5ff708 --- /dev/null +++ b/projects/nerf/configs/pt3logo.yaml @@ -0,0 +1,44 @@ +seed: 3 +resume: True +stats_print_interval: 10 +validation_epoch_interval: 30 +checkpoint_epoch_interval: 30 +checkpoint_path: 'checkpoints/pt3logo_pt3d.pth' +data: + dataset_name: 'pt3logo' + image_size: [512, 1024] # [height, width] + precache_rays: True +test: + mode: 'export_video' + trajectory_type: 'figure_eight' + up: [0.0, -1.0, 0.0] + scene_center: [0.0, 0.0, 0.0] + n_frames: 100 + fps: 20 +optimizer: + max_epochs: 100000 + lr: 0.0005 + lr_scheduler_step_size: 10000 + lr_scheduler_gamma: 0.1 +visualization: + history_size: 20 + visdom: True + visdom_server: 'localhost' + visdom_port: 8097 + visdom_env: 'nerf_pytorch3d' +raysampler: + n_pts_per_ray: 64 + n_pts_per_ray_fine: 64 + n_rays_per_image: 1024 + min_depth: 8.0 + max_depth: 23.0 + stratified: True + stratified_test: False + chunk_size_test: 6000 +implicit_function: + n_harmonic_functions_xyz: 10 + n_harmonic_functions_dir: 4 + n_hidden_neurons_xyz: 256 + n_hidden_neurons_dir: 128 + density_noise_std: 0.0 + n_layers_xyz: 8 diff --git a/projects/nerf/nerf/nerf_renderer.py b/projects/nerf/nerf/nerf_renderer.py index 57ec5db6..c01db773 100644 --- a/projects/nerf/nerf/nerf_renderer.py +++ b/projects/nerf/nerf/nerf_renderer.py @@ -2,8 +2,11 @@ from typing import Tuple, List, Optional import torch -from pytorch3d.renderer import ImplicitRenderer +from pytorch3d.renderer import ImplicitRenderer, ray_bundle_to_ray_points from pytorch3d.renderer.cameras import CamerasBase +from pytorch3d.structures import Pointclouds +from pytorch3d.vis.plotly_vis import plot_scene +from visdom import Visdom from .implicit_function import NeuralRadianceField from .raymarcher import EmissionAbsorptionNeRFRaymarcher @@ -357,3 +360,68 @@ class RadianceFieldRenderer(torch.nn.Module): ) return out, metrics + + +def visualize_nerf_outputs( + nerf_out: dict, output_cache: List, viz: Visdom, visdom_env: str +): + """ + Visualizes the outputs of the `RadianceFieldRenderer`. + + Args: + nerf_out: An output of the validation rendering pass. + output_cache: A list with outputs of several training render passes. + viz: A visdom connection object. + visdom_env: The name of visdom environment for visualization. + """ + + # Show the training images. + ims = torch.stack([o["image"] for o in output_cache]) + ims = torch.cat(list(ims), dim=1) + viz.image( + ims.permute(2, 0, 1), + env=visdom_env, + win="images", + opts={"title": "train_images"}, + ) + + # Show the coarse and fine renders together with the ground truth images. + ims_full = torch.cat( + [ + nerf_out[imvar][0].permute(2, 0, 1).detach().cpu().clamp(0.0, 1.0) + for imvar in ("rgb_coarse", "rgb_fine", "rgb_gt") + ], + dim=2, + ) + viz.image( + ims_full, + env=visdom_env, + win="images_full", + opts={"title": "coarse | fine | target"}, + ) + + # Make a 3D plot of training cameras and their emitted rays. + camera_trace = { + f"camera_{ci:03d}": o["camera"].cpu() for ci, o in enumerate(output_cache) + } + ray_pts_trace = { + f"ray_pts_{ci:03d}": Pointclouds( + ray_bundle_to_ray_points(o["coarse_ray_bundle"]) + .detach() + .cpu() + .view(1, -1, 3) + ) + for ci, o in enumerate(output_cache) + } + plotly_plot = plot_scene( + { + "training_scene": { + **camera_trace, + **ray_pts_trace, + }, + }, + pointcloud_max_points=5000, + pointcloud_marker_size=1, + camera_scale=0.3, + ) + viz.plotlyplot(plotly_plot, env=visdom_env, win="scenes") diff --git a/projects/nerf/train_nerf.py b/projects/nerf/train_nerf.py new file mode 100644 index 00000000..2822cb0f --- /dev/null +++ b/projects/nerf/train_nerf.py @@ -0,0 +1,265 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import collections +import os +import pickle +import warnings + +import hydra +import numpy as np +import torch +from nerf.dataset import get_nerf_datasets, trivial_collate +from nerf.nerf_renderer import RadianceFieldRenderer, visualize_nerf_outputs +from nerf.stats import Stats +from omegaconf import DictConfig +from visdom import Visdom + +CONFIG_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs") + + +@hydra.main(config_path=CONFIG_DIR, config_name="lego") +def main(cfg: DictConfig): + + # Set the relevant seeds for reproducibility. + np.random.seed(cfg.seed) + torch.manual_seed(cfg.seed) + + # Device on which to run. + if torch.cuda.is_available(): + device = "cuda" + else: + warnings.warn( + "Please note that although executing on CPU is supported," + + "the training is unlikely to finish in resonable time." + ) + device = "cpu" + + # Initialize the Radiance Field model. + model = RadianceFieldRenderer( + image_size=cfg.data.image_size, + n_pts_per_ray=cfg.raysampler.n_pts_per_ray, + n_pts_per_ray_fine=cfg.raysampler.n_pts_per_ray, + n_rays_per_image=cfg.raysampler.n_rays_per_image, + min_depth=cfg.raysampler.min_depth, + max_depth=cfg.raysampler.max_depth, + stratified=cfg.raysampler.stratified, + stratified_test=cfg.raysampler.stratified_test, + chunk_size_test=cfg.raysampler.chunk_size_test, + n_harmonic_functions_xyz=cfg.implicit_function.n_harmonic_functions_xyz, + n_harmonic_functions_dir=cfg.implicit_function.n_harmonic_functions_dir, + n_hidden_neurons_xyz=cfg.implicit_function.n_hidden_neurons_xyz, + n_hidden_neurons_dir=cfg.implicit_function.n_hidden_neurons_dir, + n_layers_xyz=cfg.implicit_function.n_layers_xyz, + density_noise_std=cfg.implicit_function.density_noise_std, + ) + + # Move the model to the relevant device. + model.to(device) + + # Init stats to None before loading. + stats = None + optimizer_state_dict = None + start_epoch = 0 + + checkpoint_path = os.path.join(hydra.utils.get_original_cwd(), cfg.checkpoint_path) + if len(cfg.checkpoint_path) > 0: + # Make the root of the experiment directory. + checkpoint_dir = os.path.split(checkpoint_path)[0] + os.makedirs(checkpoint_dir, exist_ok=True) + + # Resume training if requested. + if cfg.resume and os.path.isfile(checkpoint_path): + print(f"Resuming from checkpoint {checkpoint_path}.") + loaded_data = torch.load(checkpoint_path) + model.load_state_dict(loaded_data["model"]) + stats = pickle.loads(loaded_data["stats"]) + print(f" => resuming from epoch {stats.epoch}.") + optimizer_state_dict = loaded_data["optimizer"] + start_epoch = stats.epoch + + # Initialize the optimizer. + optimizer = torch.optim.Adam( + model.parameters(), + lr=cfg.optimizer.lr, + ) + + # Load the optimizer state dict in case we are resuming. + if optimizer_state_dict is not None: + optimizer.load_state_dict(optimizer_state_dict) + optimizer.last_epoch = start_epoch + + # Init the stats object. + if stats is None: + stats = Stats( + ["loss", "mse_coarse", "mse_fine", "psnr_coarse", "psnr_fine", "sec/it"], + ) + + # Learning rate scheduler setup. + + # Following the original code, we use exponential decay of the + # learning rate: current_lr = base_lr * gamma ** (epoch / step_size) + def lr_lambda(epoch): + return cfg.optimizer.lr_scheduler_gamma ** ( + epoch / cfg.optimizer.lr_scheduler_step_size + ) + + # The learning rate scheduling is implemented with LambdaLR PyTorch scheduler. + lr_scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, lr_lambda, last_epoch=start_epoch - 1, verbose=False + ) + + # Initialize the cache for storing variables needed for visulization. + visuals_cache = collections.deque(maxlen=cfg.visualization.history_size) + + # Init the visualization visdom env. + if cfg.visualization.visdom: + viz = Visdom( + server=cfg.visualization.visdom_server, + port=cfg.visualization.visdom_port, + use_incoming_socket=False, + ) + else: + viz = None + + # Load the training/validation data. + train_dataset, val_dataset, _ = get_nerf_datasets( + dataset_name=cfg.data.dataset_name, + image_size=cfg.data.image_size, + ) + + if cfg.data.precache_rays: + # Precache the projection rays. + model.eval() + with torch.no_grad(): + for dataset in (train_dataset, val_dataset): + cache_cameras = [e["camera"].to(device) for e in dataset] + cache_camera_hashes = [e["camera_idx"] for e in dataset] + model.precache_rays(cache_cameras, cache_camera_hashes) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=1, + shuffle=True, + num_workers=0, + collate_fn=trivial_collate, + ) + + # The validation dataloader is just an endless stream of random samples. + val_dataloader = torch.utils.data.DataLoader( + val_dataset, + batch_size=1, + num_workers=0, + collate_fn=trivial_collate, + sampler=torch.utils.data.RandomSampler( + val_dataset, + replacement=True, + num_samples=cfg.optimizer.max_epochs, + ), + ) + + # Set the model to the training mode. + model.train() + + # Run the main training loop. + for epoch in range(start_epoch, cfg.optimizer.max_epochs): + stats.new_epoch() # Init a new epoch. + for iteration, batch in enumerate(train_dataloader): + image, camera, camera_idx = batch[0].values() + image = image.to(device) + camera = camera.to(device) + + optimizer.zero_grad() + + # Run the forward pass of the model. + nerf_out, metrics = model( + camera_idx if cfg.data.precache_rays else None, + camera, + image, + ) + + # The loss is a sum of coarse and fine MSEs + loss = metrics["mse_coarse"] + metrics["mse_fine"] + + # Take the training step. + loss.backward() + optimizer.step() + + # Update stats with the current metrics. + stats.update( + {"loss": float(loss), **metrics}, + stat_set="train", + ) + + if iteration % cfg.stats_print_interval == 0: + stats.print(stat_set="train") + + # Update the visualisatioon cache. + visuals_cache.append( + { + "camera": camera.cpu(), + "camera_idx": camera_idx, + "image": image.cpu().detach(), + "rgb_fine": nerf_out["rgb_fine"].cpu().detach(), + "rgb_coarse": nerf_out["rgb_coarse"].cpu().detach(), + "rgb_gt": nerf_out["rgb_gt"].cpu().detach(), + "coarse_ray_bundle": nerf_out["coarse_ray_bundle"], + } + ) + + # Adjust the learning rate. + lr_scheduler.step() + + # Validation + if epoch % cfg.validation_epoch_interval == 0 and epoch > 0: + + # Sample a validation camera/image. + val_batch = next(val_dataloader.__iter__()) + val_image, val_camera, camera_idx = val_batch[0].values() + val_image = val_image.to(device) + val_camera = val_camera.to(device) + + # Activate eval mode of the model (allows to do a full rendering pass). + model.eval() + with torch.no_grad(): + val_nerf_out, val_metrics = model( + camera_idx if cfg.data.precache_rays else None, + val_camera, + val_image, + ) + + # Update stats with the validation metrics. + stats.update(val_metrics, stat_set="val") + stats.print(stat_set="val") + + if viz is not None: + # Plot that loss curves into visdom. + stats.plot_stats( + viz=viz, + visdom_env=cfg.visualization.visdom_env, + plot_file=None, + ) + # Visualize the intermediate results. + visualize_nerf_outputs( + val_nerf_out, visuals_cache, viz, cfg.visualization.visdom_env + ) + + # Set the model back to train mode. + model.train() + + # Checkpoint. + if ( + epoch % cfg.checkpoint_epoch_interval == 0 + and len(cfg.checkpoint_path) > 0 + and epoch > 0 + ): + print(f"Storing checkpoint {checkpoint_path}.") + data_to_store = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "stats": pickle.dumps(stats), + } + torch.save(data_to_store, checkpoint_path) + + +if __name__ == "__main__": + main()