Tidy uses of torch.device in Pointclouds

Summary:
Tidy uses of `torch.device` in `Pointclouds`:
- Allow `str` or `torch.device` in `Pointclouds.to()` method
- Consistently use `torch.device` for internal type
- Fix comparison of devices

Reviewed By: nikhilaravi

Differential Revision: D28970221

fbshipit-source-id: 3ca7104d4c0d9b20b0cff4f00e3ce931c5f1528a
This commit is contained in:
Patrick Labatut 2021-06-09 15:48:56 -07:00 committed by Facebook GitHub Bot
parent 633d66f1f0
commit 1db40ac566
2 changed files with 51 additions and 19 deletions

View File

@ -2,6 +2,7 @@
import torch import torch
from ..common.types import Device, make_device
from . import utils as struct_utils from . import utils as struct_utils
@ -130,7 +131,7 @@ class Pointclouds:
Refer to comments above for descriptions of List and Padded Refer to comments above for descriptions of List and Padded
representations. representations.
""" """
self.device = None self.device = torch.device("cpu")
# Indicates whether the clouds in the list/batch have the same number # Indicates whether the clouds in the list/batch have the same number
# of points. # of points.
@ -175,7 +176,6 @@ class Pointclouds:
if isinstance(points, list): if isinstance(points, list):
self._points_list = points self._points_list = points
self._N = len(self._points_list) self._N = len(self._points_list)
self.device = torch.device("cpu")
self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device) self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device)
self._num_points_per_cloud = [] self._num_points_per_cloud = []
@ -700,7 +700,7 @@ class Pointclouds:
setattr(other, k, v.detach()) setattr(other, k, v.detach())
return other return other
def to(self, device, copy: bool = False): def to(self, device: Device, copy: bool = False):
""" """
Match functionality of torch.Tensor.to() Match functionality of torch.Tensor.to()
If copy = True or the self Tensor is on a different device, the If copy = True or the self Tensor is on a different device, the
@ -709,34 +709,39 @@ class Pointclouds:
then self is returned. then self is returned.
Args: Args:
device: Device id for the new tensor. device: Device (as str or torch.device) for the new tensor.
copy: Boolean indicator whether or not to clone self. Default False. copy: Boolean indicator whether or not to clone self. Default False.
Returns: Returns:
Pointclouds object. Pointclouds object.
""" """
if not copy and self.device == device: device_ = make_device(device)
if not copy and self.device == device_:
return self return self
other = self.clone() other = self.clone()
if self.device != device: if self.device == device_:
other.device = device return other
if other._N > 0:
other._points_list = [v.to(device) for v in other.points_list()] other.device = device_
if other._normals_list is not None: if other._N > 0:
other._normals_list = [n.to(device) for n in other.normals_list()] other._points_list = [v.to(device_) for v in other.points_list()]
if other._features_list is not None: if other._normals_list is not None:
other._features_list = [f.to(device) for f in other.features_list()] other._normals_list = [n.to(device_) for n in other.normals_list()]
for k in self._INTERNAL_TENSORS: if other._features_list is not None:
v = getattr(self, k) other._features_list = [f.to(device_) for f in other.features_list()]
if torch.is_tensor(v): for k in self._INTERNAL_TENSORS:
setattr(other, k, v.to(device)) v = getattr(self, k)
if torch.is_tensor(v):
setattr(other, k, v.to(device_))
return other return other
def cpu(self): def cpu(self):
return self.to(torch.device("cpu")) return self.to("cpu")
def cuda(self): def cuda(self):
return self.to(torch.device("cuda")) return self.to("cuda")
def get_cloud(self, index: int): def get_cloud(self, index: int):
""" """

View File

@ -633,6 +633,33 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
clouds.extend(N=-1) clouds.extend(N=-1)
def test_to(self):
cloud = self.init_cloud(5, 100, 10) # Using device "cuda:0"
cuda_device = torch.device("cuda:0")
converted_cloud = cloud.to("cuda:0")
self.assertEqual(cuda_device, converted_cloud.device)
self.assertEqual(cuda_device, cloud.device)
self.assertIs(cloud, converted_cloud)
converted_cloud = cloud.to(cuda_device)
self.assertEqual(cuda_device, converted_cloud.device)
self.assertEqual(cuda_device, cloud.device)
self.assertIs(cloud, converted_cloud)
cpu_device = torch.device("cpu")
converted_cloud = cloud.to("cpu")
self.assertEqual(cpu_device, converted_cloud.device)
self.assertEqual(cuda_device, cloud.device)
self.assertIsNot(cloud, converted_cloud)
converted_cloud = cloud.to(cpu_device)
self.assertEqual(cpu_device, converted_cloud.device)
self.assertEqual(cuda_device, cloud.device)
self.assertIsNot(cloud, converted_cloud)
def test_to_list(self): def test_to_list(self):
cloud = self.init_cloud(5, 100, 10) cloud = self.init_cloud(5, 100, 10)
device = torch.device("cuda:1") device = torch.device("cuda:1")