mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 22:00:35 +08:00
Tidy uses of torch.device in Pointclouds
Summary: Tidy uses of `torch.device` in `Pointclouds`: - Allow `str` or `torch.device` in `Pointclouds.to()` method - Consistently use `torch.device` for internal type - Fix comparison of devices Reviewed By: nikhilaravi Differential Revision: D28970221 fbshipit-source-id: 3ca7104d4c0d9b20b0cff4f00e3ce931c5f1528a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
633d66f1f0
commit
1db40ac566
@@ -633,6 +633,33 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
clouds.extend(N=-1)
|
||||
|
||||
def test_to(self):
|
||||
cloud = self.init_cloud(5, 100, 10) # Using device "cuda:0"
|
||||
|
||||
cuda_device = torch.device("cuda:0")
|
||||
|
||||
converted_cloud = cloud.to("cuda:0")
|
||||
self.assertEqual(cuda_device, converted_cloud.device)
|
||||
self.assertEqual(cuda_device, cloud.device)
|
||||
self.assertIs(cloud, converted_cloud)
|
||||
|
||||
converted_cloud = cloud.to(cuda_device)
|
||||
self.assertEqual(cuda_device, converted_cloud.device)
|
||||
self.assertEqual(cuda_device, cloud.device)
|
||||
self.assertIs(cloud, converted_cloud)
|
||||
|
||||
cpu_device = torch.device("cpu")
|
||||
|
||||
converted_cloud = cloud.to("cpu")
|
||||
self.assertEqual(cpu_device, converted_cloud.device)
|
||||
self.assertEqual(cuda_device, cloud.device)
|
||||
self.assertIsNot(cloud, converted_cloud)
|
||||
|
||||
converted_cloud = cloud.to(cpu_device)
|
||||
self.assertEqual(cpu_device, converted_cloud.device)
|
||||
self.assertEqual(cuda_device, cloud.device)
|
||||
self.assertIsNot(cloud, converted_cloud)
|
||||
|
||||
def test_to_list(self):
|
||||
cloud = self.init_cloud(5, 100, 10)
|
||||
device = torch.device("cuda:1")
|
||||
|
||||
Reference in New Issue
Block a user