From 633d66f1f03653fbc476774b04f3d9e6fc940d24 Mon Sep 17 00:00:00 2001 From: Patrick Labatut Date: Wed, 9 Jun 2021 15:48:56 -0700 Subject: [PATCH] Tidy uses of torch.device in Meshes Summary: Tidy uses of `torch.device` in `Meshes`: - Allow `str` or `torch.device` in `Meshes.to()` method - Consistently use `torch.device` for internal type - Fix comparison of devices Reviewed By: nikhilaravi Differential Revision: D28969461 fbshipit-source-id: 16d3c1f5458954bb11fdf0efea88542e94dccd7a --- pytorch3d/structures/meshes.py | 39 ++++++++++++++++++---------------- tests/test_meshes.py | 29 ++++++++++++++++++++----- 2 files changed, 45 insertions(+), 23 deletions(-) diff --git a/pytorch3d/structures/meshes.py b/pytorch3d/structures/meshes.py index 38f75f5f..ac88cd05 100644 --- a/pytorch3d/structures/meshes.py +++ b/pytorch3d/structures/meshes.py @@ -4,6 +4,7 @@ from typing import List, Union import torch +from ..common.types import Device, make_device from . import utils as struct_utils @@ -250,7 +251,7 @@ class Meshes: Refer to comments above for descriptions of List and Padded representations. """ - self.device = None + self.device = torch.device("cpu") if textures is not None and not hasattr(textures, "sample_textures"): msg = "Expected textures to be an instance of type TexturesBase; got %r" raise ValueError(msg % type(textures)) @@ -339,7 +340,6 @@ class Meshes: f[f.gt(-1).all(1)].to(torch.int64) if len(f) > 0 else f for f in faces ] self._N = len(self._verts_list) - self.device = torch.device("cpu") self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device) if self._N > 0: self.device = self._verts_list[0].device @@ -1222,7 +1222,7 @@ class Meshes: other.textures = self.textures.detach() return other - def to(self, device, copy: bool = False): + def to(self, device: Device, copy: bool = False): """ Match functionality of torch.Tensor.to() If copy = True or the self Tensor is on a different device, the @@ -1231,34 +1231,37 @@ class Meshes: then self is returned. Args: - device: Device id for the new tensor. + device: Device (as str or torch.device) for the new tensor. copy: Boolean indicator whether or not to clone self. Default False. Returns: Meshes object. """ - if not copy and self.device == device: + device_ = make_device(device) + if not copy and self.device == device_: return self other = self.clone() - if self.device != device: - other.device = device - if other._N > 0: - other._verts_list = [v.to(device) for v in other._verts_list] - other._faces_list = [f.to(device) for f in other._faces_list] - for k in self._INTERNAL_TENSORS: - v = getattr(self, k) - if torch.is_tensor(v): - setattr(other, k, v.to(device)) - if self.textures is not None: - other.textures = other.textures.to(device) + if self.device == device_: + return other + + other.device = device_ + if other._N > 0: + other._verts_list = [v.to(device_) for v in other._verts_list] + other._faces_list = [f.to(device_) for f in other._faces_list] + for k in self._INTERNAL_TENSORS: + v = getattr(self, k) + if torch.is_tensor(v): + setattr(other, k, v.to(device_)) + if self.textures is not None: + other.textures = other.textures.to(device_) return other def cpu(self): - return self.to(torch.device("cpu")) + return self.to("cpu") def cuda(self): - return self.to(torch.device("cuda")) + return self.to("cuda") def get_mesh_verts_faces(self, index: int): """ diff --git a/tests/test_meshes.py b/tests/test_meshes.py index 5d2d5a6d..929bebcb 100644 --- a/tests/test_meshes.py +++ b/tests/test_meshes.py @@ -719,12 +719,31 @@ class TestMeshes(TestCaseMixin, unittest.TestCase): mesh.extend(N=-1) def test_to(self): - mesh = init_mesh(5, 10, 100, device=torch.device("cuda:0")) - device = torch.device("cuda:1") + mesh = init_mesh(5, 10, 100) - new_mesh = mesh.to(device) - self.assertTrue(new_mesh.device == device) - self.assertTrue(mesh.device == torch.device("cuda:0")) + cpu_device = torch.device("cpu") + + converted_mesh = mesh.to("cpu") + self.assertEqual(cpu_device, converted_mesh.device) + self.assertEqual(cpu_device, mesh.device) + self.assertIs(mesh, converted_mesh) + + converted_mesh = mesh.to(cpu_device) + self.assertEqual(cpu_device, converted_mesh.device) + self.assertEqual(cpu_device, mesh.device) + self.assertIs(mesh, converted_mesh) + + cuda_device = torch.device("cuda") + + converted_mesh = mesh.to("cuda") + self.assertEqual(cuda_device, converted_mesh.device) + self.assertEqual(cpu_device, mesh.device) + self.assertIsNot(mesh, converted_mesh) + + converted_mesh = mesh.to(cuda_device) + self.assertEqual(cuda_device, converted_mesh.device) + self.assertEqual(cpu_device, mesh.device) + self.assertIsNot(mesh, converted_mesh) def test_split_mesh(self): mesh = init_mesh(5, 10, 100)