diff --git a/pytorch3d/renderer/implicit/raysampling.py b/pytorch3d/renderer/implicit/raysampling.py index c20e5e6e..a4711d6b 100644 --- a/pytorch3d/renderer/implicit/raysampling.py +++ b/pytorch3d/renderer/implicit/raysampling.py @@ -123,7 +123,7 @@ class GridRaysampler(torch.nn.Module): device = cameras.device # expand the (H, W, 2) grid batch_size-times to (B, H, W, 2) - xy_grid = self._xy_grid.to(device)[None].expand( # pyre-ignore + xy_grid = self._xy_grid.to(device)[None].expand( batch_size, *self._xy_grid.shape )