Cuda function for points2vols

Summary: Added CUDA implementation to match the new, still unused, C++ function for the core of points2vols.

Reviewed By: nikhilaravi

Differential Revision: D29548608

fbshipit-source-id: 16ebb61787fcb4c70461f9215a86ad5f97aecb4e
This commit is contained in:
Jeremy Reizenstein
2021-10-01 11:57:07 -07:00
committed by Facebook GitHub Bot
parent 0dfc6e0eb8
commit 9ad98c87c3
3 changed files with 433 additions and 2 deletions

View File

@@ -420,6 +420,18 @@ class TestRawFunction(TestCaseMixin, unittest.TestCase):
def test_grad_round_cpu(self):
self.do_gradcheck(torch.device("cpu"), False, False)
def test_grad_corners_splat_cuda(self):
self.do_gradcheck(torch.device("cuda:0"), True, True)
def test_grad_corners_round_cuda(self):
self.do_gradcheck(torch.device("cuda:0"), False, True)
def test_grad_splat_cuda(self):
self.do_gradcheck(torch.device("cuda:0"), True, False)
def test_grad_round_cuda(self):
self.do_gradcheck(torch.device("cuda:0"), False, False)
def do_gradcheck(self, device, splat: bool, align_corners: bool):
"""
Use gradcheck to verify the gradient of _points_to_volumes
@@ -492,6 +504,18 @@ class TestRawFunction(TestCaseMixin, unittest.TestCase):
def test_single_splat_cpu(self):
self.single_point(torch.device("cpu"), True, False)
def test_single_corners_round_cuda(self):
self.single_point(torch.device("cuda:0"), False, True)
def test_single_corners_splat_cuda(self):
self.single_point(torch.device("cuda:0"), True, True)
def test_single_round_cuda(self):
self.single_point(torch.device("cuda:0"), False, False)
def test_single_splat_cuda(self):
self.single_point(torch.device("cuda:0"), True, False)
def single_point(self, device, splat: bool, align_corners: bool):
"""
Check the outcome of _points_to_volumes where a single point