From 1b8d86a104eab24ac25863c423d084d611f64bae Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Tue, 31 Aug 2021 14:29:11 -0700 Subject: [PATCH] (breaking) image_size-agnostic GridRaySampler Summary: As suggested in #802. By not persisting the _xy_grid buffer, we can allow (in some cases) a model with one image_size to be loaded from a saved model which was trained at a different resolution. Also avoid persisting _frequencies in HarmonicEmbedding for similar reasons. BC-break: This will cause load_state_dict, in strict mode, to complain if you try to load an old model with the new code. Reviewed By: patricklabatut Differential Revision: D30349234 fbshipit-source-id: d6061d1e51c9f79a78d61a9f732c9a5dfadbbb47 --- projects/nerf/nerf/harmonic_embedding.py | 6 +++--- pytorch3d/renderer/implicit/raysampling.py | 2 +- tests/test_raysampling.py | 20 ++++++++++++++++++++ 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/projects/nerf/nerf/harmonic_embedding.py b/projects/nerf/nerf/harmonic_embedding.py index 2aab985e..4040eed9 100644 --- a/projects/nerf/nerf/harmonic_embedding.py +++ b/projects/nerf/nerf/harmonic_embedding.py @@ -14,7 +14,7 @@ class HarmonicEmbedding(torch.nn.Module): omega0: float = 1.0, logspace: bool = True, include_input: bool = True, - ): + ) -> None: """ Given an input tensor `x` of shape [minibatch, ... , dim], the harmonic embedding layer converts each feature @@ -69,10 +69,10 @@ class HarmonicEmbedding(torch.nn.Module): dtype=torch.float32, ) - self.register_buffer("_frequencies", omega0 * frequencies) + self.register_buffer("_frequencies", omega0 * frequencies, persistent=False) self.include_input = include_input - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: tensor of shape [..., dim] diff --git a/pytorch3d/renderer/implicit/raysampling.py b/pytorch3d/renderer/implicit/raysampling.py index a4711d6b..8cb25d76 100644 --- a/pytorch3d/renderer/implicit/raysampling.py +++ b/pytorch3d/renderer/implicit/raysampling.py @@ -96,7 +96,7 @@ class GridRaysampler(torch.nn.Module): ), dim=-1, ) - self.register_buffer("_xy_grid", _xy_grid) + self.register_buffer("_xy_grid", _xy_grid, persistent=False) def forward(self, cameras: CamerasBase, **kwargs) -> RayBundle: """ diff --git a/tests/test_raysampling.py b/tests/test_raysampling.py index 2a188443..766582d7 100644 --- a/tests/test_raysampling.py +++ b/tests/test_raysampling.py @@ -425,3 +425,23 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase): ray_bundle_camera_fix_seed.directions.view(batch_size, -1, 3), atol=1e-5, ) + + def test_load_state(self): + # check that we can load the state of one ray sampler into + # another with different image size. + module1 = NDCGridRaysampler( + image_width=20, + image_height=30, + n_pts_per_ray=40, + min_depth=1.2, + max_depth=2.3, + ) + module2 = NDCGridRaysampler( + image_width=22, + image_height=32, + n_pts_per_ray=42, + min_depth=1.2, + max_depth=2.3, + ) + state = module1.state_dict() + module2.load_state_dict(state)