ViewPooler class

Summary: Implements a ViewPooler that groups ViewSampler and FeatureAggregator.

Reviewed By: shapovalov

Differential Revision: D35852367

fbshipit-source-id: c1bcaf5a1f826ff94efce53aa5836121ad9c50ec
This commit is contained in:
David Novotny
2022-05-12 12:50:03 -07:00
committed by Facebook GitHub Bot
parent bef959c755
commit 47d06c8924
26 changed files with 304 additions and 110 deletions

View File

@@ -56,7 +56,12 @@ class TestGenericModel(unittest.TestCase):
cfg = _load_model_config_from_yaml(str(config_file))
model = GenericModel(**cfg)
model.to(device)
self._one_model_test(model, device, eval_test=True)
self._one_model_test(
model,
device,
eval_test=True,
bw_test=True,
)
def _one_model_test(
self,
@@ -64,6 +69,7 @@ class TestGenericModel(unittest.TestCase):
device,
n_train_cameras: int = 5,
eval_test: bool = True,
bw_test: bool = True,
):
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
@@ -86,8 +92,12 @@ class TestGenericModel(unittest.TestCase):
**random_args,
evaluation_mode=EvaluationMode.TRAINING,
)
self.assertGreater(train_preds["objective"].item(), 0)
train_preds["objective"].backward()
self.assertTrue(
train_preds["objective"].isfinite().item()
) # check finiteness of the objective
if bw_test:
train_preds["objective"].backward()
if eval_test:
model.eval()