register_buffer compatibility

Summary: In D30349234 (1b8d86a104) we introduced persistent=False to some register_buffer calls, which depend on PyTorch 1.6. We go back to the old behaviour for PyTorch 1.5.

Reviewed By: nikhilaravi

Differential Revision: D30731327

fbshipit-source-id: ab02ef98ee87440ef02479b72f4872b562ab85b5
This commit is contained in:
Jeremy Reizenstein
2021-09-09 07:36:56 -07:00
committed by Facebook GitHub Bot
parent bbc7573261
commit c3d7808868
3 changed files with 14 additions and 3 deletions

View File

@@ -69,7 +69,12 @@ class HarmonicEmbedding(torch.nn.Module):
dtype=torch.float32,
)
self.register_buffer("_frequencies", omega0 * frequencies, persistent=False)
try:
self.register_buffer("_frequencies", omega0 * frequencies, persistent=False)
except TypeError:
# workaround for pytorch<1.6
self.register_buffer("_frequencies", omega0 * frequencies)
self.include_input = include_input
def forward(self, x: torch.Tensor) -> torch.Tensor: