diff --git a/projects/nerf/nerf/raysampler.py b/projects/nerf/nerf/raysampler.py index a6b3c270..69e99b9a 100644 --- a/projects/nerf/nerf/raysampler.py +++ b/projects/nerf/nerf/raysampler.py @@ -330,9 +330,9 @@ class NeRFRaysampler(torch.nn.Module): if self.training: # During training we randomly subsample rays. - sel_rays = torch.randperm(n_pixels, device=device)[ - : self._mc_raysampler._n_rays_per_image - ] + sel_rays = torch.randperm( + n_pixels, device=full_ray_bundle.lengths.device + )[: self._mc_raysampler._n_rays_per_image] else: # In case we test, we take only the requested chunk. if chunksize is None: