Cuda updates

Summary:
Updates to:
- enable cuda kernel launches on any GPU (not just the default)
- cuda and contiguous checks for all kernels
- checks to ensure all tensors are on the same device
- error reporting in the cuda kernels
- cuda tests now run on a random device not just the default

Reviewed By: jcjohnson, gkioxari

Differential Revision: D21215280

fbshipit-source-id: 1bedc9fe6c35e9e920bdc4d78ed12865b1005519
This commit is contained in:
Nikhila Ravi
2020-04-24 09:07:54 -07:00
committed by Facebook GitHub Bot
parent c9267ab7af
commit c3d636dc8c
33 changed files with 979 additions and 240 deletions

View File

@@ -8,11 +8,21 @@ from test_chamfer import TestChamfer
def bm_chamfer() -> None:
kwargs_list_naive = [
{"batch_size": 1, "P1": 32, "P2": 64, "return_normals": False},
{"batch_size": 1, "P1": 32, "P2": 64, "return_normals": True},
{"batch_size": 32, "P1": 32, "P2": 64, "return_normals": False},
]
devices = ["cpu"]
if torch.cuda.is_available():
devices.append("cuda:0")
kwargs_list_naive = []
batch_size = [1, 32]
return_normals = [True, False]
test_cases = product(batch_size, return_normals, devices)
for case in test_cases:
b, n, d = case
kwargs_list_naive.append(
{"batch_size": b, "P1": 32, "P2": 64, "return_normals": n, "device": d}
)
benchmark(
TestChamfer.chamfer_naive_with_init,
"CHAMFER_NAIVE",
@@ -21,6 +31,7 @@ def bm_chamfer() -> None:
)
if torch.cuda.is_available():
device = "cuda:0"
kwargs_list = []
batch_size = [1, 32]
P1 = [32, 1000, 10000]
@@ -38,6 +49,7 @@ def bm_chamfer() -> None:
"P2": p2,
"return_normals": n,
"homogeneous": h,
"device": device,
}
)
benchmark(TestChamfer.chamfer_with_init, "CHAMFER", kwargs_list, warmup_iters=1)

View File

