mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 22:00:35 +08:00
Cuda updates
Summary: Updates to: - enable cuda kernel launches on any GPU (not just the default) - cuda and contiguous checks for all kernels - checks to ensure all tensors are on the same device - error reporting in the cuda kernels - cuda tests now run on a random device not just the default Reviewed By: jcjohnson, gkioxari Differential Revision: D21215280 fbshipit-source-id: 1bedc9fe6c35e9e920bdc4d78ed12865b1005519
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c9267ab7af
commit
c3d636dc8c
@@ -8,11 +8,21 @@ from test_chamfer import TestChamfer
|
||||
|
||||
|
||||
def bm_chamfer() -> None:
|
||||
kwargs_list_naive = [
|
||||
{"batch_size": 1, "P1": 32, "P2": 64, "return_normals": False},
|
||||
{"batch_size": 1, "P1": 32, "P2": 64, "return_normals": True},
|
||||
{"batch_size": 32, "P1": 32, "P2": 64, "return_normals": False},
|
||||
]
|
||||
devices = ["cpu"]
|
||||
if torch.cuda.is_available():
|
||||
devices.append("cuda:0")
|
||||
|
||||
kwargs_list_naive = []
|
||||
batch_size = [1, 32]
|
||||
return_normals = [True, False]
|
||||
test_cases = product(batch_size, return_normals, devices)
|
||||
|
||||
for case in test_cases:
|
||||
b, n, d = case
|
||||
kwargs_list_naive.append(
|
||||
{"batch_size": b, "P1": 32, "P2": 64, "return_normals": n, "device": d}
|
||||
)
|
||||
|
||||
benchmark(
|
||||
TestChamfer.chamfer_naive_with_init,
|
||||
"CHAMFER_NAIVE",
|
||||
@@ -21,6 +31,7 @@ def bm_chamfer() -> None:
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda:0"
|
||||
kwargs_list = []
|
||||
batch_size = [1, 32]
|
||||
P1 = [32, 1000, 10000]
|
||||
@@ -38,6 +49,7 @@ def bm_chamfer() -> None:
|
||||
"P2": p2,
|
||||
"return_normals": n,
|
||||
"homogeneous": h,
|
||||
"device": device,
|
||||
}
|
||||
)
|
||||
benchmark(TestChamfer.chamfer_with_init, "CHAMFER", kwargs_list, warmup_iters=1)
|
||||
|
||||
@@ -20,6 +20,18 @@ def load_rgb_image(filename: str, data_dir: Union[str, Path]):
|
||||
TensorOrArray = Union[torch.Tensor, np.ndarray]
|
||||
|
||||
|
||||
def get_random_cuda_device() -> str:
|
||||
"""
|
||||
Function to get a random GPU device from the
|
||||
available devices. This is useful for testing
|
||||
that custom cuda kernels can support inputs on
|
||||
any device without having to set the device explicitly.
|
||||
"""
|
||||
num_devices = torch.cuda.device_count()
|
||||
rand_device_id = torch.randint(high=num_devices, size=(1,)).item()
|
||||
return "cuda:%d" % rand_device_id
|
||||
|
||||
|
||||
class TestCaseMixin(unittest.TestCase):
|
||||
def assertSeparate(self, tensor1, tensor2) -> None:
|
||||
"""
|
||||
|
||||
@@ -6,7 +6,7 @@ from collections import namedtuple
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from common_testing import TestCaseMixin
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d.loss import chamfer_distance
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
|
||||
@@ -81,7 +81,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def chamfer_distance_naive_pointclouds(p1, p2):
|
||||
def chamfer_distance_naive_pointclouds(p1, p2, device="cpu"):
|
||||
"""
|
||||
Naive iterative implementation of nearest neighbor and chamfer distance.
|
||||
x and y are assumed to be pointclouds objects with points and optionally normals.
|
||||
@@ -97,7 +97,6 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
x_normals = p1.normals_padded()
|
||||
y_normals = p2.normals_padded()
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
return_normals = x_normals is not None and y_normals is not None
|
||||
|
||||
# Initialize all distances to + inf
|
||||
@@ -163,7 +162,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
"""
|
||||
N, P1, D = x.shape
|
||||
P2 = y.size(1)
|
||||
device = torch.device("cuda:0")
|
||||
device = x.device
|
||||
return_normals = x_normals is not None and y_normals is not None
|
||||
dist = torch.zeros((N, P1, P2), dtype=torch.float32, device=device)
|
||||
|
||||
@@ -203,7 +202,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
This tests only uses homogeneous pointclouds.
|
||||
"""
|
||||
N, max_P1, max_P2 = 7, 10, 18
|
||||
device = "cuda:0"
|
||||
device = get_random_cuda_device()
|
||||
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
|
||||
p1 = points_normals.p1
|
||||
p2 = points_normals.p2
|
||||
@@ -237,7 +236,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
which supports heterogeneous pointcloud objects.
|
||||
"""
|
||||
N, max_P1, max_P2 = 3, 70, 70
|
||||
device = "cuda:0"
|
||||
device = get_random_cuda_device()
|
||||
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
|
||||
weights = points_normals.weights
|
||||
x_lengths = points_normals.p1_lengths
|
||||
@@ -256,7 +255,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
# Chamfer with pointclouds as input.
|
||||
pred_loss, pred_norm_loss = TestChamfer.chamfer_distance_naive_pointclouds(
|
||||
points_normals.cloud1, points_normals.cloud2
|
||||
points_normals.cloud1, points_normals.cloud2, device=device
|
||||
)
|
||||
|
||||
# Mean reduction point loss.
|
||||
@@ -299,7 +298,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
def test_chamfer_pointcloud_object_withnormals(self):
|
||||
N = 5
|
||||
P1, P2 = 100, 100
|
||||
device = "cuda:0"
|
||||
device = get_random_cuda_device()
|
||||
|
||||
reductions = [
|
||||
("sum", "sum"),
|
||||
@@ -359,7 +358,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
def test_chamfer_pointcloud_object_nonormals(self):
|
||||
N = 5
|
||||
P1, P2 = 100, 100
|
||||
device = "cuda:0"
|
||||
device = get_random_cuda_device()
|
||||
|
||||
reductions = [
|
||||
("sum", "sum"),
|
||||
@@ -415,7 +414,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
for point_reduction = "mean" and batch_reduction = None.
|
||||
"""
|
||||
N, max_P1, max_P2 = 7, 10, 18
|
||||
device = "cuda:0"
|
||||
device = get_random_cuda_device()
|
||||
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
|
||||
p1 = points_normals.p1
|
||||
p2 = points_normals.p2
|
||||
@@ -464,7 +463,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
for point_reduction = "sum" and batch_reduction = None.
|
||||
"""
|
||||
N, P1, P2 = 7, 10, 18
|
||||
device = "cuda:0"
|
||||
device = get_random_cuda_device()
|
||||
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
|
||||
p1 = points_normals.p1
|
||||
p2 = points_normals.p2
|
||||
@@ -579,7 +578,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
point_reduction in ["mean", "sum"].
|
||||
"""
|
||||
N, max_P1, max_P2 = 7, 10, 18
|
||||
device = "cuda:0"
|
||||
device = get_random_cuda_device()
|
||||
|
||||
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
|
||||
p1 = points_normals.p1
|
||||
@@ -681,7 +680,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
def test_incorrect_weights(self):
|
||||
N, P1, P2 = 16, 64, 128
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
p1 = torch.rand(
|
||||
(N, P1, 3), dtype=torch.float32, device=device, requires_grad=True
|
||||
)
|
||||
@@ -716,7 +715,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
def test_incorrect_inputs(self):
|
||||
N, P1, P2 = 7, 10, 18
|
||||
device = "cuda:0"
|
||||
device = get_random_cuda_device()
|
||||
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
|
||||
p1 = points_normals.p1
|
||||
p2 = points_normals.p2
|
||||
@@ -740,11 +739,16 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
@staticmethod
|
||||
def chamfer_with_init(
|
||||
batch_size: int, P1: int, P2: int, return_normals: bool, homogeneous: bool
|
||||
batch_size: int,
|
||||
P1: int,
|
||||
P2: int,
|
||||
return_normals: bool,
|
||||
homogeneous: bool,
|
||||
device="cpu",
|
||||
):
|
||||
p1, p2, p1_normals, p2_normals, weights, l1, l2 = TestChamfer.init_pointclouds(
|
||||
batch_size, P1, P2
|
||||
)
|
||||
points_normals = TestChamfer.init_pointclouds(batch_size, P1, P2, device=device)
|
||||
l1 = points_normals.p1_lengths
|
||||
l2 = points_normals.p2_lengths
|
||||
if homogeneous:
|
||||
# Set lengths to None so in Chamfer it assumes
|
||||
# there is no padding.
|
||||
@@ -754,13 +758,13 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
def loss():
|
||||
loss, loss_normals = chamfer_distance(
|
||||
p1,
|
||||
p2,
|
||||
points_normals.p1,
|
||||
points_normals.p2,
|
||||
x_lengths=l1,
|
||||
y_lengths=l2,
|
||||
x_normals=p1_normals,
|
||||
y_normals=p2_normals,
|
||||
weights=weights,
|
||||
x_normals=points_normals.n1,
|
||||
y_normals=points_normals.n2,
|
||||
weights=points_normals.weights,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@@ -768,16 +772,17 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
@staticmethod
|
||||
def chamfer_naive_with_init(
|
||||
batch_size: int, P1: int, P2: int, return_normals: bool
|
||||
batch_size: int, P1: int, P2: int, return_normals: bool, device="cpu"
|
||||
):
|
||||
p1, p2, p1_normals, p2_normals, weights, _, _ = TestChamfer.init_pointclouds(
|
||||
batch_size, P1, P2
|
||||
)
|
||||
points_normals = TestChamfer.init_pointclouds(batch_size, P1, P2, device=device)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def loss():
|
||||
loss, loss_normals = TestChamfer.chamfer_distance_naive(
|
||||
p1, p2, x_normals=p1_normals, y_normals=p2_normals
|
||||
points_normals.p1,
|
||||
points_normals.p2,
|
||||
x_normals=points_normals.n1,
|
||||
y_normals=points_normals.n2,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d.renderer.compositing import (
|
||||
alpha_composite,
|
||||
norm_weighted_sum,
|
||||
@@ -10,7 +11,7 @@ from pytorch3d.renderer.compositing import (
|
||||
)
|
||||
|
||||
|
||||
class TestAccumulatePoints(unittest.TestCase):
|
||||
class TestAccumulatePoints(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
# NAIVE PYTHON IMPLEMENTATIONS (USED FOR TESTING)
|
||||
@staticmethod
|
||||
@@ -120,7 +121,7 @@ class TestAccumulatePoints(unittest.TestCase):
|
||||
self._simple_wsumnorm(norm_weighted_sum, device)
|
||||
|
||||
def test_cuda(self):
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
self._simple_alphacomposite(alpha_composite, device)
|
||||
self._simple_wsum(weighted_sum, device)
|
||||
self._simple_wsumnorm(norm_weighted_sum, device)
|
||||
@@ -142,7 +143,7 @@ class TestAccumulatePoints(unittest.TestCase):
|
||||
C = 3
|
||||
P = 32
|
||||
|
||||
for d in ["cpu", "cuda"]:
|
||||
for d in ["cpu", get_random_cuda_device()]:
|
||||
# TODO(gkioxari) add torch.float64 to types after double precision
|
||||
# support is added to atomicAdd
|
||||
for t in [torch.float32]:
|
||||
@@ -181,7 +182,7 @@ class TestAccumulatePoints(unittest.TestCase):
|
||||
res1 = fn1(*args1)
|
||||
res2 = fn2(*args2)
|
||||
|
||||
self.assertTrue(torch.allclose(res1.cpu(), res2.cpu(), atol=1e-6))
|
||||
self.assertClose(res1.cpu(), res2.cpu(), atol=1e-6)
|
||||
|
||||
if not compare_grads:
|
||||
return
|
||||
@@ -200,7 +201,7 @@ class TestAccumulatePoints(unittest.TestCase):
|
||||
grads2 = [gradsi.grad.data.clone().cpu() for gradsi in grads2]
|
||||
|
||||
for i in range(0, len(grads1)):
|
||||
self.assertTrue(torch.allclose(grads1[i].cpu(), grads2[i].cpu(), atol=1e-6))
|
||||
self.assertClose(grads1[i].cpu(), grads2[i].cpu(), atol=1e-6)
|
||||
|
||||
def _simple_wsum(self, accum_func, device):
|
||||
# Initialise variables
|
||||
@@ -273,7 +274,7 @@ class TestAccumulatePoints(unittest.TestCase):
|
||||
]
|
||||
).to(device)
|
||||
|
||||
self.assertTrue(torch.allclose(result.cpu(), true_result.cpu(), rtol=1e-3))
|
||||
self.assertClose(result.cpu(), true_result.cpu(), rtol=1e-3)
|
||||
|
||||
def _simple_wsumnorm(self, accum_func, device):
|
||||
# Initialise variables
|
||||
@@ -346,7 +347,7 @@ class TestAccumulatePoints(unittest.TestCase):
|
||||
]
|
||||
).to(device)
|
||||
|
||||
self.assertTrue(torch.allclose(result.cpu(), true_result.cpu(), rtol=1e-3))
|
||||
self.assertClose(result.cpu(), true_result.cpu(), rtol=1e-3)
|
||||
|
||||
def _simple_alphacomposite(self, accum_func, device):
|
||||
# Initialise variables
|
||||
|
||||
@@ -33,7 +33,9 @@ class TestCubify(unittest.TestCase):
|
||||
|
||||
# 1st-check
|
||||
verts, faces = meshes.get_mesh_verts_faces(0)
|
||||
self.assertTrue(torch.allclose(faces.max(), torch.tensor([verts.size(0) - 1])))
|
||||
self.assertTrue(
|
||||
torch.allclose(faces.max().cpu(), torch.tensor([verts.size(0) - 1]))
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
verts,
|
||||
@@ -78,7 +80,9 @@ class TestCubify(unittest.TestCase):
|
||||
)
|
||||
# 2nd-check
|
||||
verts, faces = meshes.get_mesh_verts_faces(1)
|
||||
self.assertTrue(torch.allclose(faces.max(), torch.tensor([verts.size(0) - 1])))
|
||||
self.assertTrue(
|
||||
torch.allclose(faces.max().cpu(), torch.tensor([verts.size(0) - 1]))
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
verts,
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d.ops import mesh_face_areas_normals
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
|
||||
@@ -94,13 +94,15 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
|
||||
self._test_face_areas_normals_helper("cpu")
|
||||
|
||||
def test_face_areas_normals_cuda(self):
|
||||
self._test_face_areas_normals_helper("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
self._test_face_areas_normals_helper(device)
|
||||
|
||||
def test_nonfloats_cpu(self):
|
||||
self._test_face_areas_normals_helper("cpu", dtype=torch.double)
|
||||
|
||||
def test_nonfloats_cuda(self):
|
||||
self._test_face_areas_normals_helper("cuda:0", dtype=torch.double)
|
||||
device = get_random_cuda_device()
|
||||
self._test_face_areas_normals_helper(device, dtype=torch.double)
|
||||
|
||||
@staticmethod
|
||||
def face_areas_normals_with_init(
|
||||
|
||||
@@ -4,7 +4,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from common_testing import TestCaseMixin
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.ops.graph_conv import GraphConv, gather_scatter, gather_scatter_python
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
@@ -14,7 +14,7 @@ from pytorch3d.utils import ico_sphere
|
||||
class TestGraphConv(TestCaseMixin, unittest.TestCase):
|
||||
def test_undirected(self):
|
||||
dtype = torch.float32
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
verts = torch.tensor(
|
||||
[[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype, device=device
|
||||
)
|
||||
@@ -97,7 +97,7 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(y, expected_y)
|
||||
|
||||
def test_backward(self):
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
mesh = ico_sphere()
|
||||
verts = mesh.verts_packed()
|
||||
edges = mesh.edges_packed()
|
||||
@@ -118,7 +118,7 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(repr(conv), "GraphConv(32 -> 64, directed=True)")
|
||||
|
||||
def test_cpu_cuda_tensor_error(self):
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
verts = torch.tensor(
|
||||
[[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float32, device=device
|
||||
)
|
||||
@@ -134,7 +134,7 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase):
|
||||
Check that gather_scatter cuda version throws an error if cpu tensors
|
||||
are given as input.
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
mesh = ico_sphere()
|
||||
verts = mesh.verts_packed()
|
||||
edges = mesh.edges_packed()
|
||||
|
||||
@@ -4,7 +4,7 @@ import unittest
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d.ops.knn import _KNN, knn_gather, knn_points
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
|
||||
self._knn_vs_python_square_helper(device)
|
||||
|
||||
def test_knn_vs_python_square_cuda(self):
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
self._knn_vs_python_square_helper(device)
|
||||
|
||||
def _knn_vs_python_ragged_helper(self, device):
|
||||
@@ -133,11 +133,11 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
|
||||
self._knn_vs_python_ragged_helper(device)
|
||||
|
||||
def test_knn_vs_python_ragged_cuda(self):
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
self._knn_vs_python_ragged_helper(device)
|
||||
|
||||
def test_knn_gather(self):
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
N, P1, P2, K, D = 4, 16, 12, 8, 3
|
||||
x = torch.rand((N, P1, D), device=device)
|
||||
y = torch.rand((N, P2, D), device=device)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d.ops import packed_to_padded, padded_to_packed
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
|
||||
@@ -126,13 +126,16 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase):
|
||||
self._test_packed_to_padded_helper(16, "cpu")
|
||||
|
||||
def test_packed_to_padded_flat_cuda(self):
|
||||
self._test_packed_to_padded_helper(0, "cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
self._test_packed_to_padded_helper(0, device)
|
||||
|
||||
def test_packed_to_padded_D1_cuda(self):
|
||||
self._test_packed_to_padded_helper(1, "cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
self._test_packed_to_padded_helper(1, device)
|
||||
|
||||
def test_packed_to_padded_D16_cuda(self):
|
||||
self._test_packed_to_padded_helper(16, "cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
self._test_packed_to_padded_helper(16, device)
|
||||
|
||||
def _test_padded_to_packed_helper(self, D, device):
|
||||
"""
|
||||
@@ -191,13 +194,16 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase):
|
||||
self._test_padded_to_packed_helper(16, "cpu")
|
||||
|
||||
def test_padded_to_packed_flat_cuda(self):
|
||||
self._test_padded_to_packed_helper(0, "cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
self._test_padded_to_packed_helper(0, device)
|
||||
|
||||
def test_padded_to_packed_D1_cuda(self):
|
||||
self._test_padded_to_packed_helper(1, "cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
self._test_padded_to_packed_helper(1, device)
|
||||
|
||||
def test_padded_to_packed_D16_cuda(self):
|
||||
self._test_padded_to_packed_helper(16, "cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
self._test_padded_to_packed_helper(16, device)
|
||||
|
||||
def test_invalid_inputs_shapes(self, device="cuda:0"):
|
||||
with self.assertRaisesRegex(ValueError, "input can only be 2-dimensional."):
|
||||
|
||||
@@ -4,7 +4,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.loss import point_mesh_edge_distance, point_mesh_face_distance
|
||||
from pytorch3d.structures import Meshes, Pointclouds, packed_to_list
|
||||
@@ -203,7 +203,7 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
|
||||
& PointEdgeArrayDistanceBackward
|
||||
"""
|
||||
P, E = 16, 32
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
points = torch.rand((P, 3), dtype=torch.float32, device=device)
|
||||
edges = torch.rand((E, 2, 3), dtype=torch.float32, device=device)
|
||||
|
||||
@@ -246,9 +246,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
|
||||
Test CUDA implementation for PointEdgeDistanceForward
|
||||
& PointEdgeDistanceBackward
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
N, V, F, P = 4, 32, 16, 24
|
||||
meshes, pcls = self.init_meshes_clouds(N, V, F, P)
|
||||
meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
|
||||
|
||||
# make points packed a leaf node
|
||||
points_packed = pcls.points_packed().detach().clone() # (P, 3)
|
||||
@@ -327,9 +327,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
|
||||
Test CUDA implementation for EdgePointDistanceForward
|
||||
& EdgePointDistanceBackward
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
N, V, F, P = 4, 32, 16, 24
|
||||
meshes, pcls = self.init_meshes_clouds(N, V, F, P)
|
||||
meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
|
||||
|
||||
# make points packed a leaf node
|
||||
points_packed = pcls.points_packed().detach().clone() # (P, 3)
|
||||
@@ -409,9 +409,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
|
||||
"""
|
||||
Test point_mesh_edge_distance from pytorch3d.loss
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
N, V, F, P = 4, 32, 16, 24
|
||||
meshes, pcls = self.init_meshes_clouds(N, V, F, P)
|
||||
meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
|
||||
|
||||
# clone and detach for another backward pass through the op
|
||||
verts_op = [verts.clone().detach() for verts in meshes.verts_list()]
|
||||
@@ -480,7 +480,7 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
|
||||
& PointFaceArrayDistanceBackward
|
||||
"""
|
||||
P, T = 16, 32
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
points = torch.rand((P, 3), dtype=torch.float32, device=device)
|
||||
tris = torch.rand((T, 3, 3), dtype=torch.float32, device=device)
|
||||
|
||||
@@ -525,9 +525,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
|
||||
Test CUDA implementation for PointFaceDistanceForward
|
||||
& PointFaceDistanceBackward
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
N, V, F, P = 4, 32, 16, 24
|
||||
meshes, pcls = self.init_meshes_clouds(N, V, F, P)
|
||||
meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
|
||||
|
||||
# make points packed a leaf node
|
||||
points_packed = pcls.points_packed().detach().clone() # (P, 3)
|
||||
@@ -608,9 +608,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
|
||||
Test CUDA implementation for FacePointDistanceForward
|
||||
& FacePointDistanceBackward
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
N, V, F, P = 4, 32, 16, 24
|
||||
meshes, pcls = self.init_meshes_clouds(N, V, F, P)
|
||||
meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
|
||||
|
||||
# make points packed a leaf node
|
||||
points_packed = pcls.points_packed().detach().clone() # (P, 3)
|
||||
@@ -690,9 +690,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
|
||||
"""
|
||||
Test point_mesh_face_distance from pytorch3d.loss
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
N, V, F, P = 4, 32, 16, 24
|
||||
meshes, pcls = self.init_meshes_clouds(N, V, F, P)
|
||||
meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
|
||||
|
||||
# clone and detach for another backward pass through the op
|
||||
verts_op = [verts.clone().detach() for verts in meshes.verts_list()]
|
||||
@@ -751,7 +751,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
|
||||
@staticmethod
|
||||
def point_mesh_edge(N: int, V: int, F: int, P: int, device: str):
|
||||
device = torch.device(device)
|
||||
meshes, pcls = TestPointMeshDistance.init_meshes_clouds(N, V, F, P)
|
||||
meshes, pcls = TestPointMeshDistance.init_meshes_clouds(
|
||||
N, V, F, P, device=device
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def loss():
|
||||
@@ -763,7 +765,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
|
||||
@staticmethod
|
||||
def point_mesh_face(N: int, V: int, F: int, P: int, device: str):
|
||||
device = torch.device(device)
|
||||
meshes, pcls = TestPointMeshDistance.init_meshes_clouds(N, V, F, P)
|
||||
meshes, pcls = TestPointMeshDistance.init_meshes_clouds(
|
||||
N, V, F, P, device=device
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def loss():
|
||||
|
||||
@@ -4,7 +4,7 @@ import functools
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.renderer.mesh.rasterize_meshes import (
|
||||
rasterize_meshes,
|
||||
@@ -32,7 +32,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
self._test_back_face_culling(rasterize_meshes, device, bin_size=0)
|
||||
|
||||
def test_simple_cuda_naive(self):
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
self._simple_triangle_raster(rasterize_meshes, device, bin_size=0)
|
||||
self._simple_blurry_raster(rasterize_meshes, device, bin_size=0)
|
||||
self._test_behind_camera(rasterize_meshes, device, bin_size=0)
|
||||
@@ -40,7 +40,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
self._test_back_face_culling(rasterize_meshes, device, bin_size=0)
|
||||
|
||||
def test_simple_cuda_binned(self):
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
self._simple_triangle_raster(rasterize_meshes, device, bin_size=5)
|
||||
self._simple_blurry_raster(rasterize_meshes, device, bin_size=5)
|
||||
self._test_behind_camera(rasterize_meshes, device, bin_size=5)
|
||||
@@ -54,7 +54,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
blur_radius = 0.1 ** 2
|
||||
faces_per_pixel = 3
|
||||
|
||||
for d in ["cpu", "cuda"]:
|
||||
for d in ["cpu", get_random_cuda_device()]:
|
||||
device = torch.device(d)
|
||||
compare_grads = True
|
||||
# Mesh with a single face.
|
||||
@@ -164,7 +164,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
verts1.requires_grad = True
|
||||
meshes_cpu = Meshes(verts=[verts1], faces=[faces1])
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
meshes_cuda = ico_sphere(0, device)
|
||||
verts2, faces2 = meshes_cuda.get_mesh_verts_faces(0)
|
||||
verts2.requires_grad = True
|
||||
@@ -186,7 +186,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
return self._test_coarse_rasterize(torch.device("cpu"))
|
||||
|
||||
def test_coarse_cuda(self):
|
||||
return self._test_coarse_rasterize(torch.device("cuda:0"))
|
||||
return self._test_coarse_rasterize(get_random_cuda_device())
|
||||
|
||||
def test_cpp_vs_cuda_naive_vs_cuda_binned(self):
|
||||
# Make sure that the backward pass runs for all pathways
|
||||
@@ -221,7 +221,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
grad1 = verts.grad.data.cpu().clone()
|
||||
|
||||
# Option II: CUDA, naive
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
meshes = ico_sphere(0, device)
|
||||
verts, faces = meshes.get_mesh_verts_faces(0)
|
||||
verts.requires_grad = True
|
||||
@@ -229,9 +229,9 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
args = (meshes, image_size, radius, faces_per_pixel, 0, 0)
|
||||
idx2, zbuf2, bary2, dist2 = rasterize_meshes(*args)
|
||||
grad_zbuf = grad_zbuf.cuda()
|
||||
grad_dist = grad_dist.cuda()
|
||||
grad_bary = grad_bary.cuda()
|
||||
grad_zbuf = grad_zbuf.to(device)
|
||||
grad_dist = grad_dist.to(device)
|
||||
grad_bary = grad_bary.to(device)
|
||||
loss = (
|
||||
(zbuf2 * grad_zbuf).sum()
|
||||
+ (dist2 * grad_dist).sum()
|
||||
@@ -244,7 +244,6 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
grad2 = verts.grad.data.cpu().clone()
|
||||
|
||||
# Option III: CUDA, binned
|
||||
device = torch.device("cuda:0")
|
||||
meshes = ico_sphere(0, device)
|
||||
verts, faces = meshes.get_mesh_verts_faces(0)
|
||||
verts.requires_grad = True
|
||||
@@ -302,7 +301,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
bin_size,
|
||||
max_faces_per_bin,
|
||||
)
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
meshes = meshes.clone().to(device)
|
||||
|
||||
faces = meshes.faces_packed()
|
||||
@@ -356,8 +355,9 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
verts1, faces1 = meshes.get_mesh_verts_faces(0)
|
||||
verts1.requires_grad = True
|
||||
meshes1 = Meshes(verts=[verts1], faces=[faces1])
|
||||
verts2 = verts1.detach().cuda().requires_grad_(True)
|
||||
faces2 = faces1.detach().clone().cuda()
|
||||
device = get_random_cuda_device()
|
||||
verts2 = verts1.detach().to(device).requires_grad_(True)
|
||||
faces2 = faces1.detach().clone().to(device)
|
||||
meshes2 = Meshes(verts=[verts2], faces=[faces2])
|
||||
|
||||
kwargs = {"image_size": 64, "perspective_correct": True}
|
||||
@@ -367,7 +367,8 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
self._compare_impls(fn1, fn2, args, args, verts1, verts2, compare_grads=True)
|
||||
|
||||
def test_cuda_naive_vs_binned_perspective_correct(self):
|
||||
meshes = ico_sphere(2, device=torch.device("cuda"))
|
||||
device = get_random_cuda_device()
|
||||
meshes = ico_sphere(2, device=device)
|
||||
verts1, faces1 = meshes.get_mesh_verts_faces(0)
|
||||
verts1.requires_grad = True
|
||||
meshes1 = Meshes(verts=[verts1], faces=[faces1])
|
||||
@@ -1029,7 +1030,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
max_faces_per_bin: int,
|
||||
):
|
||||
|
||||
meshes = ico_sphere(ico_level, torch.device("cuda:0"))
|
||||
meshes = ico_sphere(ico_level, get_random_cuda_device())
|
||||
meshes_batch = meshes.extend(num_meshes)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.renderer.points.rasterize_points import (
|
||||
rasterize_points,
|
||||
@@ -25,7 +25,7 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
|
||||
self._simple_test_case(rasterize_points, device)
|
||||
|
||||
def test_naive_simple_cuda(self):
|
||||
device = torch.device("cuda")
|
||||
device = get_random_cuda_device()
|
||||
self._simple_test_case(rasterize_points, device, bin_size=0)
|
||||
|
||||
def test_python_behind_camera(self):
|
||||
@@ -37,7 +37,8 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
|
||||
self._test_behind_camera(rasterize_points, torch.device("cpu"))
|
||||
|
||||
def test_cuda_behind_camera(self):
|
||||
self._test_behind_camera(rasterize_points, torch.device("cuda"), bin_size=0)
|
||||
device = get_random_cuda_device()
|
||||
self._test_behind_camera(rasterize_points, device, bin_size=0)
|
||||
|
||||
def test_cpp_vs_naive_vs_binned(self):
|
||||
# Make sure that the backward pass runs for all pathways
|
||||
@@ -373,7 +374,8 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
|
||||
return self._test_coarse_rasterize(torch.device("cpu"))
|
||||
|
||||
def test_coarse_cuda(self):
|
||||
return self._test_coarse_rasterize(torch.device("cuda"))
|
||||
device = get_random_cuda_device()
|
||||
return self._test_coarse_rasterize(device)
|
||||
|
||||
def test_compare_coarse_cpu_vs_cuda(self):
|
||||
torch.manual_seed(231)
|
||||
@@ -405,7 +407,8 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
bp_cpu = _C._rasterize_points_coarse(*args)
|
||||
|
||||
pointclouds_cuda = pointclouds.to("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
pointclouds_cuda = pointclouds.to(device)
|
||||
points_packed = pointclouds_cuda.points_packed()
|
||||
cloud_to_packed_first_idx = pointclouds_cuda.cloud_to_packed_first_idx()
|
||||
num_points_per_cloud = pointclouds_cuda.num_points_per_cloud()
|
||||
|
||||
@@ -5,7 +5,7 @@ import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d.ops import sample_points_from_meshes
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
from pytorch3d.utils.ico_sphere import ico_sphere
|
||||
@@ -42,7 +42,7 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase):
|
||||
Check sample_points_from_meshes raises an exception if all meshes are
|
||||
invalid.
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
verts1 = torch.tensor([], dtype=torch.float32, device=device)
|
||||
faces1 = torch.tensor([], dtype=torch.int64, device=device)
|
||||
meshes = Meshes(verts=[verts1, verts1, verts1], faces=[faces1, faces1, faces1])
|
||||
@@ -56,7 +56,7 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase):
|
||||
For an ico_sphere, the sampled vertices should lie on a unit sphere.
|
||||
For an empty mesh, the samples and normals should be 0.
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
|
||||
# Unit simplex.
|
||||
verts_pyramid = torch.tensor(
|
||||
|
||||
Reference in New Issue
Block a user