From 1db40ac56629e11d16c6b6beae56f5ee006343cc Mon Sep 17 00:00:00 2001 From: Patrick Labatut Date: Wed, 9 Jun 2021 15:48:56 -0700 Subject: [PATCH] Tidy uses of torch.device in Pointclouds Summary: Tidy uses of `torch.device` in `Pointclouds`: - Allow `str` or `torch.device` in `Pointclouds.to()` method - Consistently use `torch.device` for internal type - Fix comparison of devices Reviewed By: nikhilaravi Differential Revision: D28970221 fbshipit-source-id: 3ca7104d4c0d9b20b0cff4f00e3ce931c5f1528a --- pytorch3d/structures/pointclouds.py | 43 ++++++++++++++++------------- tests/test_pointclouds.py | 27 ++++++++++++++++++ 2 files changed, 51 insertions(+), 19 deletions(-) diff --git a/pytorch3d/structures/pointclouds.py b/pytorch3d/structures/pointclouds.py index 52de6192..9a43717b 100644 --- a/pytorch3d/structures/pointclouds.py +++ b/pytorch3d/structures/pointclouds.py @@ -2,6 +2,7 @@ import torch +from ..common.types import Device, make_device from . import utils as struct_utils @@ -130,7 +131,7 @@ class Pointclouds: Refer to comments above for descriptions of List and Padded representations. """ - self.device = None + self.device = torch.device("cpu") # Indicates whether the clouds in the list/batch have the same number # of points. @@ -175,7 +176,6 @@ class Pointclouds: if isinstance(points, list): self._points_list = points self._N = len(self._points_list) - self.device = torch.device("cpu") self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device) self._num_points_per_cloud = [] @@ -700,7 +700,7 @@ class Pointclouds: setattr(other, k, v.detach()) return other - def to(self, device, copy: bool = False): + def to(self, device: Device, copy: bool = False): """ Match functionality of torch.Tensor.to() If copy = True or the self Tensor is on a different device, the @@ -709,34 +709,39 @@ class Pointclouds: then self is returned. Args: - device: Device id for the new tensor. + device: Device (as str or torch.device) for the new tensor. copy: Boolean indicator whether or not to clone self. Default False. Returns: Pointclouds object. """ - if not copy and self.device == device: + device_ = make_device(device) + + if not copy and self.device == device_: return self + other = self.clone() - if self.device != device: - other.device = device - if other._N > 0: - other._points_list = [v.to(device) for v in other.points_list()] - if other._normals_list is not None: - other._normals_list = [n.to(device) for n in other.normals_list()] - if other._features_list is not None: - other._features_list = [f.to(device) for f in other.features_list()] - for k in self._INTERNAL_TENSORS: - v = getattr(self, k) - if torch.is_tensor(v): - setattr(other, k, v.to(device)) + if self.device == device_: + return other + + other.device = device_ + if other._N > 0: + other._points_list = [v.to(device_) for v in other.points_list()] + if other._normals_list is not None: + other._normals_list = [n.to(device_) for n in other.normals_list()] + if other._features_list is not None: + other._features_list = [f.to(device_) for f in other.features_list()] + for k in self._INTERNAL_TENSORS: + v = getattr(self, k) + if torch.is_tensor(v): + setattr(other, k, v.to(device_)) return other def cpu(self): - return self.to(torch.device("cpu")) + return self.to("cpu") def cuda(self): - return self.to(torch.device("cuda")) + return self.to("cuda") def get_cloud(self, index: int): """ diff --git a/tests/test_pointclouds.py b/tests/test_pointclouds.py index 9fdaac8a..75b28266 100644 --- a/tests/test_pointclouds.py +++ b/tests/test_pointclouds.py @@ -633,6 +633,33 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase): with self.assertRaises(ValueError): clouds.extend(N=-1) + def test_to(self): + cloud = self.init_cloud(5, 100, 10) # Using device "cuda:0" + + cuda_device = torch.device("cuda:0") + + converted_cloud = cloud.to("cuda:0") + self.assertEqual(cuda_device, converted_cloud.device) + self.assertEqual(cuda_device, cloud.device) + self.assertIs(cloud, converted_cloud) + + converted_cloud = cloud.to(cuda_device) + self.assertEqual(cuda_device, converted_cloud.device) + self.assertEqual(cuda_device, cloud.device) + self.assertIs(cloud, converted_cloud) + + cpu_device = torch.device("cpu") + + converted_cloud = cloud.to("cpu") + self.assertEqual(cpu_device, converted_cloud.device) + self.assertEqual(cuda_device, cloud.device) + self.assertIsNot(cloud, converted_cloud) + + converted_cloud = cloud.to(cpu_device) + self.assertEqual(cpu_device, converted_cloud.device) + self.assertEqual(cuda_device, cloud.device) + self.assertIsNot(cloud, converted_cloud) + def test_to_list(self): cloud = self.init_cloud(5, 100, 10) device = torch.device("cuda:1")