mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
Cuda updates
Summary: Updates to: - enable cuda kernel launches on any GPU (not just the default) - cuda and contiguous checks for all kernels - checks to ensure all tensors are on the same device - error reporting in the cuda kernels - cuda tests now run on a random device not just the default Reviewed By: jcjohnson, gkioxari Differential Revision: D21215280 fbshipit-source-id: 1bedc9fe6c35e9e920bdc4d78ed12865b1005519
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c9267ab7af
commit
c3d636dc8c
@@ -4,7 +4,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from common_testing import TestCaseMixin
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.ops.graph_conv import GraphConv, gather_scatter, gather_scatter_python
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
@@ -14,7 +14,7 @@ from pytorch3d.utils import ico_sphere
|
||||
class TestGraphConv(TestCaseMixin, unittest.TestCase):
|
||||
def test_undirected(self):
|
||||
dtype = torch.float32
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
verts = torch.tensor(
|
||||
[[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype, device=device
|
||||
)
|
||||
@@ -97,7 +97,7 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(y, expected_y)
|
||||
|
||||
def test_backward(self):
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
mesh = ico_sphere()
|
||||
verts = mesh.verts_packed()
|
||||
edges = mesh.edges_packed()
|
||||
@@ -118,7 +118,7 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(repr(conv), "GraphConv(32 -> 64, directed=True)")
|
||||
|
||||
def test_cpu_cuda_tensor_error(self):
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
verts = torch.tensor(
|
||||
[[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float32, device=device
|
||||
)
|
||||
@@ -134,7 +134,7 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase):
|
||||
Check that gather_scatter cuda version throws an error if cpu tensors
|
||||
are given as input.
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
mesh = ico_sphere()
|
||||
verts = mesh.verts_packed()
|
||||
edges = mesh.edges_packed()
|
||||
|
||||
Reference in New Issue
Block a user