From 1f9661e150592104bcb42d7458dadcf2ed894bcb Mon Sep 17 00:00:00 2001 From: Patrick Labatut Date: Wed, 9 Jun 2021 15:48:56 -0700 Subject: [PATCH] Tidy uses of torch.device in Volumes Summary: Tidy uses of `torch.device` in `Volumes`: - Allow `str` or `torch.device` in `Volumes.to()` method - Consistently use `torch.device` for internal type - Fix comparison of devices Reviewed By: nikhilaravi Differential Revision: D28970876 fbshipit-source-id: c640cc22ced684a54cc450ac38a0f4b3435d47be --- pytorch3d/structures/volumes.py | 41 ++++++++++++--------- tests/test_volumes.py | 65 +++++++++++++++++++++++---------- 2 files changed, 69 insertions(+), 37 deletions(-) diff --git a/pytorch3d/structures/volumes.py b/pytorch3d/structures/volumes.py index cec18d58..6f9011b0 100644 --- a/pytorch3d/structures/volumes.py +++ b/pytorch3d/structures/volumes.py @@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union import torch +from ..common.types import Device, make_device from ..transforms import Scale, Transform3d from . import utils as struct_utils @@ -694,7 +695,7 @@ class Volumes: """ return copy.deepcopy(self) - def to(self, device, copy: bool = False) -> "Volumes": + def to(self, device: Device, copy: bool = False) -> "Volumes": """ Match the functionality of torch.Tensor.to() If copy = True or the self Tensor is on a different device, the @@ -703,30 +704,34 @@ class Volumes: then self is returned. Args: - **device**: Device id for the new tensor. - **copy**: Boolean indicator whether or not to clone self. Default False. + device: Device (as str or torch.device) for the new tensor. + copy: Boolean indicator whether or not to clone self. Default False. Returns: - Volumes object. + Volumes object. """ - if not copy and self.device == device: + device_ = make_device(device) + if not copy and self.device == device_: return self + other = self.clone() - if self.device != device: - other.device = device - # pyre-fixme[16]: `List` has no attribute `to`. - other._densities = self._densities.to(device) - if self._features is not None: - # pyre-fixme[16]: `Optional` has no attribute `to`. - other._features = self.features().to(device) - other._local_to_world_transform = ( - self.get_local_to_world_coords_transform().to(device) - ) - other._grid_sizes = self._grid_sizes.to(device) + if self.device == device_: + return other + + other.device = device_ + # pyre-fixme[16]: `List` has no attribute `to`. + other._densities = self._densities.to(device_) + if self._features is not None: + # pyre-fixme[16]: `Optional` has no attribute `to`. + other._features = self.features().to(device_) + other._local_to_world_transform = self.get_local_to_world_coords_transform().to( + device_ + ) + other._grid_sizes = self._grid_sizes.to(device_) return other def cpu(self) -> "Volumes": - return self.to(torch.device("cpu")) + return self.to("cpu") def cuda(self) -> "Volumes": - return self.to(torch.device("cuda")) + return self.to("cuda") diff --git a/tests/test_volumes.py b/tests/test_volumes.py index fd7bc27e..9b4053d8 100644 --- a/tests/test_volumes.py +++ b/tests/test_volumes.py @@ -445,38 +445,65 @@ class TestVolumes(TestCaseMixin, unittest.TestCase): Test the moving of the volumes from/to gpu and cpu """ - device = torch.device("cuda:0") - device_cpu = torch.device("cpu") - features = torch.randn( - size=[num_volumes, num_channels, *size], device=device, dtype=torch.float32 + size=[num_volumes, num_channels, *size], dtype=torch.float32 ) - densities = torch.rand(size=[num_volumes, 1, *size], device=device, dtype=dtype) + densities = torch.rand(size=[num_volumes, 1, *size], dtype=dtype) + volumes = Volumes(densities=densities, features=features) + + # Test support for str and torch.device + cpu_device = torch.device("cpu") + + converted_volumes = volumes.to("cpu") + self.assertEqual(cpu_device, converted_volumes.device) + self.assertEqual(cpu_device, volumes.device) + self.assertIs(volumes, converted_volumes) + + converted_volumes = volumes.to(cpu_device) + self.assertEqual(cpu_device, converted_volumes.device) + self.assertEqual(cpu_device, volumes.device) + self.assertIs(volumes, converted_volumes) + + cuda_device = torch.device("cuda:0") + + converted_volumes = volumes.to("cuda:0") + self.assertEqual(cuda_device, converted_volumes.device) + self.assertEqual(cpu_device, volumes.device) + self.assertIsNot(volumes, converted_volumes) + + converted_volumes = volumes.to(cuda_device) + self.assertEqual(cuda_device, converted_volumes.device) + self.assertEqual(cpu_device, volumes.device) + self.assertIsNot(volumes, converted_volumes) + + # Test device placement of internal tensors + features = features.to(cuda_device) + densities = features.to(cuda_device) for features_ in (features, None): - v = Volumes(densities=densities, features=features_) + volumes = Volumes(densities=densities, features=features_) - v_cpu = v.cpu() - v_cuda = v_cpu.cuda() - v_cuda_2 = v_cuda.cuda() - v_cpu_2 = v_cuda_2.cpu() + cpu_volumes = volumes.cpu() + cuda_volumes = cpu_volumes.cuda() + cuda_volumes2 = cuda_volumes.cuda() + cpu_volumes2 = cuda_volumes2.cpu() - for v1, v2 in itertools.combinations( - (v, v_cpu, v_cpu_2, v_cuda, v_cuda_2), 2 + for volumes1, volumes2 in itertools.combinations( + (volumes, cpu_volumes, cpu_volumes2, cuda_volumes, cuda_volumes2), 2 ): - if v1 is v_cuda and v2 is v_cuda_2: + if volumes1 is cuda_volumes and volumes2 is cuda_volumes2: # checks that we do not copy if the devices stay the same assert_fun = self.assertIs else: assert_fun = self.assertSeparate - assert_fun(v1._densities, v2._densities) + assert_fun(volumes1._densities, volumes2._densities) if features_ is not None: - assert_fun(v1._features, v2._features) - for v_ in (v1, v2): - if v_ in (v_cpu, v_cpu_2): - self._check_vars_on_device(v_, device_cpu) + assert_fun(volumes1._features, volumes2._features) + for volumes_ in (volumes1, volumes2): + if volumes_ in (cpu_volumes, cpu_volumes2): + self._check_vars_on_device(volumes_, cpu_device) else: - self._check_vars_on_device(v_, device) + self._check_vars_on_device(volumes_, cuda_device) def _check_padded(self, x_pad, x_list, grid_sizes): """