Adding utility methods to TensorProperties

Summary:
Context: in the code we are releasing with CO3D dataset, we use  `cuda()` on TensorProperties like Pointclouds and Cameras where we recursively move batch to a GPU. It would be good to push it to a release so we don’t need to depend on the nightly build.

Additionally, I aligned the logic of `.to("cuda")` without device index to the one of `torch.Tensor` where the current device is populated to index. It should not affect any actual use cases but some tests had to be changed.

Reviewed By: bottler

Differential Revision: D29659529

fbshipit-source-id: abe58aeaca14bacc68da3e6cf5ae07df3353e3ce
This commit is contained in:
Roman Shapovalov
2021-07-13 10:28:41 -07:00
committed by Facebook GitHub Bot
parent fa44a05567
commit 0c02ae907e
6 changed files with 31 additions and 9 deletions

View File

@@ -709,9 +709,9 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
self.assertEqual(cpu_device, mesh.device)
self.assertIs(mesh, converted_mesh)
cuda_device = torch.device("cuda")
cuda_device = torch.device("cuda:0")
converted_mesh = mesh.to("cuda")
converted_mesh = mesh.to("cuda:0")
self.assertEqual(cuda_device, converted_mesh.device)
self.assertEqual(cpu_device, mesh.device)
self.assertIsNot(mesh, converted_mesh)

View File

@@ -39,7 +39,17 @@ class TestTensorProperties(TestCaseMixin, unittest.TestCase):
example = TensorPropertiesTestClass(x=10.0, y=(100.0, 200.0))
device = torch.device("cuda:0")
new_example = example.to(device=device)
self.assertTrue(new_example.device == device)
self.assertEqual(new_example.device, device)
example_cpu = example.cpu()
self.assertEqual(example_cpu.device, torch.device("cpu"))
example_gpu = example.cuda()
self.assertEqual(example_gpu.device.type, "cuda")
self.assertIsNotNone(example_gpu.device.index)
example_gpu1 = example.cuda(1)
self.assertEqual(example_gpu1.device, torch.device("cuda:1"))
def test_clone(self):
# Check clone method

View File

@@ -22,7 +22,7 @@ from pytorch3d.structures.meshes import Meshes
class TestShader(TestCaseMixin, unittest.TestCase):
def test_to(self):
cpu_device = torch.device("cpu")
cuda_device = torch.device("cuda")
cuda_device = torch.device("cuda:0")
R, T = look_at_view_transform()

View File

@@ -50,9 +50,9 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
self.assertEqual(torch.float32, t.dtype)
self.assertIsNot(t, cpu_t)
cuda_device = torch.device("cuda")
cuda_device = torch.device("cuda:0")
cuda_t = t.to("cuda")
cuda_t = t.to("cuda:0")
self.assertEqual(cuda_device, cuda_t.device)
self.assertEqual(cpu_device, t.device)
self.assertEqual(torch.float32, cuda_t.dtype)