@@ -20,6 +20,18 @@ def load_rgb_image(filename: str, data_dir: Union[str, Path]):
TensorOrArray = Union[torch.Tensor, np.ndarray]
def get_random_cuda_device() -> str:
"""
Function to get a random GPU device from the
available devices. This is useful for testing
that custom cuda kernels can support inputs on
any device without having to set the device explicitly.
"""
num_devices = torch.cuda.device_count()
rand_device_id = torch.randint(high=num_devices, size=(1,)).item()
return "cuda:%d" % rand_device_id
class TestCaseMixin(unittest.TestCase):
def assertSeparate(self, tensor1, tensor2) -> None:
"""

View File

@@ -6,7 +6,7 @@ from collections import namedtuple
import numpy as np
import torch
import torch.nn.functional as F
from common_testing import TestCaseMixin
from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d.loss import chamfer_distance
from pytorch3d.structures.pointclouds import Pointclouds
@@ -81,7 +81,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
)
@staticmethod
def chamfer_distance_naive_pointclouds(p1, p2):
def chamfer_distance_naive_pointclouds(p1, p2, device="cpu"):
"""
Naive iterative implementation of nearest neighbor and chamfer distance.
x and y are assumed to be pointclouds objects with points and optionally normals.
@@ -97,7 +97,6 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
x_normals = p1.normals_padded()
y_normals = p2.normals_padded()
device = torch.device("cuda:0")
return_normals = x_normals is not None and y_normals is not None
# Initialize all distances to + inf
@@ -163,7 +162,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
"""
N, P1, D = x.shape
P2 = y.size(1)
device = torch.device("cuda:0")
device = x.device
return_normals = x_normals is not None and y_normals is not None
dist = torch.zeros((N, P1, P2), dtype=torch.float32, device=device)
@@ -203,7 +202,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
This tests only uses homogeneous pointclouds.
"""
N, max_P1, max_P2 = 7, 10, 18
device = "cuda:0"
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
@@ -237,7 +236,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
which supports heterogeneous pointcloud objects.
"""
N, max_P1, max_P2 = 3, 70, 70
device = "cuda:0"
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
weights = points_normals.weights
x_lengths = points_normals.p1_lengths
@@ -256,7 +255,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
# Chamfer with pointclouds as input.
pred_loss, pred_norm_loss = TestChamfer.chamfer_distance_naive_pointclouds(
points_normals.cloud1, points_normals.cloud2
points_normals.cloud1, points_normals.cloud2, device=device
)
# Mean reduction point loss.
@@ -299,7 +298,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
def test_chamfer_pointcloud_object_withnormals(self):
N = 5
P1, P2 = 100, 100
device = "cuda:0"
device = get_random_cuda_device()
reductions = [
("sum", "sum"),
@@ -359,7 +358,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
def test_chamfer_pointcloud_object_nonormals(self):
N = 5
P1, P2 = 100, 100
device = "cuda:0"
device = get_random_cuda_device()
reductions = [
("sum", "sum"),
@@ -415,7 +414,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
for point_reduction = "mean" and batch_reduction = None.
"""
N, max_P1, max_P2 = 7, 10, 18
device = "cuda:0"
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
@@ -464,7 +463,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
for point_reduction = "sum" and batch_reduction = None.
"""
N, P1, P2 = 7, 10, 18
device = "cuda:0"
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
@@ -579,7 +578,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
point_reduction in ["mean", "sum"].
"""
N, max_P1, max_P2 = 7, 10, 18
device = "cuda:0"
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
p1 = points_normals.p1
@@ -681,7 +680,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
def test_incorrect_weights(self):
N, P1, P2 = 16, 64, 128
device = torch.device("cuda:0")
device = get_random_cuda_device()
p1 = torch.rand(
(N, P1, 3), dtype=torch.float32, device=device, requires_grad=True
)
@@ -716,7 +715,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
def test_incorrect_inputs(self):
N, P1, P2 = 7, 10, 18
device = "cuda:0"
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
@@ -740,11 +739,16 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
@staticmethod
def chamfer_with_init(
batch_size: int, P1: int, P2: int, return_normals: bool, homogeneous: bool
batch_size: int,
P1: int,
P2: int,
return_normals: bool,
homogeneous: bool,
device="cpu",
):
p1, p2, p1_normals, p2_normals, weights, l1, l2 = TestChamfer.init_pointclouds(
batch_size, P1, P2
)
points_normals = TestChamfer.init_pointclouds(batch_size, P1, P2, device=device)
l1 = points_normals.p1_lengths
l2 = points_normals.p2_lengths
if homogeneous:
# Set lengths to None so in Chamfer it assumes
# there is no padding.
@@ -754,13 +758,13 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
def loss():
loss, loss_normals = chamfer_distance(
p1,
p2,
points_normals.p1,
points_normals.p2,
x_lengths=l1,
y_lengths=l2,
x_normals=p1_normals,
y_normals=p2_normals,
weights=weights,
x_normals=points_normals.n1,
y_normals=points_normals.n2,
weights=points_normals.weights,
)
torch.cuda.synchronize()
@@ -768,16 +772,17 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
@staticmethod
def chamfer_naive_with_init(
batch_size: int, P1: int, P2: int, return_normals: bool
batch_size: int, P1: int, P2: int, return_normals: bool, device="cpu"
):
p1, p2, p1_normals, p2_normals, weights, _, _ = TestChamfer.init_pointclouds(
batch_size, P1, P2
)
points_normals = TestChamfer.init_pointclouds(batch_size, P1, P2, device=device)
torch.cuda.synchronize()
def loss():
loss, loss_normals = TestChamfer.chamfer_distance_naive(
p1, p2, x_normals=p1_normals, y_normals=p2_normals
points_normals.p1,
points_normals.p2,
x_normals=points_normals.n1,
y_normals=points_normals.n2,
)
torch.cuda.synchronize()

View File

