use no_grad for sample_pdf in NeRF project

Summary: We don't use gradents of sample_pdf. Here we disable gradient calculation around calling it, instead of calling detach later. There's a theoretical speedup, but mainly this enables using sample_pdf implementations which don't support gradients.

Reviewed By: nikhilaravi

Differential Revision: D28057284

fbshipit-source-id: 8a9d5e73f18b34e1e4291028008e02973023638d
This commit is contained in:
Jeremy Reizenstein 2021-04-28 09:33:35 -07:00 committed by Facebook GitHub Bot
parent 6053d0e46f
commit 097b0ef2c6

View File

@ -69,11 +69,11 @@ class ProbabilisticRaysampler(torch.nn.Module):
# Calculate the mid-points between the ray depths.
z_vals = input_ray_bundle.lengths
batch_size = z_vals.shape[0]
z_vals_mid = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1])
# Carry out the importance sampling.
z_samples = (
sample_pdf(
with torch.no_grad():
z_vals_mid = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1])
z_samples = sample_pdf(
z_vals_mid.view(-1, z_vals_mid.shape[-1]),
ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1],
self._n_pts_per_ray,
@ -81,10 +81,7 @@ class ProbabilisticRaysampler(torch.nn.Module):
(self._stratified and self.training)
or (self._stratified_test and not self.training)
),
)
.detach()
.view(batch_size, z_vals.shape[1], self._n_pts_per_ray)
)
).view(batch_size, z_vals.shape[1], self._n_pts_per_ray)
if self._add_input_samples:
# Add the new samples to the input ones.