Adding utility methods to TensorProperties

Summary:
Context: in the code we are releasing with CO3D dataset, we use  `cuda()` on TensorProperties like Pointclouds and Cameras where we recursively move batch to a GPU. It would be good to push it to a release so we don’t need to depend on the nightly build.

Additionally, I aligned the logic of `.to("cuda")` without device index to the one of `torch.Tensor` where the current device is populated to index. It should not affect any actual use cases but some tests had to be changed.

Reviewed By: bottler

Differential Revision: D29659529

fbshipit-source-id: abe58aeaca14bacc68da3e6cf5ae07df3353e3ce
This commit is contained in:
Roman Shapovalov 2021-07-13 10:28:41 -07:00 committed by Facebook GitHub Bot
parent fa44a05567
commit 0c02ae907e
6 changed files with 31 additions and 9 deletions

View File

@ -15,7 +15,8 @@ Device = Union[str, torch.device]
def make_device(device: Device) -> torch.device: def make_device(device: Device) -> torch.device:
""" """
Makes an actual torch.device object from the device specified as Makes an actual torch.device object from the device specified as
either a string or torch.device object. either a string or torch.device object. If the device is `cuda` without
a specific index, the index of the current device is assigned.
Args: Args:
device: Device (as str or torch.device) device: Device (as str or torch.device)
@ -23,7 +24,12 @@ def make_device(device: Device) -> torch.device:
Returns: Returns:
A matching torch.device object A matching torch.device object
""" """
return torch.device(device) if isinstance(device, str) else device device = torch.device(device) if isinstance(device, str) else device
if device.type == "cuda" and device.index is None: # pyre-ignore[16]
# If cuda but with no index, then the current cuda device is indicated.
# In that case, we fix to that device
device = torch.device(f"cuda:{torch.cuda.current_device()}")
return device
def get_device(x, device: Optional[Device] = None) -> torch.device: def get_device(x, device: Optional[Device] = None) -> torch.device:

View File

@ -8,7 +8,7 @@
import copy import copy
import inspect import inspect
import warnings import warnings
from typing import Any, Union from typing import Any, Optional, Union
import numpy as np import numpy as np
import torch import torch
@ -174,6 +174,12 @@ class TensorProperties(nn.Module):
setattr(self, k, v.to(device_)) setattr(self, k, v.to(device_))
return self return self
def cpu(self) -> "TensorProperties":
return self.to("cpu")
def cuda(self, device: Optional[int] = None) -> "TensorProperties":
return self.to(f"cuda:{device}" if device is not None else "cuda")
def clone(self, other) -> "TensorProperties": def clone(self, other) -> "TensorProperties":
""" """
Update the tensor properties of other with the cloned properties of self. Update the tensor properties of other with the cloned properties of self.

View File

@ -709,9 +709,9 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
self.assertEqual(cpu_device, mesh.device) self.assertEqual(cpu_device, mesh.device)
self.assertIs(mesh, converted_mesh) self.assertIs(mesh, converted_mesh)
cuda_device = torch.device("cuda") cuda_device = torch.device("cuda:0")
converted_mesh = mesh.to("cuda") converted_mesh = mesh.to("cuda:0")
self.assertEqual(cuda_device, converted_mesh.device) self.assertEqual(cuda_device, converted_mesh.device)
self.assertEqual(cpu_device, mesh.device) self.assertEqual(cpu_device, mesh.device)
self.assertIsNot(mesh, converted_mesh) self.assertIsNot(mesh, converted_mesh)

View File

@ -39,7 +39,17 @@ class TestTensorProperties(TestCaseMixin, unittest.TestCase):
example = TensorPropertiesTestClass(x=10.0, y=(100.0, 200.0)) example = TensorPropertiesTestClass(x=10.0, y=(100.0, 200.0))
device = torch.device("cuda:0") device = torch.device("cuda:0")
new_example = example.to(device=device) new_example = example.to(device=device)
self.assertTrue(new_example.device == device) self.assertEqual(new_example.device, device)
example_cpu = example.cpu()
self.assertEqual(example_cpu.device, torch.device("cpu"))
example_gpu = example.cuda()
self.assertEqual(example_gpu.device.type, "cuda")
self.assertIsNotNone(example_gpu.device.index)
example_gpu1 = example.cuda(1)
self.assertEqual(example_gpu1.device, torch.device("cuda:1"))
def test_clone(self): def test_clone(self):
# Check clone method # Check clone method

View File

@ -22,7 +22,7 @@ from pytorch3d.structures.meshes import Meshes
class TestShader(TestCaseMixin, unittest.TestCase): class TestShader(TestCaseMixin, unittest.TestCase):
def test_to(self): def test_to(self):
cpu_device = torch.device("cpu") cpu_device = torch.device("cpu")
cuda_device = torch.device("cuda") cuda_device = torch.device("cuda:0")
R, T = look_at_view_transform() R, T = look_at_view_transform()

View File

@ -50,9 +50,9 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
self.assertEqual(torch.float32, t.dtype) self.assertEqual(torch.float32, t.dtype)
self.assertIsNot(t, cpu_t) self.assertIsNot(t, cpu_t)
cuda_device = torch.device("cuda") cuda_device = torch.device("cuda:0")
cuda_t = t.to("cuda") cuda_t = t.to("cuda:0")
self.assertEqual(cuda_device, cuda_t.device) self.assertEqual(cuda_device, cuda_t.device)
self.assertEqual(cpu_device, t.device) self.assertEqual(cpu_device, t.device)
self.assertEqual(torch.float32, cuda_t.dtype) self.assertEqual(torch.float32, cuda_t.dtype)