diff --git a/projects/nerf/nerf/harmonic_embedding.py b/projects/nerf/nerf/harmonic_embedding.py index 2aab985e..4040eed9 100644 --- a/projects/nerf/nerf/harmonic_embedding.py +++ b/projects/nerf/nerf/harmonic_embedding.py @@ -14,7 +14,7 @@ class HarmonicEmbedding(torch.nn.Module): omega0: float = 1.0, logspace: bool = True, include_input: bool = True, - ): + ) -> None: """ Given an input tensor `x` of shape [minibatch, ... , dim], the harmonic embedding layer converts each feature @@ -69,10 +69,10 @@ class HarmonicEmbedding(torch.nn.Module): dtype=torch.float32, ) - self.register_buffer("_frequencies", omega0 * frequencies) + self.register_buffer("_frequencies", omega0 * frequencies, persistent=False) self.include_input = include_input - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: tensor of shape [..., dim] diff --git a/pytorch3d/renderer/implicit/raysampling.py b/pytorch3d/renderer/implicit/raysampling.py index a4711d6b..8cb25d76 100644 --- a/pytorch3d/renderer/implicit/raysampling.py +++ b/pytorch3d/renderer/implicit/raysampling.py @@ -96,7 +96,7 @@ class GridRaysampler(torch.nn.Module): ), dim=-1, ) - self.register_buffer("_xy_grid", _xy_grid) + self.register_buffer("_xy_grid", _xy_grid, persistent=False) def forward(self, cameras: CamerasBase, **kwargs) -> RayBundle: """ diff --git a/tests/test_raysampling.py b/tests/test_raysampling.py index 2a188443..766582d7 100644 --- a/tests/test_raysampling.py +++ b/tests/test_raysampling.py @@ -425,3 +425,23 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase): ray_bundle_camera_fix_seed.directions.view(batch_size, -1, 3), atol=1e-5, ) + + def test_load_state(self): + # check that we can load the state of one ray sampler into + # another with different image size. + module1 = NDCGridRaysampler( + image_width=20, + image_height=30, + n_pts_per_ray=40, + min_depth=1.2, + max_depth=2.3, + ) + module2 = NDCGridRaysampler( + image_width=22, + image_height=32, + n_pts_per_ray=42, + min_depth=1.2, + max_depth=2.3, + ) + state = module1.state_dict() + module2.load_state_dict(state)