mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
register_buffer compatibility
Summary: In D30349234 (1b8d86a104
) we introduced persistent=False to some register_buffer calls, which depend on PyTorch 1.6. We go back to the old behaviour for PyTorch 1.5.
Reviewed By: nikhilaravi
Differential Revision: D30731327
fbshipit-source-id: ab02ef98ee87440ef02479b72f4872b562ab85b5
This commit is contained in:
parent
bbc7573261
commit
c3d7808868
@ -69,7 +69,12 @@ class HarmonicEmbedding(torch.nn.Module):
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.register_buffer("_frequencies", omega0 * frequencies, persistent=False)
|
try:
|
||||||
|
self.register_buffer("_frequencies", omega0 * frequencies, persistent=False)
|
||||||
|
except TypeError:
|
||||||
|
# workaround for pytorch<1.6
|
||||||
|
self.register_buffer("_frequencies", omega0 * frequencies)
|
||||||
|
|
||||||
self.include_input = include_input
|
self.include_input = include_input
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
@ -96,7 +96,10 @@ class GridRaysampler(torch.nn.Module):
|
|||||||
),
|
),
|
||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
self.register_buffer("_xy_grid", _xy_grid, persistent=False)
|
try:
|
||||||
|
self.register_buffer("_xy_grid", _xy_grid, persistent=False)
|
||||||
|
except TypeError:
|
||||||
|
self.register_buffer("_xy_grid", _xy_grid) # workaround for pytorch<1.6
|
||||||
|
|
||||||
def forward(self, cameras: CamerasBase, **kwargs) -> RayBundle:
|
def forward(self, cameras: CamerasBase, **kwargs) -> RayBundle:
|
||||||
"""
|
"""
|
||||||
|
@ -426,7 +426,10 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
|||||||
atol=1e-5,
|
atol=1e-5,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_load_state(self):
|
@unittest.skipIf(
|
||||||
|
torch.__version__[:4] == "1.5.", "non persistent buffer needs PyTorch 1.6"
|
||||||
|
)
|
||||||
|
def test_load_state_different_resolution(self):
|
||||||
# check that we can load the state of one ray sampler into
|
# check that we can load the state of one ray sampler into
|
||||||
# another with different image size.
|
# another with different image size.
|
||||||
module1 = NDCGridRaysampler(
|
module1 = NDCGridRaysampler(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user