Tidy uses of torch.device in Pointclouds

Summary:
Tidy uses of `torch.device` in `Pointclouds`:
- Allow `str` or `torch.device` in `Pointclouds.to()` method
- Consistently use `torch.device` for internal type
- Fix comparison of devices

Reviewed By: nikhilaravi

Differential Revision: D28970221

fbshipit-source-id: 3ca7104d4c0d9b20b0cff4f00e3ce931c5f1528a
This commit is contained in:
Patrick Labatut
2021-06-09 15:48:56 -07:00
committed by Facebook GitHub Bot
parent 633d66f1f0
commit 1db40ac566
2 changed files with 51 additions and 19 deletions

View File

@@ -633,6 +633,33 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
with self.assertRaises(ValueError):
clouds.extend(N=-1)
def test_to(self):
cloud = self.init_cloud(5, 100, 10) # Using device "cuda:0"
cuda_device = torch.device("cuda:0")
converted_cloud = cloud.to("cuda:0")
self.assertEqual(cuda_device, converted_cloud.device)
self.assertEqual(cuda_device, cloud.device)
self.assertIs(cloud, converted_cloud)
converted_cloud = cloud.to(cuda_device)
self.assertEqual(cuda_device, converted_cloud.device)
self.assertEqual(cuda_device, cloud.device)
self.assertIs(cloud, converted_cloud)
cpu_device = torch.device("cpu")
converted_cloud = cloud.to("cpu")
self.assertEqual(cpu_device, converted_cloud.device)
self.assertEqual(cuda_device, cloud.device)
self.assertIsNot(cloud, converted_cloud)
converted_cloud = cloud.to(cpu_device)
self.assertEqual(cpu_device, converted_cloud.device)
self.assertEqual(cuda_device, cloud.device)
self.assertIsNot(cloud, converted_cloud)
def test_to_list(self):
cloud = self.init_cloud(5, 100, 10)
device = torch.device("cuda:1")