@@ -3,6 +3,7 @@
import unittest
import torch
from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d.renderer.compositing import (
alpha_composite,
norm_weighted_sum,
@@ -10,7 +11,7 @@ from pytorch3d.renderer.compositing import (
)
class TestAccumulatePoints(unittest.TestCase):
class TestAccumulatePoints(TestCaseMixin, unittest.TestCase):
# NAIVE PYTHON IMPLEMENTATIONS (USED FOR TESTING)
@staticmethod
@@ -120,7 +121,7 @@ class TestAccumulatePoints(unittest.TestCase):
self._simple_wsumnorm(norm_weighted_sum, device)
def test_cuda(self):
device = torch.device("cuda:0")
device = get_random_cuda_device()
self._simple_alphacomposite(alpha_composite, device)
self._simple_wsum(weighted_sum, device)
self._simple_wsumnorm(norm_weighted_sum, device)
@@ -142,7 +143,7 @@ class TestAccumulatePoints(unittest.TestCase):
C = 3
P = 32
for d in ["cpu", "cuda"]:
for d in ["cpu", get_random_cuda_device()]:
# TODO(gkioxari) add torch.float64 to types after double precision
# support is added to atomicAdd
for t in [torch.float32]:
@@ -181,7 +182,7 @@ class TestAccumulatePoints(unittest.TestCase):
res1 = fn1(*args1)
res2 = fn2(*args2)
self.assertTrue(torch.allclose(res1.cpu(), res2.cpu(), atol=1e-6))
self.assertClose(res1.cpu(), res2.cpu(), atol=1e-6)
if not compare_grads:
return
@@ -200,7 +201,7 @@ class TestAccumulatePoints(unittest.TestCase):
grads2 = [gradsi.grad.data.clone().cpu() for gradsi in grads2]
for i in range(0, len(grads1)):
self.assertTrue(torch.allclose(grads1[i].cpu(), grads2[i].cpu(), atol=1e-6))
self.assertClose(grads1[i].cpu(), grads2[i].cpu(), atol=1e-6)
def _simple_wsum(self, accum_func, device):
# Initialise variables
@@ -273,7 +274,7 @@ class TestAccumulatePoints(unittest.TestCase):
]
).to(device)
self.assertTrue(torch.allclose(result.cpu(), true_result.cpu(), rtol=1e-3))
self.assertClose(result.cpu(), true_result.cpu(), rtol=1e-3)
def _simple_wsumnorm(self, accum_func, device):
# Initialise variables
@@ -346,7 +347,7 @@ class TestAccumulatePoints(unittest.TestCase):
]
).to(device)
self.assertTrue(torch.allclose(result.cpu(), true_result.cpu(), rtol=1e-3))
self.assertClose(result.cpu(), true_result.cpu(), rtol=1e-3)
def _simple_alphacomposite(self, accum_func, device):
# Initialise variables

View File

