Cuda updates

Summary:
Updates to:
- enable cuda kernel launches on any GPU (not just the default)
- cuda and contiguous checks for all kernels
- checks to ensure all tensors are on the same device
- error reporting in the cuda kernels
- cuda tests now run on a random device not just the default

Reviewed By: jcjohnson, gkioxari

Differential Revision: D21215280

fbshipit-source-id: 1bedc9fe6c35e9e920bdc4d78ed12865b1005519
This commit is contained in:
Nikhila Ravi
2020-04-24 09:07:54 -07:00
committed by Facebook GitHub Bot
parent c9267ab7af
commit c3d636dc8c
33 changed files with 979 additions and 240 deletions

View File

@@ -4,7 +4,7 @@ import unittest
import numpy as np
import torch
from common_testing import TestCaseMixin
from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d import _C
from pytorch3d.loss import point_mesh_edge_distance, point_mesh_face_distance
from pytorch3d.structures import Meshes, Pointclouds, packed_to_list
@@ -203,7 +203,7 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
& PointEdgeArrayDistanceBackward
"""
P, E = 16, 32
device = torch.device("cuda:0")
device = get_random_cuda_device()
points = torch.rand((P, 3), dtype=torch.float32, device=device)
edges = torch.rand((E, 2, 3), dtype=torch.float32, device=device)
@@ -246,9 +246,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
Test CUDA implementation for PointEdgeDistanceForward
& PointEdgeDistanceBackward
"""
device = torch.device("cuda:0")
device = get_random_cuda_device()
N, V, F, P = 4, 32, 16, 24
meshes, pcls = self.init_meshes_clouds(N, V, F, P)
meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
# make points packed a leaf node
points_packed = pcls.points_packed().detach().clone() # (P, 3)
@@ -327,9 +327,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
Test CUDA implementation for EdgePointDistanceForward
& EdgePointDistanceBackward
"""
device = torch.device("cuda:0")
device = get_random_cuda_device()
N, V, F, P = 4, 32, 16, 24
meshes, pcls = self.init_meshes_clouds(N, V, F, P)
meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
# make points packed a leaf node
points_packed = pcls.points_packed().detach().clone() # (P, 3)
@@ -409,9 +409,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
"""
Test point_mesh_edge_distance from pytorch3d.loss
"""
device = torch.device("cuda:0")
device = get_random_cuda_device()
N, V, F, P = 4, 32, 16, 24
meshes, pcls = self.init_meshes_clouds(N, V, F, P)
meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
# clone and detach for another backward pass through the op
verts_op = [verts.clone().detach() for verts in meshes.verts_list()]
@@ -480,7 +480,7 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
& PointFaceArrayDistanceBackward
"""
P, T = 16, 32
device = torch.device("cuda:0")
device = get_random_cuda_device()
points = torch.rand((P, 3), dtype=torch.float32, device=device)
tris = torch.rand((T, 3, 3), dtype=torch.float32, device=device)
@@ -525,9 +525,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
Test CUDA implementation for PointFaceDistanceForward
& PointFaceDistanceBackward
"""
device = torch.device("cuda:0")
device = get_random_cuda_device()
N, V, F, P = 4, 32, 16, 24
meshes, pcls = self.init_meshes_clouds(N, V, F, P)
meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
# make points packed a leaf node
points_packed = pcls.points_packed().detach().clone() # (P, 3)
@@ -608,9 +608,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
Test CUDA implementation for FacePointDistanceForward
& FacePointDistanceBackward
"""
device = torch.device("cuda:0")
device = get_random_cuda_device()
N, V, F, P = 4, 32, 16, 24
meshes, pcls = self.init_meshes_clouds(N, V, F, P)
meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
# make points packed a leaf node
points_packed = pcls.points_packed().detach().clone() # (P, 3)
@@ -690,9 +690,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
"""
Test point_mesh_face_distance from pytorch3d.loss
"""
device = torch.device("cuda:0")
device = get_random_cuda_device()
N, V, F, P = 4, 32, 16, 24
meshes, pcls = self.init_meshes_clouds(N, V, F, P)
meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
# clone and detach for another backward pass through the op
verts_op = [verts.clone().detach() for verts in meshes.verts_list()]
@@ -751,7 +751,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
@staticmethod
def point_mesh_edge(N: int, V: int, F: int, P: int, device: str):
device = torch.device(device)
meshes, pcls = TestPointMeshDistance.init_meshes_clouds(N, V, F, P)
meshes, pcls = TestPointMeshDistance.init_meshes_clouds(
N, V, F, P, device=device
)
torch.cuda.synchronize()
def loss():
@@ -763,7 +765,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
@staticmethod
def point_mesh_face(N: int, V: int, F: int, P: int, device: str):
device = torch.device(device)
meshes, pcls = TestPointMeshDistance.init_meshes_clouds(N, V, F, P)
meshes, pcls = TestPointMeshDistance.init_meshes_clouds(
N, V, F, P, device=device
)
torch.cuda.synchronize()
def loss():