mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
use no_grad for sample_pdf in NeRF project
Summary: We don't use gradents of sample_pdf. Here we disable gradient calculation around calling it, instead of calling detach later. There's a theoretical speedup, but mainly this enables using sample_pdf implementations which don't support gradients. Reviewed By: nikhilaravi Differential Revision: D28057284 fbshipit-source-id: 8a9d5e73f18b34e1e4291028008e02973023638d
This commit is contained in:
parent
6053d0e46f
commit
097b0ef2c6
@ -69,11 +69,11 @@ class ProbabilisticRaysampler(torch.nn.Module):
|
||||
# Calculate the mid-points between the ray depths.
|
||||
z_vals = input_ray_bundle.lengths
|
||||
batch_size = z_vals.shape[0]
|
||||
z_vals_mid = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1])
|
||||
|
||||
# Carry out the importance sampling.
|
||||
z_samples = (
|
||||
sample_pdf(
|
||||
with torch.no_grad():
|
||||
z_vals_mid = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1])
|
||||
z_samples = sample_pdf(
|
||||
z_vals_mid.view(-1, z_vals_mid.shape[-1]),
|
||||
ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1],
|
||||
self._n_pts_per_ray,
|
||||
@ -81,10 +81,7 @@ class ProbabilisticRaysampler(torch.nn.Module):
|
||||
(self._stratified and self.training)
|
||||
or (self._stratified_test and not self.training)
|
||||
),
|
||||
)
|
||||
.detach()
|
||||
.view(batch_size, z_vals.shape[1], self._n_pts_per_ray)
|
||||
)
|
||||
).view(batch_size, z_vals.shape[1], self._n_pts_per_ray)
|
||||
|
||||
if self._add_input_samples:
|
||||
# Add the new samples to the input ones.
|
||||
|
Loading…
x
Reference in New Issue
Block a user