diff --git a/projects/nerf/nerf/nerf_renderer.py b/projects/nerf/nerf/nerf_renderer.py index 084cecc5..4880edfb 100644 --- a/projects/nerf/nerf/nerf_renderer.py +++ b/projects/nerf/nerf/nerf_renderer.py @@ -64,6 +64,7 @@ class RadianceFieldRenderer(torch.nn.Module): n_layers_xyz: int = 8, append_xyz: Tuple[int] = (5,), density_noise_std: float = 0.0, + visualization: bool = False, ): """ Args: @@ -102,6 +103,7 @@ class RadianceFieldRenderer(torch.nn.Module): density_noise_std: The standard deviation of the random normal noise added to the output of the occupancy MLP. Active only when `self.training==True`. + visualization: whether to store extra output for visualization. """ super().__init__() @@ -159,6 +161,7 @@ class RadianceFieldRenderer(torch.nn.Module): self._density_noise_std = density_noise_std self._chunk_size_test = chunk_size_test self._image_size = image_size + self.visualization = visualization def precache_rays( self, @@ -248,16 +251,15 @@ class RadianceFieldRenderer(torch.nn.Module): else: raise ValueError(f"No such rendering pass {renderer_pass}") - return { - "rgb_fine": rgb_fine, - "rgb_coarse": rgb_coarse, - "rgb_gt": rgb_gt, + out = {"rgb_fine": rgb_fine, "rgb_coarse": rgb_coarse, "rgb_gt": rgb_gt} + if self.visualization: # Store the coarse rays/weights only for visualization purposes. - "coarse_ray_bundle": type(coarse_ray_bundle)( + out["coarse_ray_bundle"] = type(coarse_ray_bundle)( *[v.detach().cpu() for k, v in coarse_ray_bundle._asdict().items()] - ), - "coarse_weights": coarse_weights.detach().cpu(), - } + ) + out["coarse_weights"] = coarse_weights.detach().cpu() + + return out def forward( self, diff --git a/projects/nerf/train_nerf.py b/projects/nerf/train_nerf.py index 87847046..d028d9ca 100644 --- a/projects/nerf/train_nerf.py +++ b/projects/nerf/train_nerf.py @@ -52,6 +52,7 @@ def main(cfg: DictConfig): n_hidden_neurons_dir=cfg.implicit_function.n_hidden_neurons_dir, n_layers_xyz=cfg.implicit_function.n_layers_xyz, density_noise_std=cfg.implicit_function.density_noise_std, + visualization=cfg.visualization.visdom, ) # Move the model to the relevant device. @@ -195,17 +196,18 @@ def main(cfg: DictConfig): stats.print(stat_set="train") # Update the visualization cache. - visuals_cache.append( - { - "camera": camera.cpu(), - "camera_idx": camera_idx, - "image": image.cpu().detach(), - "rgb_fine": nerf_out["rgb_fine"].cpu().detach(), - "rgb_coarse": nerf_out["rgb_coarse"].cpu().detach(), - "rgb_gt": nerf_out["rgb_gt"].cpu().detach(), - "coarse_ray_bundle": nerf_out["coarse_ray_bundle"], - } - ) + if viz is not None: + visuals_cache.append( + { + "camera": camera.cpu(), + "camera_idx": camera_idx, + "image": image.cpu().detach(), + "rgb_fine": nerf_out["rgb_fine"].cpu().detach(), + "rgb_coarse": nerf_out["rgb_coarse"].cpu().detach(), + "rgb_gt": nerf_out["rgb_gt"].cpu().detach(), + "coarse_ray_bundle": nerf_out["coarse_ray_bundle"], + } + ) # Adjust the learning rate. lr_scheduler.step()