diff --git a/tests/implicitron/test_forward_pass.py b/tests/implicitron/test_forward_pass.py index 6755f04a..8f456dd9 100644 --- a/tests/implicitron/test_forward_pass.py +++ b/tests/implicitron/test_forward_pass.py @@ -14,6 +14,10 @@ from pytorch3d.implicitron.models.renderer.base import EvaluationMode from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras +if os.environ.get("FB_TEST", False): + from .common_resources import provide_resnet34 +else: + from common_resources import provide_resnet34 if os.environ.get("FB_TEST", False): from common_testing import get_pytorch3d_dir @@ -26,6 +30,10 @@ IMPLICITRON_CONFIGS_DIR = ( class TestGenericModel(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + provide_resnet34() + def setUp(self): torch.manual_seed(42) @@ -54,6 +62,8 @@ class TestGenericModel(unittest.TestCase): for config_file in config_files: with self.subTest(name=config_file.stem): cfg = _load_model_config_from_yaml(str(config_file)) + cfg.render_image_height = 80 + cfg.render_image_width = 80 model = GenericModel(**cfg) model.to(device) self._one_model_test(