Tidy uses of torch.device in Meshes

Summary:
Tidy uses of `torch.device` in `Meshes`:
- Allow `str` or `torch.device` in `Meshes.to()` method
- Consistently use `torch.device` for internal type
- Fix comparison of devices

Reviewed By: nikhilaravi

Differential Revision: D28969461

fbshipit-source-id: 16d3c1f5458954bb11fdf0efea88542e94dccd7a
This commit is contained in:
Patrick Labatut 2021-06-09 15:48:56 -07:00 committed by Facebook GitHub Bot
parent 13a0110b69
commit 633d66f1f0
2 changed files with 45 additions and 23 deletions

View File

@ -4,6 +4,7 @@ from typing import List, Union
import torch import torch
from ..common.types import Device, make_device
from . import utils as struct_utils from . import utils as struct_utils
@ -250,7 +251,7 @@ class Meshes:
Refer to comments above for descriptions of List and Padded representations. Refer to comments above for descriptions of List and Padded representations.
""" """
self.device = None self.device = torch.device("cpu")
if textures is not None and not hasattr(textures, "sample_textures"): if textures is not None and not hasattr(textures, "sample_textures"):
msg = "Expected textures to be an instance of type TexturesBase; got %r" msg = "Expected textures to be an instance of type TexturesBase; got %r"
raise ValueError(msg % type(textures)) raise ValueError(msg % type(textures))
@ -339,7 +340,6 @@ class Meshes:
f[f.gt(-1).all(1)].to(torch.int64) if len(f) > 0 else f for f in faces f[f.gt(-1).all(1)].to(torch.int64) if len(f) > 0 else f for f in faces
] ]
self._N = len(self._verts_list) self._N = len(self._verts_list)
self.device = torch.device("cpu")
self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device) self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device)
if self._N > 0: if self._N > 0:
self.device = self._verts_list[0].device self.device = self._verts_list[0].device
@ -1222,7 +1222,7 @@ class Meshes:
other.textures = self.textures.detach() other.textures = self.textures.detach()
return other return other
def to(self, device, copy: bool = False): def to(self, device: Device, copy: bool = False):
""" """
Match functionality of torch.Tensor.to() Match functionality of torch.Tensor.to()
If copy = True or the self Tensor is on a different device, the If copy = True or the self Tensor is on a different device, the
@ -1231,34 +1231,37 @@ class Meshes:
then self is returned. then self is returned.
Args: Args:
device: Device id for the new tensor. device: Device (as str or torch.device) for the new tensor.
copy: Boolean indicator whether or not to clone self. Default False. copy: Boolean indicator whether or not to clone self. Default False.
Returns: Returns:
Meshes object. Meshes object.
""" """
if not copy and self.device == device: device_ = make_device(device)
if not copy and self.device == device_:
return self return self
other = self.clone() other = self.clone()
if self.device != device: if self.device == device_:
other.device = device return other
if other._N > 0:
other._verts_list = [v.to(device) for v in other._verts_list] other.device = device_
other._faces_list = [f.to(device) for f in other._faces_list] if other._N > 0:
for k in self._INTERNAL_TENSORS: other._verts_list = [v.to(device_) for v in other._verts_list]
v = getattr(self, k) other._faces_list = [f.to(device_) for f in other._faces_list]
if torch.is_tensor(v): for k in self._INTERNAL_TENSORS:
setattr(other, k, v.to(device)) v = getattr(self, k)
if self.textures is not None: if torch.is_tensor(v):
other.textures = other.textures.to(device) setattr(other, k, v.to(device_))
if self.textures is not None:
other.textures = other.textures.to(device_)
return other return other
def cpu(self): def cpu(self):
return self.to(torch.device("cpu")) return self.to("cpu")
def cuda(self): def cuda(self):
return self.to(torch.device("cuda")) return self.to("cuda")
def get_mesh_verts_faces(self, index: int): def get_mesh_verts_faces(self, index: int):
""" """

View File

@ -719,12 +719,31 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
mesh.extend(N=-1) mesh.extend(N=-1)
def test_to(self): def test_to(self):
mesh = init_mesh(5, 10, 100, device=torch.device("cuda:0")) mesh = init_mesh(5, 10, 100)
device = torch.device("cuda:1")
new_mesh = mesh.to(device) cpu_device = torch.device("cpu")
self.assertTrue(new_mesh.device == device)
self.assertTrue(mesh.device == torch.device("cuda:0")) converted_mesh = mesh.to("cpu")
self.assertEqual(cpu_device, converted_mesh.device)
self.assertEqual(cpu_device, mesh.device)
self.assertIs(mesh, converted_mesh)
converted_mesh = mesh.to(cpu_device)
self.assertEqual(cpu_device, converted_mesh.device)
self.assertEqual(cpu_device, mesh.device)
self.assertIs(mesh, converted_mesh)
cuda_device = torch.device("cuda")
converted_mesh = mesh.to("cuda")
self.assertEqual(cuda_device, converted_mesh.device)
self.assertEqual(cpu_device, mesh.device)
self.assertIsNot(mesh, converted_mesh)
converted_mesh = mesh.to(cuda_device)
self.assertEqual(cuda_device, converted_mesh.device)
self.assertEqual(cpu_device, mesh.device)
self.assertIsNot(mesh, converted_mesh)
def test_split_mesh(self): def test_split_mesh(self):
mesh = init_mesh(5, 10, 100) mesh = init_mesh(5, 10, 100)