Cuda updates

Summary:
Updates to:
- enable cuda kernel launches on any GPU (not just the default)
- cuda and contiguous checks for all kernels
- checks to ensure all tensors are on the same device
- error reporting in the cuda kernels
- cuda tests now run on a random device not just the default

Reviewed By: jcjohnson, gkioxari

Differential Revision: D21215280

fbshipit-source-id: 1bedc9fe6c35e9e920bdc4d78ed12865b1005519
This commit is contained in:
Nikhila Ravi
2020-04-24 09:07:54 -07:00
committed by Facebook GitHub Bot
parent c9267ab7af
commit c3d636dc8c
33 changed files with 979 additions and 240 deletions

View File

@@ -6,7 +6,7 @@ from collections import namedtuple
import numpy as np
import torch
import torch.nn.functional as F
from common_testing import TestCaseMixin
from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d.loss import chamfer_distance
from pytorch3d.structures.pointclouds import Pointclouds
@@ -81,7 +81,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
)
@staticmethod
def chamfer_distance_naive_pointclouds(p1, p2):
def chamfer_distance_naive_pointclouds(p1, p2, device="cpu"):
"""
Naive iterative implementation of nearest neighbor and chamfer distance.
x and y are assumed to be pointclouds objects with points and optionally normals.
@@ -97,7 +97,6 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
x_normals = p1.normals_padded()
y_normals = p2.normals_padded()
device = torch.device("cuda:0")
return_normals = x_normals is not None and y_normals is not None
# Initialize all distances to + inf
@@ -163,7 +162,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
"""
N, P1, D = x.shape
P2 = y.size(1)
device = torch.device("cuda:0")
device = x.device
return_normals = x_normals is not None and y_normals is not None
dist = torch.zeros((N, P1, P2), dtype=torch.float32, device=device)
@@ -203,7 +202,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
This tests only uses homogeneous pointclouds.
"""
N, max_P1, max_P2 = 7, 10, 18
device = "cuda:0"
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
@@ -237,7 +236,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
which supports heterogeneous pointcloud objects.
"""
N, max_P1, max_P2 = 3, 70, 70
device = "cuda:0"
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
weights = points_normals.weights
x_lengths = points_normals.p1_lengths
@@ -256,7 +255,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
# Chamfer with pointclouds as input.
pred_loss, pred_norm_loss = TestChamfer.chamfer_distance_naive_pointclouds(
points_normals.cloud1, points_normals.cloud2
points_normals.cloud1, points_normals.cloud2, device=device
)
# Mean reduction point loss.
@@ -299,7 +298,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
def test_chamfer_pointcloud_object_withnormals(self):
N = 5
P1, P2 = 100, 100
device = "cuda:0"
device = get_random_cuda_device()
reductions = [
("sum", "sum"),
@@ -359,7 +358,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
def test_chamfer_pointcloud_object_nonormals(self):
N = 5
P1, P2 = 100, 100
device = "cuda:0"
device = get_random_cuda_device()
reductions = [
("sum", "sum"),
@@ -415,7 +414,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
for point_reduction = "mean" and batch_reduction = None.
"""
N, max_P1, max_P2 = 7, 10, 18
device = "cuda:0"
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
@@ -464,7 +463,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
for point_reduction = "sum" and batch_reduction = None.
"""
N, P1, P2 = 7, 10, 18
device = "cuda:0"
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
@@ -579,7 +578,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
point_reduction in ["mean", "sum"].
"""
N, max_P1, max_P2 = 7, 10, 18
device = "cuda:0"
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
p1 = points_normals.p1
@@ -681,7 +680,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
def test_incorrect_weights(self):
N, P1, P2 = 16, 64, 128
device = torch.device("cuda:0")
device = get_random_cuda_device()
p1 = torch.rand(
(N, P1, 3), dtype=torch.float32, device=device, requires_grad=True
)
@@ -716,7 +715,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
def test_incorrect_inputs(self):
N, P1, P2 = 7, 10, 18
device = "cuda:0"
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
@@ -740,11 +739,16 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
@staticmethod
def chamfer_with_init(
batch_size: int, P1: int, P2: int, return_normals: bool, homogeneous: bool
batch_size: int,
P1: int,
P2: int,
return_normals: bool,
homogeneous: bool,
device="cpu",
):
p1, p2, p1_normals, p2_normals, weights, l1, l2 = TestChamfer.init_pointclouds(
batch_size, P1, P2
)
points_normals = TestChamfer.init_pointclouds(batch_size, P1, P2, device=device)
l1 = points_normals.p1_lengths
l2 = points_normals.p2_lengths
if homogeneous:
# Set lengths to None so in Chamfer it assumes
# there is no padding.
@@ -754,13 +758,13 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
def loss():
loss, loss_normals = chamfer_distance(
p1,
p2,
points_normals.p1,
points_normals.p2,
x_lengths=l1,
y_lengths=l2,
x_normals=p1_normals,
y_normals=p2_normals,
weights=weights,
x_normals=points_normals.n1,
y_normals=points_normals.n2,
weights=points_normals.weights,
)
torch.cuda.synchronize()
@@ -768,16 +772,17 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
@staticmethod
def chamfer_naive_with_init(
batch_size: int, P1: int, P2: int, return_normals: bool
batch_size: int, P1: int, P2: int, return_normals: bool, device="cpu"
):
p1, p2, p1_normals, p2_normals, weights, _, _ = TestChamfer.init_pointclouds(
batch_size, P1, P2
)
points_normals = TestChamfer.init_pointclouds(batch_size, P1, P2, device=device)
torch.cuda.synchronize()
def loss():
loss, loss_normals = TestChamfer.chamfer_distance_naive(
p1, p2, x_normals=p1_normals, y_normals=p2_normals
points_normals.p1,
points_normals.p2,
x_normals=points_normals.n1,
y_normals=points_normals.n2,
)
torch.cuda.synchronize()