mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
cuda 11 problems with test_normal_consistency
Summary: One test hits problems with CUDA 11.1 and pytorch 1.8. This seems to be a known bug, so we just run that test on the cpu in the problematic cases. Note - the full test run is much slower with cuda 11.1 than 10.2, but this is known. Reviewed By: patricklabatut Differential Revision: D28938933 fbshipit-source-id: cf8ed84cd10a0b52d8f4292edbef7bd4844fea65
This commit is contained in:
parent
c710d8c101
commit
7204a4ca64
@ -9,6 +9,17 @@ from pytorch3d.structures.meshes import Meshes
|
|||||||
from pytorch3d.utils.ico_sphere import ico_sphere
|
from pytorch3d.utils.ico_sphere import ico_sphere
|
||||||
|
|
||||||
|
|
||||||
|
IS_TORCH_1_8 = torch.__version__.startswith("1.8.")
|
||||||
|
PROBLEMATIC_CUDA = torch.version.cuda in ("11.0", "11.1")
|
||||||
|
# TODO: There are problems with cuda 11.0 and 11.1 here.
|
||||||
|
# The symptom can be
|
||||||
|
# RuntimeError: radix_sort: failed on 1st step: cudaErrorInvalidDevice: invalid device ordinal
|
||||||
|
# or something like
|
||||||
|
# operator(): block: [0,0,0], thread: [96,0,0]
|
||||||
|
# Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
|
||||||
|
AVOID_LARGE_MESH_CUDA = PROBLEMATIC_CUDA and IS_TORCH_1_8
|
||||||
|
|
||||||
|
|
||||||
class TestMeshNormalConsistency(unittest.TestCase):
|
class TestMeshNormalConsistency(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
@ -37,6 +48,9 @@ class TestMeshNormalConsistency(unittest.TestCase):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_meshes(num_meshes: int = 10, num_verts: int = 1000, num_faces: int = 3000):
|
def init_meshes(num_meshes: int = 10, num_verts: int = 1000, num_faces: int = 3000):
|
||||||
|
if AVOID_LARGE_MESH_CUDA:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
else:
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
valid_faces = TestMeshNormalConsistency.init_faces(num_verts).to(device)
|
valid_faces = TestMeshNormalConsistency.init_faces(num_verts).to(device)
|
||||||
verts_list = []
|
verts_list = []
|
||||||
|
Loading…
x
Reference in New Issue
Block a user