diff --git a/projects/nerf/nerf/harmonic_embedding.py b/projects/nerf/nerf/harmonic_embedding.py index 4040eed9..ecd11ce1 100644 --- a/projects/nerf/nerf/harmonic_embedding.py +++ b/projects/nerf/nerf/harmonic_embedding.py @@ -69,7 +69,12 @@ class HarmonicEmbedding(torch.nn.Module): dtype=torch.float32, ) - self.register_buffer("_frequencies", omega0 * frequencies, persistent=False) + try: + self.register_buffer("_frequencies", omega0 * frequencies, persistent=False) + except TypeError: + # workaround for pytorch<1.6 + self.register_buffer("_frequencies", omega0 * frequencies) + self.include_input = include_input def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/pytorch3d/renderer/implicit/raysampling.py b/pytorch3d/renderer/implicit/raysampling.py index 8cb25d76..e3b33958 100644 --- a/pytorch3d/renderer/implicit/raysampling.py +++ b/pytorch3d/renderer/implicit/raysampling.py @@ -96,7 +96,10 @@ class GridRaysampler(torch.nn.Module): ), dim=-1, ) - self.register_buffer("_xy_grid", _xy_grid, persistent=False) + try: + self.register_buffer("_xy_grid", _xy_grid, persistent=False) + except TypeError: + self.register_buffer("_xy_grid", _xy_grid) # workaround for pytorch<1.6 def forward(self, cameras: CamerasBase, **kwargs) -> RayBundle: """ diff --git a/tests/test_raysampling.py b/tests/test_raysampling.py index 766582d7..9c7e3de3 100644 --- a/tests/test_raysampling.py +++ b/tests/test_raysampling.py @@ -426,7 +426,10 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase): atol=1e-5, ) - def test_load_state(self): + @unittest.skipIf( + torch.__version__[:4] == "1.5.", "non persistent buffer needs PyTorch 1.6" + ) + def test_load_state_different_resolution(self): # check that we can load the state of one ray sampler into # another with different image size. module1 = NDCGridRaysampler(