From c3d7808868411177e8d6cd99ff3798ec169cdf9d Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Thu, 9 Sep 2021 07:36:56 -0700 Subject: [PATCH] register_buffer compatibility Summary: In D30349234 (https://github.com/facebookresearch/pytorch3d/commit/1b8d86a104eab24ac25863c423d084d611f64bae) we introduced persistent=False to some register_buffer calls, which depend on PyTorch 1.6. We go back to the old behaviour for PyTorch 1.5. Reviewed By: nikhilaravi Differential Revision: D30731327 fbshipit-source-id: ab02ef98ee87440ef02479b72f4872b562ab85b5 --- projects/nerf/nerf/harmonic_embedding.py | 7 ++++++- pytorch3d/renderer/implicit/raysampling.py | 5 ++++- tests/test_raysampling.py | 5 ++++- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/projects/nerf/nerf/harmonic_embedding.py b/projects/nerf/nerf/harmonic_embedding.py index 4040eed9..ecd11ce1 100644 --- a/projects/nerf/nerf/harmonic_embedding.py +++ b/projects/nerf/nerf/harmonic_embedding.py @@ -69,7 +69,12 @@ class HarmonicEmbedding(torch.nn.Module): dtype=torch.float32, ) - self.register_buffer("_frequencies", omega0 * frequencies, persistent=False) + try: + self.register_buffer("_frequencies", omega0 * frequencies, persistent=False) + except TypeError: + # workaround for pytorch<1.6 + self.register_buffer("_frequencies", omega0 * frequencies) + self.include_input = include_input def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/pytorch3d/renderer/implicit/raysampling.py b/pytorch3d/renderer/implicit/raysampling.py index 8cb25d76..e3b33958 100644 --- a/pytorch3d/renderer/implicit/raysampling.py +++ b/pytorch3d/renderer/implicit/raysampling.py @@ -96,7 +96,10 @@ class GridRaysampler(torch.nn.Module): ), dim=-1, ) - self.register_buffer("_xy_grid", _xy_grid, persistent=False) + try: + self.register_buffer("_xy_grid", _xy_grid, persistent=False) + except TypeError: + self.register_buffer("_xy_grid", _xy_grid) # workaround for pytorch<1.6 def forward(self, cameras: CamerasBase, **kwargs) -> RayBundle: """ diff --git a/tests/test_raysampling.py b/tests/test_raysampling.py index 766582d7..9c7e3de3 100644 --- a/tests/test_raysampling.py +++ b/tests/test_raysampling.py @@ -426,7 +426,10 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase): atol=1e-5, ) - def test_load_state(self): + @unittest.skipIf( + torch.__version__[:4] == "1.5.", "non persistent buffer needs PyTorch 1.6" + ) + def test_load_state_different_resolution(self): # check that we can load the state of one ray sampler into # another with different image size. module1 = NDCGridRaysampler(