mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
(breaking) image_size-agnostic GridRaySampler
Summary: As suggested in #802. By not persisting the _xy_grid buffer, we can allow (in some cases) a model with one image_size to be loaded from a saved model which was trained at a different resolution. Also avoid persisting _frequencies in HarmonicEmbedding for similar reasons. BC-break: This will cause load_state_dict, in strict mode, to complain if you try to load an old model with the new code. Reviewed By: patricklabatut Differential Revision: D30349234 fbshipit-source-id: d6061d1e51c9f79a78d61a9f732c9a5dfadbbb47
This commit is contained in:
parent
1251446383
commit
1b8d86a104
@ -14,7 +14,7 @@ class HarmonicEmbedding(torch.nn.Module):
|
|||||||
omega0: float = 1.0,
|
omega0: float = 1.0,
|
||||||
logspace: bool = True,
|
logspace: bool = True,
|
||||||
include_input: bool = True,
|
include_input: bool = True,
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Given an input tensor `x` of shape [minibatch, ... , dim],
|
Given an input tensor `x` of shape [minibatch, ... , dim],
|
||||||
the harmonic embedding layer converts each feature
|
the harmonic embedding layer converts each feature
|
||||||
@ -69,10 +69,10 @@ class HarmonicEmbedding(torch.nn.Module):
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.register_buffer("_frequencies", omega0 * frequencies)
|
self.register_buffer("_frequencies", omega0 * frequencies, persistent=False)
|
||||||
self.include_input = include_input
|
self.include_input = include_input
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x: tensor of shape [..., dim]
|
x: tensor of shape [..., dim]
|
||||||
|
@ -96,7 +96,7 @@ class GridRaysampler(torch.nn.Module):
|
|||||||
),
|
),
|
||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
self.register_buffer("_xy_grid", _xy_grid)
|
self.register_buffer("_xy_grid", _xy_grid, persistent=False)
|
||||||
|
|
||||||
def forward(self, cameras: CamerasBase, **kwargs) -> RayBundle:
|
def forward(self, cameras: CamerasBase, **kwargs) -> RayBundle:
|
||||||
"""
|
"""
|
||||||
|
@ -425,3 +425,23 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
|||||||
ray_bundle_camera_fix_seed.directions.view(batch_size, -1, 3),
|
ray_bundle_camera_fix_seed.directions.view(batch_size, -1, 3),
|
||||||
atol=1e-5,
|
atol=1e-5,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_load_state(self):
|
||||||
|
# check that we can load the state of one ray sampler into
|
||||||
|
# another with different image size.
|
||||||
|
module1 = NDCGridRaysampler(
|
||||||
|
image_width=20,
|
||||||
|
image_height=30,
|
||||||
|
n_pts_per_ray=40,
|
||||||
|
min_depth=1.2,
|
||||||
|
max_depth=2.3,
|
||||||
|
)
|
||||||
|
module2 = NDCGridRaysampler(
|
||||||
|
image_width=22,
|
||||||
|
image_height=32,
|
||||||
|
n_pts_per_ray=42,
|
||||||
|
min_depth=1.2,
|
||||||
|
max_depth=2.3,
|
||||||
|
)
|
||||||
|
state = module1.state_dict()
|
||||||
|
module2.load_state_dict(state)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user