diff --git a/projects/nerf/nerf/raysampler.py b/projects/nerf/nerf/raysampler.py index 6c49ba17..c46cdbe1 100644 --- a/projects/nerf/nerf/raysampler.py +++ b/projects/nerf/nerf/raysampler.py @@ -69,11 +69,11 @@ class ProbabilisticRaysampler(torch.nn.Module): # Calculate the mid-points between the ray depths. z_vals = input_ray_bundle.lengths batch_size = z_vals.shape[0] - z_vals_mid = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1]) # Carry out the importance sampling. - z_samples = ( - sample_pdf( + with torch.no_grad(): + z_vals_mid = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1]) + z_samples = sample_pdf( z_vals_mid.view(-1, z_vals_mid.shape[-1]), ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1], self._n_pts_per_ray, @@ -81,10 +81,7 @@ class ProbabilisticRaysampler(torch.nn.Module): (self._stratified and self.training) or (self._stratified_test and not self.training) ), - ) - .detach() - .view(batch_size, z_vals.shape[1], self._n_pts_per_ray) - ) + ).view(batch_size, z_vals.shape[1], self._n_pts_per_ray) if self._add_input_samples: # Add the new samples to the input ones.