@@ -33,7 +33,9 @@ class TestCubify(unittest.TestCase):
# 1st-check
verts, faces = meshes.get_mesh_verts_faces(0)
self.assertTrue(torch.allclose(faces.max(), torch.tensor([verts.size(0) - 1])))
self.assertTrue(
torch.allclose(faces.max().cpu(), torch.tensor([verts.size(0) - 1]))
)
self.assertTrue(
torch.allclose(
verts,
@@ -78,7 +80,9 @@ class TestCubify(unittest.TestCase):
)
# 2nd-check
verts, faces = meshes.get_mesh_verts_faces(1)
self.assertTrue(torch.allclose(faces.max(), torch.tensor([verts.size(0) - 1])))
self.assertTrue(
torch.allclose(faces.max().cpu(), torch.tensor([verts.size(0) - 1]))
)
self.assertTrue(
torch.allclose(
verts,

View File

@@ -4,7 +4,7 @@
import unittest
import torch
from common_testing import TestCaseMixin
from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d.ops import mesh_face_areas_normals
from pytorch3d.structures.meshes import Meshes
@@ -94,13 +94,15 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
self._test_face_areas_normals_helper("cpu")
def test_face_areas_normals_cuda(self):
self._test_face_areas_normals_helper("cuda:0")
device = get_random_cuda_device()
self._test_face_areas_normals_helper(device)
def test_nonfloats_cpu(self):
self._test_face_areas_normals_helper("cpu", dtype=torch.double)
def test_nonfloats_cuda(self):
self._test_face_areas_normals_helper("cuda:0", dtype=torch.double)
device = get_random_cuda_device()
self._test_face_areas_normals_helper(device, dtype=torch.double)
@staticmethod
def face_areas_normals_with_init(

View File

@@ -4,7 +4,7 @@ import unittest
import torch
import torch.nn as nn
from common_testing import TestCaseMixin
from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d import _C
from pytorch3d.ops.graph_conv import GraphConv, gather_scatter, gather_scatter_python
from pytorch3d.structures.meshes import Meshes
@@ -14,7 +14,7 @@ from pytorch3d.utils import ico_sphere
class TestGraphConv(TestCaseMixin, unittest.TestCase):
def test_undirected(self):
dtype = torch.float32
device = torch.device("cuda:0")
device = get_random_cuda_device()
verts = torch.tensor(
[[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype, device=device
)
@@ -97,7 +97,7 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase):
self.assertClose(y, expected_y)
def test_backward(self):
device = torch.device("cuda:0")
device = get_random_cuda_device()
mesh = ico_sphere()
verts = mesh.verts_packed()
edges = mesh.edges_packed()
@@ -118,7 +118,7 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase):
self.assertEqual(repr(conv), "GraphConv(32 -> 64, directed=True)")
def test_cpu_cuda_tensor_error(self):
device = torch.device("cuda:0")
device = get_random_cuda_device()
verts = torch.tensor(
[[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float32, device=device
)
@@ -134,7 +134,7 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase):
Check that gather_scatter cuda version throws an error if cpu tensors
are given as input.
"""
device = torch.device("cuda:0")
device = get_random_cuda_device()
mesh = ico_sphere()
verts = mesh.verts_packed()
edges = mesh.edges_packed()

View File

@@ -4,7 +4,7 @@ import unittest
from itertools import product
import torch
from common_testing import TestCaseMixin
from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d.ops.knn import _KNN, knn_gather, knn_points
@@ -89,7 +89,7 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
self._knn_vs_python_square_helper(device)
def test_knn_vs_python_square_cuda(self):
device = torch.device("cuda:0")
device = get_random_cuda_device()
self._knn_vs_python_square_helper(device)
def _knn_vs_python_ragged_helper(self, device):
@@ -133,11 +133,11 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
self._knn_vs_python_ragged_helper(device)
def test_knn_vs_python_ragged_cuda(self):
device = torch.device("cuda:0")
device = get_random_cuda_device()
self._knn_vs_python_ragged_helper(device)
def test_knn_gather(self):
device = torch.device("cuda:0")
device = get_random_cuda_device()
N, P1, P2, K, D = 4, 16, 12, 8, 3
x = torch.rand((N, P1, D), device=device)
y = torch.rand((N, P2, D), device=device)

View File

@@ -3,7 +3,7 @@
import unittest
import torch
from common_testing import TestCaseMixin
from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d.ops import packed_to_padded, padded_to_packed
from pytorch3d.structures.meshes import Meshes
@@ -126,13 +126,16 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase):
self._test_packed_to_padded_helper(16, "cpu")
def test_packed_to_padded_flat_cuda(self):
self._test_packed_to_padded_helper(0, "cuda:0")
device = get_random_cuda_device()
self._test_packed_to_padded_helper(0, device)
def test_packed_to_padded_D1_cuda(self):
self._test_packed_to_padded_helper(1, "cuda:0")
device = get_random_cuda_device()
self._test_packed_to_padded_helper(1, device)
def test_packed_to_padded_D16_cuda(self):
self._test_packed_to_padded_helper(16, "cuda:0")
device = get_random_cuda_device()
self._test_packed_to_padded_helper(16, device)
def _test_padded_to_packed_helper(self, D, device):
"""
@@ -191,13 +194,16 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase):
self._test_padded_to_packed_helper(16, "cpu")
def test_padded_to_packed_flat_cuda(self):
self._test_padded_to_packed_helper(0, "cuda:0")
device = get_random_cuda_device()
self._test_padded_to_packed_helper(0, device)
def test_padded_to_packed_D1_cuda(self):
self._test_padded_to_packed_helper(1, "cuda:0")
device = get_random_cuda_device()
self._test_padded_to_packed_helper(1, device)
def test_padded_to_packed_D16_cuda(self):
self._test_padded_to_packed_helper(16, "cuda:0")
device = get_random_cuda_device()
self._test_padded_to_packed_helper(16, device)
def test_invalid_inputs_shapes(self, device="cuda:0"):
with self.assertRaisesRegex(ValueError, "input can only be 2-dimensional."):

View File

@@ -4,7 +4,7 @@ import unittest
import numpy as np
import torch
from common_testing import TestCaseMixin
from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d import _C
from pytorch3d.loss import point_mesh_edge_distance, point_mesh_face_distance
from pytorch3d.structures import Meshes, Pointclouds, packed_to_list
@@ -203,7 +203,7 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
& PointEdgeArrayDistanceBackward
"""
P, E = 16, 32
device = torch.device("cuda:0")
device = get_random_cuda_device()
points = torch.rand((P, 3), dtype=torch.float32, device=device)
edges = torch.rand((E, 2, 3), dtype=torch.float32, device=device)
@@ -246,9 +246,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
Test CUDA implementation for PointEdgeDistanceForward
& PointEdgeDistanceBackward
"""
device = torch.device("cuda:0")
device = get_random_cuda_device()
N, V, F, P = 4, 32, 16, 24
meshes, pcls = self.init_meshes_clouds(N, V, F, P)
meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
# make points packed a leaf node
points_packed = pcls.points_packed().detach().clone() # (P, 3)
@@ -327,9 +327,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
Test CUDA implementation for EdgePointDistanceForward
& EdgePointDistanceBackward
"""
device = torch.device("cuda:0")
device = get_random_cuda_device()
N, V, F, P = 4, 32, 16, 24
meshes, pcls = self.init_meshes_clouds(N, V, F, P)
meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
# make points packed a leaf node
points_packed = pcls.points_packed().detach().clone() # (P, 3)
@@ -409,9 +409,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
"""
Test point_mesh_edge_distance from pytorch3d.loss
"""
device = torch.device("cuda:0")
device = get_random_cuda_device()
N, V, F, P = 4, 32, 16, 24
meshes, pcls = self.init_meshes_clouds(N, V, F, P)
meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
# clone and detach for another backward pass through the op
verts_op = [verts.clone().detach() for verts in meshes.verts_list()]
@@ -480,7 +480,7 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
& PointFaceArrayDistanceBackward
"""
P, T = 16, 32
device = torch.device("cuda:0")
device = get_random_cuda_device()
points = torch.rand((P, 3), dtype=torch.float32, device=device)
tris = torch.rand((T, 3, 3), dtype=torch.float32, device=device)
@@ -525,9 +525,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
Test CUDA implementation for PointFaceDistanceForward
& PointFaceDistanceBackward
"""
device = torch.device("cuda:0")
device = get_random_cuda_device()
N, V, F, P = 4, 32, 16, 24
meshes, pcls = self.init_meshes_clouds(N, V, F, P)
meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
# make points packed a leaf node
points_packed = pcls.points_packed().detach().clone() # (P, 3)
@@ -608,9 +608,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
Test CUDA implementation for FacePointDistanceForward
& FacePointDistanceBackward
"""
device = torch.device("cuda:0")
device = get_random_cuda_device()
N, V, F, P = 4, 32, 16, 24
meshes, pcls = self.init_meshes_clouds(N, V, F, P)
meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
# make points packed a leaf node
points_packed = pcls.points_packed().detach().clone() # (P, 3)
@@ -690,9 +690,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
"""
Test point_mesh_face_distance from pytorch3d.loss
"""
device = torch.device("cuda:0")
device = get_random_cuda_device()
N, V, F, P = 4, 32, 16, 24
meshes, pcls = self.init_meshes_clouds(N, V, F, P)
meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
# clone and detach for another backward pass through the op
verts_op = [verts.clone().detach() for verts in meshes.verts_list()]
@@ -751,7 +751,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
@staticmethod
def point_mesh_edge(N: int, V: int, F: int, P: int, device: str):
device = torch.device(device)
meshes, pcls = TestPointMeshDistance.init_meshes_clouds(N, V, F, P)
meshes, pcls = TestPointMeshDistance.init_meshes_clouds(
N, V, F, P, device=device
)
torch.cuda.synchronize()
def loss():
@@ -763,7 +765,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
@staticmethod
def point_mesh_face(N: int, V: int, F: int, P: int, device: str):
device = torch.device(device)
meshes, pcls = TestPointMeshDistance.init_meshes_clouds(N, V, F, P)
meshes, pcls = TestPointMeshDistance.init_meshes_clouds(
N, V, F, P, device=device
)
torch.cuda.synchronize()
def loss():

View File

@@ -4,7 +4,7 @@ import functools
import unittest
import torch
from common_testing import TestCaseMixin
from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d import _C
from pytorch3d.renderer.mesh.rasterize_meshes import (
rasterize_meshes,
@@ -32,7 +32,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
self._test_back_face_culling(rasterize_meshes, device, bin_size=0)
def test_simple_cuda_naive(self):
device = torch.device("cuda:0")
device = get_random_cuda_device()
self._simple_triangle_raster(rasterize_meshes, device, bin_size=0)
self._simple_blurry_raster(rasterize_meshes, device, bin_size=0)
self._test_behind_camera(rasterize_meshes, device, bin_size=0)
@@ -40,7 +40,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
self._test_back_face_culling(rasterize_meshes, device, bin_size=0)
def test_simple_cuda_binned(self):
device = torch.device("cuda:0")
device = get_random_cuda_device()
self._simple_triangle_raster(rasterize_meshes, device, bin_size=5)
self._simple_blurry_raster(rasterize_meshes, device, bin_size=5)
self._test_behind_camera(rasterize_meshes, device, bin_size=5)
@@ -54,7 +54,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
blur_radius = 0.1 ** 2
faces_per_pixel = 3
for d in ["cpu", "cuda"]:
for d in ["cpu", get_random_cuda_device()]:
device = torch.device(d)
compare_grads = True
# Mesh with a single face.
@@ -164,7 +164,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
verts1.requires_grad = True
meshes_cpu = Meshes(verts=[verts1], faces=[faces1])
device = torch.device("cuda:0")
device = get_random_cuda_device()
meshes_cuda = ico_sphere(0, device)
verts2, faces2 = meshes_cuda.get_mesh_verts_faces(0)
verts2.requires_grad = True
@@ -186,7 +186,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
return self._test_coarse_rasterize(torch.device("cpu"))
def test_coarse_cuda(self):
return self._test_coarse_rasterize(torch.device("cuda:0"))
return self._test_coarse_rasterize(get_random_cuda_device())
def test_cpp_vs_cuda_naive_vs_cuda_binned(self):
# Make sure that the backward pass runs for all pathways
@@ -221,7 +221,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
grad1 = verts.grad.data.cpu().clone()
# Option II: CUDA, naive
device = torch.device("cuda:0")
device = get_random_cuda_device()
meshes = ico_sphere(0, device)
verts, faces = meshes.get_mesh_verts_faces(0)
verts.requires_grad = True
@@ -229,9 +229,9 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
args = (meshes, image_size, radius, faces_per_pixel, 0, 0)
idx2, zbuf2, bary2, dist2 = rasterize_meshes(*args)
grad_zbuf = grad_zbuf.cuda()
grad_dist = grad_dist.cuda()
grad_bary = grad_bary.cuda()
grad_zbuf = grad_zbuf.to(device)
grad_dist = grad_dist.to(device)
grad_bary = grad_bary.to(device)
loss = (
(zbuf2 * grad_zbuf).sum()
+ (dist2 * grad_dist).sum()
@@ -244,7 +244,6 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
grad2 = verts.grad.data.cpu().clone()
# Option III: CUDA, binned
device = torch.device("cuda:0")
meshes = ico_sphere(0, device)
verts, faces = meshes.get_mesh_verts_faces(0)
verts.requires_grad = True
@@ -302,7 +301,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
bin_size,
max_faces_per_bin,
)
device = torch.device("cuda:0")
device = get_random_cuda_device()
meshes = meshes.clone().to(device)
faces = meshes.faces_packed()
@@ -356,8 +355,9 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
verts1, faces1 = meshes.get_mesh_verts_faces(0)
verts1.requires_grad = True
meshes1 = Meshes(verts=[verts1], faces=[faces1])
verts2 = verts1.detach().cuda().requires_grad_(True)
faces2 = faces1.detach().clone().cuda()
device = get_random_cuda_device()
verts2 = verts1.detach().to(device).requires_grad_(True)
faces2 = faces1.detach().clone().to(device)
meshes2 = Meshes(verts=[verts2], faces=[faces2])
kwargs = {"image_size": 64, "perspective_correct": True}
@@ -367,7 +367,8 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
self._compare_impls(fn1, fn2, args, args, verts1, verts2, compare_grads=True)
def test_cuda_naive_vs_binned_perspective_correct(self):
meshes = ico_sphere(2, device=torch.device("cuda"))
device = get_random_cuda_device()
meshes = ico_sphere(2, device=device)
verts1, faces1 = meshes.get_mesh_verts_faces(0)
verts1.requires_grad = True
meshes1 = Meshes(verts=[verts1], faces=[faces1])
@@ -1029,7 +1030,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
max_faces_per_bin: int,
):
meshes = ico_sphere(ico_level, torch.device("cuda:0"))
meshes = ico_sphere(ico_level, get_random_cuda_device())
meshes_batch = meshes.extend(num_meshes)
torch.cuda.synchronize()

View File

@@ -5,7 +5,7 @@ import unittest
import numpy as np
import torch
from common_testing import TestCaseMixin
from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d import _C
from pytorch3d.renderer.points.rasterize_points import (
rasterize_points,
@@ -25,7 +25,7 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
self._simple_test_case(rasterize_points, device)
def test_naive_simple_cuda(self):
device = torch.device("cuda")
device = get_random_cuda_device()
self._simple_test_case(rasterize_points, device, bin_size=0)
def test_python_behind_camera(self):
@@ -37,7 +37,8 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
self._test_behind_camera(rasterize_points, torch.device("cpu"))
def test_cuda_behind_camera(self):
self._test_behind_camera(rasterize_points, torch.device("cuda"), bin_size=0)
device = get_random_cuda_device()
self._test_behind_camera(rasterize_points, device, bin_size=0)
def test_cpp_vs_naive_vs_binned(self):
# Make sure that the backward pass runs for all pathways
@@ -373,7 +374,8 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
return self._test_coarse_rasterize(torch.device("cpu"))
def test_coarse_cuda(self):
return self._test_coarse_rasterize(torch.device("cuda"))
device = get_random_cuda_device()
return self._test_coarse_rasterize(device)
def test_compare_coarse_cpu_vs_cuda(self):
torch.manual_seed(231)
@@ -405,7 +407,8 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
)
bp_cpu = _C._rasterize_points_coarse(*args)
pointclouds_cuda = pointclouds.to("cuda:0")
device = get_random_cuda_device()
pointclouds_cuda = pointclouds.to(device)
points_packed = pointclouds_cuda.points_packed()
cloud_to_packed_first_idx = pointclouds_cuda.cloud_to_packed_first_idx()
num_points_per_cloud = pointclouds_cuda.num_points_per_cloud()

View File

@@ -5,7 +5,7 @@ import unittest
from pathlib import Path
import torch
from common_testing import TestCaseMixin
from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.structures.meshes import Meshes
from pytorch3d.utils.ico_sphere import ico_sphere
@@ -42,7 +42,7 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase):
Check sample_points_from_meshes raises an exception if all meshes are
invalid.
"""
device = torch.device("cuda:0")
device = get_random_cuda_device()
verts1 = torch.tensor([], dtype=torch.float32, device=device)
faces1 = torch.tensor([], dtype=torch.int64, device=device)
meshes = Meshes(verts=[verts1, verts1, verts1], faces=[faces1, faces1, faces1])
@@ -56,7 +56,7 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase):
For an ico_sphere, the sampled vertices should lie on a unit sphere.
For an empty mesh, the samples and normals should be 0.
"""
device = torch.device("cuda:0")
device = get_random_cuda_device()
# Unit simplex.
verts_pyramid = torch.tensor(