Tidy uses of torch.device in Volumes

Summary:
Tidy uses of `torch.device` in `Volumes`:
- Allow `str` or `torch.device` in `Volumes.to()` method
- Consistently use `torch.device` for internal type
- Fix comparison of devices

Reviewed By: nikhilaravi

Differential Revision: D28970876

fbshipit-source-id: c640cc22ced684a54cc450ac38a0f4b3435d47be
This commit is contained in:
Patrick Labatut
2021-06-09 15:48:56 -07:00
committed by Facebook GitHub Bot
parent 1db40ac566
commit 1f9661e150
2 changed files with 69 additions and 37 deletions

View File

@@ -445,38 +445,65 @@ class TestVolumes(TestCaseMixin, unittest.TestCase):
Test the moving of the volumes from/to gpu and cpu
"""
device = torch.device("cuda:0")
device_cpu = torch.device("cpu")
features = torch.randn(
size=[num_volumes, num_channels, *size], device=device, dtype=torch.float32
size=[num_volumes, num_channels, *size], dtype=torch.float32
)
densities = torch.rand(size=[num_volumes, 1, *size], device=device, dtype=dtype)
densities = torch.rand(size=[num_volumes, 1, *size], dtype=dtype)
volumes = Volumes(densities=densities, features=features)
# Test support for str and torch.device
cpu_device = torch.device("cpu")
converted_volumes = volumes.to("cpu")
self.assertEqual(cpu_device, converted_volumes.device)
self.assertEqual(cpu_device, volumes.device)
self.assertIs(volumes, converted_volumes)
converted_volumes = volumes.to(cpu_device)
self.assertEqual(cpu_device, converted_volumes.device)
self.assertEqual(cpu_device, volumes.device)
self.assertIs(volumes, converted_volumes)
cuda_device = torch.device("cuda:0")
converted_volumes = volumes.to("cuda:0")
self.assertEqual(cuda_device, converted_volumes.device)
self.assertEqual(cpu_device, volumes.device)
self.assertIsNot(volumes, converted_volumes)
converted_volumes = volumes.to(cuda_device)
self.assertEqual(cuda_device, converted_volumes.device)
self.assertEqual(cpu_device, volumes.device)
self.assertIsNot(volumes, converted_volumes)
# Test device placement of internal tensors
features = features.to(cuda_device)
densities = features.to(cuda_device)
for features_ in (features, None):
v = Volumes(densities=densities, features=features_)
volumes = Volumes(densities=densities, features=features_)
v_cpu = v.cpu()
v_cuda = v_cpu.cuda()
v_cuda_2 = v_cuda.cuda()
v_cpu_2 = v_cuda_2.cpu()
cpu_volumes = volumes.cpu()
cuda_volumes = cpu_volumes.cuda()
cuda_volumes2 = cuda_volumes.cuda()
cpu_volumes2 = cuda_volumes2.cpu()
for v1, v2 in itertools.combinations(
(v, v_cpu, v_cpu_2, v_cuda, v_cuda_2), 2
for volumes1, volumes2 in itertools.combinations(
(volumes, cpu_volumes, cpu_volumes2, cuda_volumes, cuda_volumes2), 2
):
if v1 is v_cuda and v2 is v_cuda_2:
if volumes1 is cuda_volumes and volumes2 is cuda_volumes2:
# checks that we do not copy if the devices stay the same
assert_fun = self.assertIs
else:
assert_fun = self.assertSeparate
assert_fun(v1._densities, v2._densities)
assert_fun(volumes1._densities, volumes2._densities)
if features_ is not None:
assert_fun(v1._features, v2._features)
for v_ in (v1, v2):
if v_ in (v_cpu, v_cpu_2):
self._check_vars_on_device(v_, device_cpu)
assert_fun(volumes1._features, volumes2._features)
for volumes_ in (volumes1, volumes2):
if volumes_ in (cpu_volumes, cpu_volumes2):
self._check_vars_on_device(volumes_, cpu_device)
else:
self._check_vars_on_device(v_, device)
self._check_vars_on_device(volumes_, cuda_device)
def _check_padded(self, x_pad, x_list, grid_sizes):
"""