Adding utility methods to TensorProperties

Summary:
Context: in the code we are releasing with CO3D dataset, we use  `cuda()` on TensorProperties like Pointclouds and Cameras where we recursively move batch to a GPU. It would be good to push it to a release so we don’t need to depend on the nightly build.

Additionally, I aligned the logic of `.to("cuda")` without device index to the one of `torch.Tensor` where the current device is populated to index. It should not affect any actual use cases but some tests had to be changed.

Reviewed By: bottler

Differential Revision: D29659529

fbshipit-source-id: abe58aeaca14bacc68da3e6cf5ae07df3353e3ce
This commit is contained in:
Roman Shapovalov
2021-07-13 10:28:41 -07:00
committed by Facebook GitHub Bot
parent fa44a05567
commit 0c02ae907e
6 changed files with 31 additions and 9 deletions

View File

@@ -15,7 +15,8 @@ Device = Union[str, torch.device]
def make_device(device: Device) -> torch.device:
"""
Makes an actual torch.device object from the device specified as
either a string or torch.device object.
either a string or torch.device object. If the device is `cuda` without
a specific index, the index of the current device is assigned.
Args:
device: Device (as str or torch.device)
@@ -23,7 +24,12 @@ def make_device(device: Device) -> torch.device:
Returns:
A matching torch.device object
"""
return torch.device(device) if isinstance(device, str) else device
device = torch.device(device) if isinstance(device, str) else device
if device.type == "cuda" and device.index is None: # pyre-ignore[16]
# If cuda but with no index, then the current cuda device is indicated.
# In that case, we fix to that device
device = torch.device(f"cuda:{torch.cuda.current_device()}")
return device
def get_device(x, device: Optional[Device] = None) -> torch.device:

View File

@@ -8,7 +8,7 @@
import copy
import inspect
import warnings
from typing import Any, Union
from typing import Any, Optional, Union
import numpy as np
import torch
@@ -174,6 +174,12 @@ class TensorProperties(nn.Module):
setattr(self, k, v.to(device_))
return self
def cpu(self) -> "TensorProperties":
return self.to("cpu")
def cuda(self, device: Optional[int] = None) -> "TensorProperties":
return self.to(f"cuda:{device}" if device is not None else "cuda")
def clone(self, other) -> "TensorProperties":
"""
Update the tensor properties of other with the cloned properties of self.