diff --git a/tests/test_build.py b/tests/test_build.py index f74ca23c..de252d28 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -118,17 +118,19 @@ class TestBuild(unittest.TestCase): # Check each module of pytorch3d imports cleanly, # which may fail if there are import cycles. - # First check the setup of the test. If any of pytorch3d - # was already imported the test would be pointless. - for module in sys.modules: - self.assertFalse(module.startswith("pytorch3d"), module) + with unittest.mock.patch.dict(sys.modules): + for module in list(sys.modules): + # If any of pytorch3d is already imported, + # the test would be pointless. + if module.startswith("pytorch3d"): + sys.modules.pop(module, None) - root_dir = get_pytorch3d_dir() / "pytorch3d" - for module_file in root_dir.glob("**/*.py"): - if module_file.stem == "__init__": - continue - relative_module = str(module_file.relative_to(root_dir))[:-3] - module = "pytorch3d." + relative_module.replace("/", ".") - with self.subTest(name=module): - with unittest.mock.patch.dict(sys.modules): - importlib.import_module(module) + root_dir = get_pytorch3d_dir() / "pytorch3d" + for module_file in root_dir.glob("**/*.py"): + if module_file.stem in ("__init__", "plotly_vis"): + continue + relative_module = str(module_file.relative_to(root_dir))[:-3] + module = "pytorch3d." + relative_module.replace("/", ".") + with self.subTest(name=module): + with unittest.mock.patch.dict(sys.modules): + importlib.import_module(module)