clean IF args

Summary: continued - avoid duplicate inputs

Reviewed By: davnov134

Differential Revision: D38248827

fbshipit-source-id: 91ed398e304496a936f66e7a70ab3d189eeb5c70
This commit is contained in:
Jeremy Reizenstein
2022-08-03 12:37:31 -07:00
committed by Facebook GitHub Bot
parent 078846d166
commit 46e82efb4e
6 changed files with 65 additions and 37 deletions

View File

@@ -55,9 +55,10 @@ class TestSRN(TestCaseMixin, unittest.TestCase):
def test_srn_hypernet_implicit_function(self):
# TODO investigate: If latent_dim_hypernet=0, why does this crash and dump core?
latent_dim_hypernet = 39
hypernet_args = {"latent_dim_hypernet": latent_dim_hypernet}
device = torch.device("cuda:0")
implicit_function = SRNHyperNetImplicitFunction(hypernet_args=hypernet_args)
implicit_function = SRNHyperNetImplicitFunction(
latent_dim_hypernet=latent_dim_hypernet
)
implicit_function.to(device)
global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device)
bundle = self._get_bundle(device=device)