Make feature extractor pluggable

Summary: Make ResNetFeatureExtractor be an implementation of FeatureExtractorBase.

Reviewed By: davnov134

Differential Revision: D35433098

fbshipit-source-id: 0664a9166a88e150231cfe2eceba017ae55aed3a
This commit is contained in:
Jeremy Reizenstein
2022-05-18 08:50:18 -07:00
committed by Facebook GitHub Bot
parent cd7b885169
commit 9ec9d057cc
12 changed files with 151 additions and 57 deletions

View File

@@ -33,7 +33,7 @@ class TestGenericModel(unittest.TestCase):
# Simple test of a forward and backward pass of the default GenericModel.
device = torch.device("cuda:1")
expand_args_fields(GenericModel)
model = GenericModel()
model = GenericModel(render_image_height=80, render_image_width=80)
model.to(device)
self._one_model_test(model, device)
@@ -149,6 +149,37 @@ class TestGenericModel(unittest.TestCase):
)
self.assertGreater(train_preds["objective"].item(), 0)
def test_viewpool(self):
device = torch.device("cuda:1")
args = get_default_args(GenericModel)
args.view_pooler_enabled = True
args.image_feature_extractor_ResNetFeatureExtractor_args.add_masks = False
model = GenericModel(**args)
model.to(device)
n_train_cameras = 2
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
cameras = PerspectiveCameras(R=R, T=T, device=device)
defaulted_args = {
"fg_probability": None,
"depth_map": None,
"mask_crop": None,
}
target_image_rgb = torch.rand(
(n_train_cameras, 3, model.render_image_height, model.render_image_width),
device=device,
)
train_preds = model(
camera=cameras,
evaluation_mode=EvaluationMode.TRAINING,
image_rgb=target_image_rgb,
sequence_name=["a"] * n_train_cameras,
**defaulted_args,
)
self.assertGreater(train_preds["objective"].item(), 0)
def _random_input_tensor(
N: int,