mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
NeRF training: avoid caching unused visualization data.
Summary: If we are not visualizing the training with visdom, then there are a couple of outputs of the coarse rendering step which are not small and are returned by the renderer but never used. We don't need to bother transferring them to the CPU. Reviewed By: nikhilaravi Differential Revision: D28939958 fbshipit-source-id: 7e0d6681d6524f7fb57b6b20164580006120de80
This commit is contained in:
parent
7204a4ca64
commit
f00ef66727
@ -64,6 +64,7 @@ class RadianceFieldRenderer(torch.nn.Module):
|
||||
n_layers_xyz: int = 8,
|
||||
append_xyz: Tuple[int] = (5,),
|
||||
density_noise_std: float = 0.0,
|
||||
visualization: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -102,6 +103,7 @@ class RadianceFieldRenderer(torch.nn.Module):
|
||||
density_noise_std: The standard deviation of the random normal noise
|
||||
added to the output of the occupancy MLP.
|
||||
Active only when `self.training==True`.
|
||||
visualization: whether to store extra output for visualization.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
@ -159,6 +161,7 @@ class RadianceFieldRenderer(torch.nn.Module):
|
||||
self._density_noise_std = density_noise_std
|
||||
self._chunk_size_test = chunk_size_test
|
||||
self._image_size = image_size
|
||||
self.visualization = visualization
|
||||
|
||||
def precache_rays(
|
||||
self,
|
||||
@ -248,16 +251,15 @@ class RadianceFieldRenderer(torch.nn.Module):
|
||||
else:
|
||||
raise ValueError(f"No such rendering pass {renderer_pass}")
|
||||
|
||||
return {
|
||||
"rgb_fine": rgb_fine,
|
||||
"rgb_coarse": rgb_coarse,
|
||||
"rgb_gt": rgb_gt,
|
||||
out = {"rgb_fine": rgb_fine, "rgb_coarse": rgb_coarse, "rgb_gt": rgb_gt}
|
||||
if self.visualization:
|
||||
# Store the coarse rays/weights only for visualization purposes.
|
||||
"coarse_ray_bundle": type(coarse_ray_bundle)(
|
||||
out["coarse_ray_bundle"] = type(coarse_ray_bundle)(
|
||||
*[v.detach().cpu() for k, v in coarse_ray_bundle._asdict().items()]
|
||||
),
|
||||
"coarse_weights": coarse_weights.detach().cpu(),
|
||||
}
|
||||
)
|
||||
out["coarse_weights"] = coarse_weights.detach().cpu()
|
||||
|
||||
return out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -52,6 +52,7 @@ def main(cfg: DictConfig):
|
||||
n_hidden_neurons_dir=cfg.implicit_function.n_hidden_neurons_dir,
|
||||
n_layers_xyz=cfg.implicit_function.n_layers_xyz,
|
||||
density_noise_std=cfg.implicit_function.density_noise_std,
|
||||
visualization=cfg.visualization.visdom,
|
||||
)
|
||||
|
||||
# Move the model to the relevant device.
|
||||
@ -195,17 +196,18 @@ def main(cfg: DictConfig):
|
||||
stats.print(stat_set="train")
|
||||
|
||||
# Update the visualization cache.
|
||||
visuals_cache.append(
|
||||
{
|
||||
"camera": camera.cpu(),
|
||||
"camera_idx": camera_idx,
|
||||
"image": image.cpu().detach(),
|
||||
"rgb_fine": nerf_out["rgb_fine"].cpu().detach(),
|
||||
"rgb_coarse": nerf_out["rgb_coarse"].cpu().detach(),
|
||||
"rgb_gt": nerf_out["rgb_gt"].cpu().detach(),
|
||||
"coarse_ray_bundle": nerf_out["coarse_ray_bundle"],
|
||||
}
|
||||
)
|
||||
if viz is not None:
|
||||
visuals_cache.append(
|
||||
{
|
||||
"camera": camera.cpu(),
|
||||
"camera_idx": camera_idx,
|
||||
"image": image.cpu().detach(),
|
||||
"rgb_fine": nerf_out["rgb_fine"].cpu().detach(),
|
||||
"rgb_coarse": nerf_out["rgb_coarse"].cpu().detach(),
|
||||
"rgb_gt": nerf_out["rgb_gt"].cpu().detach(),
|
||||
"coarse_ray_bundle": nerf_out["coarse_ray_bundle"],
|
||||
}
|
||||
)
|
||||
|
||||
# Adjust the learning rate.
|
||||
lr_scheduler.step()
|
||||
|
Loading…
x
Reference in New Issue
Block a user