Tidy uses of torch.device in Meshes

Summary:
Tidy uses of `torch.device` in `Meshes`:
- Allow `str` or `torch.device` in `Meshes.to()` method
- Consistently use `torch.device` for internal type
- Fix comparison of devices

Reviewed By: nikhilaravi

Differential Revision: D28969461

fbshipit-source-id: 16d3c1f5458954bb11fdf0efea88542e94dccd7a
This commit is contained in:
Patrick Labatut
2021-06-09 15:48:56 -07:00
committed by Facebook GitHub Bot
parent 13a0110b69
commit 633d66f1f0
2 changed files with 45 additions and 23 deletions

View File

@@ -719,12 +719,31 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
mesh.extend(N=-1)
def test_to(self):
mesh = init_mesh(5, 10, 100, device=torch.device("cuda:0"))
device = torch.device("cuda:1")
mesh = init_mesh(5, 10, 100)
new_mesh = mesh.to(device)
self.assertTrue(new_mesh.device == device)
self.assertTrue(mesh.device == torch.device("cuda:0"))
cpu_device = torch.device("cpu")
converted_mesh = mesh.to("cpu")
self.assertEqual(cpu_device, converted_mesh.device)
self.assertEqual(cpu_device, mesh.device)
self.assertIs(mesh, converted_mesh)
converted_mesh = mesh.to(cpu_device)
self.assertEqual(cpu_device, converted_mesh.device)
self.assertEqual(cpu_device, mesh.device)
self.assertIs(mesh, converted_mesh)
cuda_device = torch.device("cuda")
converted_mesh = mesh.to("cuda")
self.assertEqual(cuda_device, converted_mesh.device)
self.assertEqual(cpu_device, mesh.device)
self.assertIsNot(mesh, converted_mesh)
converted_mesh = mesh.to(cuda_device)
self.assertEqual(cuda_device, converted_mesh.device)
self.assertEqual(cpu_device, mesh.device)
self.assertIsNot(mesh, converted_mesh)
def test_split_mesh(self):
mesh = init_mesh(5, 10, 100)