diff --git a/tests/common_testing.py b/tests/common_testing.py index 6c84b2ec..f7a22413 100644 --- a/tests/common_testing.py +++ b/tests/common_testing.py @@ -1,5 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import os import unittest from pathlib import Path from typing import Callable, Optional, Union @@ -19,8 +20,13 @@ def get_tests_dir() -> Path: def get_pytorch3d_dir() -> Path: """ Returns Path for the root PyTorch3D directory. + + Facebook internal systems need a special case here. """ - return get_tests_dir().parent + if os.environ.get("INSIDE_RE_WORKER") is not None: + return Path(__file__).resolve().parent + else: + return Path(__file__).resolve().parent.parent def load_rgb_image(filename: str, data_dir: Union[str, Path]): diff --git a/tests/test_build.py b/tests/test_build.py index 9a3cd2be..e47cd292 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -10,10 +10,11 @@ from common_testing import get_pytorch3d_dir, get_tests_dir # This file groups together tests which look at the code without running it. # When running the tests inside conda's build, the code is not available. in_conda_build = os.environ.get("CONDA_BUILD_STATE", "") == "TEST" +in_re_worker = os.environ.get("INSIDE_RE_WORKER", "") is not None class TestBuild(unittest.TestCase): - @unittest.skipIf(in_conda_build, "In conda build") + @unittest.skipIf(in_conda_build or in_re_worker, "In conda build, or RE worker") def test_name_clash(self): # For setup.py, all translation units need distinct names, so we # cannot have foo.cu and foo.cpp, even in different directories. @@ -29,7 +30,7 @@ class TestBuild(unittest.TestCase): for k, v in counter.items(): self.assertEqual(v, 1, f"Too many files with stem {k}.") - @unittest.skipIf(in_conda_build, "In conda build") + @unittest.skipIf(in_conda_build or in_re_worker, "In conda build, or RE worker") def test_copyright(self): test_dir = get_tests_dir() root_dir = test_dir.parent @@ -61,7 +62,7 @@ class TestBuild(unittest.TestCase): if len(files_missing_copyright_header) != 0: self.fail("\n".join(files_missing_copyright_header)) - @unittest.skipIf(in_conda_build, "In conda build") + @unittest.skipIf(in_conda_build or in_re_worker, "In conda build, or RE worker") def test_valid_ipynbs(self): # Check that the ipython notebooks are valid json root_dir = get_pytorch3d_dir()