From 0c02ae907edc2db9aee7d5bda1159814ce06ee56 Mon Sep 17 00:00:00 2001 From: Roman Shapovalov Date: Tue, 13 Jul 2021 10:28:41 -0700 Subject: [PATCH] Adding utility methods to TensorProperties MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Context: in the code we are releasing with CO3D dataset, we use `cuda()` on TensorProperties like Pointclouds and Cameras where we recursively move batch to a GPU. It would be good to push it to a release so we don’t need to depend on the nightly build. Additionally, I aligned the logic of `.to("cuda")` without device index to the one of `torch.Tensor` where the current device is populated to index. It should not affect any actual use cases but some tests had to be changed. Reviewed By: bottler Differential Revision: D29659529 fbshipit-source-id: abe58aeaca14bacc68da3e6cf5ae07df3353e3ce --- pytorch3d/common/types.py | 10 ++++++++-- pytorch3d/renderer/utils.py | 8 +++++++- tests/test_meshes.py | 4 ++-- tests/test_rendering_utils.py | 12 +++++++++++- tests/test_shader.py | 2 +- tests/test_transforms.py | 4 ++-- 6 files changed, 31 insertions(+), 9 deletions(-) diff --git a/pytorch3d/common/types.py b/pytorch3d/common/types.py index ab15a184..da5e71a9 100644 --- a/pytorch3d/common/types.py +++ b/pytorch3d/common/types.py @@ -15,7 +15,8 @@ Device = Union[str, torch.device] def make_device(device: Device) -> torch.device: """ Makes an actual torch.device object from the device specified as - either a string or torch.device object. + either a string or torch.device object. If the device is `cuda` without + a specific index, the index of the current device is assigned. Args: device: Device (as str or torch.device) @@ -23,7 +24,12 @@ def make_device(device: Device) -> torch.device: Returns: A matching torch.device object """ - return torch.device(device) if isinstance(device, str) else device + device = torch.device(device) if isinstance(device, str) else device + if device.type == "cuda" and device.index is None: # pyre-ignore[16] + # If cuda but with no index, then the current cuda device is indicated. + # In that case, we fix to that device + device = torch.device(f"cuda:{torch.cuda.current_device()}") + return device def get_device(x, device: Optional[Device] = None) -> torch.device: diff --git a/pytorch3d/renderer/utils.py b/pytorch3d/renderer/utils.py index cf0d1187..d6e52e6e 100644 --- a/pytorch3d/renderer/utils.py +++ b/pytorch3d/renderer/utils.py @@ -8,7 +8,7 @@ import copy import inspect import warnings -from typing import Any, Union +from typing import Any, Optional, Union import numpy as np import torch @@ -174,6 +174,12 @@ class TensorProperties(nn.Module): setattr(self, k, v.to(device_)) return self + def cpu(self) -> "TensorProperties": + return self.to("cpu") + + def cuda(self, device: Optional[int] = None) -> "TensorProperties": + return self.to(f"cuda:{device}" if device is not None else "cuda") + def clone(self, other) -> "TensorProperties": """ Update the tensor properties of other with the cloned properties of self. diff --git a/tests/test_meshes.py b/tests/test_meshes.py index f350a989..71acf1a6 100644 --- a/tests/test_meshes.py +++ b/tests/test_meshes.py @@ -709,9 +709,9 @@ class TestMeshes(TestCaseMixin, unittest.TestCase): self.assertEqual(cpu_device, mesh.device) self.assertIs(mesh, converted_mesh) - cuda_device = torch.device("cuda") + cuda_device = torch.device("cuda:0") - converted_mesh = mesh.to("cuda") + converted_mesh = mesh.to("cuda:0") self.assertEqual(cuda_device, converted_mesh.device) self.assertEqual(cpu_device, mesh.device) self.assertIsNot(mesh, converted_mesh) diff --git a/tests/test_rendering_utils.py b/tests/test_rendering_utils.py index dc60b00a..6828ff41 100644 --- a/tests/test_rendering_utils.py +++ b/tests/test_rendering_utils.py @@ -39,7 +39,17 @@ class TestTensorProperties(TestCaseMixin, unittest.TestCase): example = TensorPropertiesTestClass(x=10.0, y=(100.0, 200.0)) device = torch.device("cuda:0") new_example = example.to(device=device) - self.assertTrue(new_example.device == device) + self.assertEqual(new_example.device, device) + + example_cpu = example.cpu() + self.assertEqual(example_cpu.device, torch.device("cpu")) + + example_gpu = example.cuda() + self.assertEqual(example_gpu.device.type, "cuda") + self.assertIsNotNone(example_gpu.device.index) + + example_gpu1 = example.cuda(1) + self.assertEqual(example_gpu1.device, torch.device("cuda:1")) def test_clone(self): # Check clone method diff --git a/tests/test_shader.py b/tests/test_shader.py index 11ad66d5..abc2bd8d 100644 --- a/tests/test_shader.py +++ b/tests/test_shader.py @@ -22,7 +22,7 @@ from pytorch3d.structures.meshes import Meshes class TestShader(TestCaseMixin, unittest.TestCase): def test_to(self): cpu_device = torch.device("cpu") - cuda_device = torch.device("cuda") + cuda_device = torch.device("cuda:0") R, T = look_at_view_transform() diff --git a/tests/test_transforms.py b/tests/test_transforms.py index d77a5d84..25dc21bb 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -50,9 +50,9 @@ class TestTransform(TestCaseMixin, unittest.TestCase): self.assertEqual(torch.float32, t.dtype) self.assertIsNot(t, cpu_t) - cuda_device = torch.device("cuda") + cuda_device = torch.device("cuda:0") - cuda_t = t.to("cuda") + cuda_t = t.to("cuda:0") self.assertEqual(cuda_device, cuda_t.device) self.assertEqual(cpu_device, t.device) self.assertEqual(torch.float32, cuda_t.dtype)