diff --git a/tests/implicitron/test_forward_pass.py b/tests/implicitron/test_forward_pass.py index 8f456dd9..52d5016e 100644 --- a/tests/implicitron/test_forward_pass.py +++ b/tests/implicitron/test_forward_pass.py @@ -163,6 +163,7 @@ class TestGenericModel(unittest.TestCase): device = torch.device("cuda:1") args = get_default_args(GenericModel) args.view_pooler_enabled = True + args.image_feature_extractor_class_type = "ResNetFeatureExtractor" args.image_feature_extractor_ResNetFeatureExtractor_args.add_masks = False model = GenericModel(**args) model.to(device)