From 097b0ef2c640aa0b962495714f83ad9eb5e08fdf Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Wed, 28 Apr 2021 09:33:35 -0700 Subject: [PATCH] use no_grad for sample_pdf in NeRF project Summary: We don't use gradents of sample_pdf. Here we disable gradient calculation around calling it, instead of calling detach later. There's a theoretical speedup, but mainly this enables using sample_pdf implementations which don't support gradients. Reviewed By: nikhilaravi Differential Revision: D28057284 fbshipit-source-id: 8a9d5e73f18b34e1e4291028008e02973023638d --- projects/nerf/nerf/raysampler.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/projects/nerf/nerf/raysampler.py b/projects/nerf/nerf/raysampler.py index 6c49ba17..c46cdbe1 100644 --- a/projects/nerf/nerf/raysampler.py +++ b/projects/nerf/nerf/raysampler.py @@ -69,11 +69,11 @@ class ProbabilisticRaysampler(torch.nn.Module): # Calculate the mid-points between the ray depths. z_vals = input_ray_bundle.lengths batch_size = z_vals.shape[0] - z_vals_mid = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1]) # Carry out the importance sampling. - z_samples = ( - sample_pdf( + with torch.no_grad(): + z_vals_mid = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1]) + z_samples = sample_pdf( z_vals_mid.view(-1, z_vals_mid.shape[-1]), ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1], self._n_pts_per_ray, @@ -81,10 +81,7 @@ class ProbabilisticRaysampler(torch.nn.Module): (self._stratified and self.training) or (self._stratified_test and not self.training) ), - ) - .detach() - .view(batch_size, z_vals.shape[1], self._n_pts_per_ray) - ) + ).view(batch_size, z_vals.shape[1], self._n_pts_per_ray) if self._add_input_samples: # Add the new samples to the input ones.