diff --git a/tests/test_mesh_normal_consistency.py b/tests/test_mesh_normal_consistency.py index 3f2f1a93..23bd1c61 100644 --- a/tests/test_mesh_normal_consistency.py +++ b/tests/test_mesh_normal_consistency.py @@ -9,6 +9,17 @@ from pytorch3d.structures.meshes import Meshes from pytorch3d.utils.ico_sphere import ico_sphere +IS_TORCH_1_8 = torch.__version__.startswith("1.8.") +PROBLEMATIC_CUDA = torch.version.cuda in ("11.0", "11.1") +# TODO: There are problems with cuda 11.0 and 11.1 here. +# The symptom can be +# RuntimeError: radix_sort: failed on 1st step: cudaErrorInvalidDevice: invalid device ordinal +# or something like +# operator(): block: [0,0,0], thread: [96,0,0] +# Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed. +AVOID_LARGE_MESH_CUDA = PROBLEMATIC_CUDA and IS_TORCH_1_8 + + class TestMeshNormalConsistency(unittest.TestCase): def setUp(self) -> None: torch.manual_seed(42) @@ -37,7 +48,10 @@ class TestMeshNormalConsistency(unittest.TestCase): @staticmethod def init_meshes(num_meshes: int = 10, num_verts: int = 1000, num_faces: int = 3000): - device = torch.device("cuda:0") + if AVOID_LARGE_MESH_CUDA: + device = torch.device("cpu") + else: + device = torch.device("cuda:0") valid_faces = TestMeshNormalConsistency.init_faces(num_verts).to(device) verts_list = [] faces_list = []