Fix type annotations for device type

Summary: Fix type annotations for device type

Reviewed By: nikhilaravi

Differential Revision: D28971179

fbshipit-source-id: 410b673c76dfd65ac51b2d144f17ed86a04a3058
This commit is contained in:
Patrick Labatut
2021-06-09 15:48:56 -07:00
committed by Facebook GitHub Bot
parent 1f9661e150
commit 626bf3fe23
12 changed files with 110 additions and 58 deletions

View File

@@ -10,6 +10,8 @@ import numpy as np
import torch
import torch.nn as nn
from ..common.types import Device, make_device
class TensorAccessor(nn.Module):
"""
@@ -88,17 +90,19 @@ class TensorProperties(nn.Module):
A mix-in class for storing tensors as properties with helper methods.
"""
def __init__(self, dtype=torch.float32, device="cpu", **kwargs):
def __init__(
self, dtype: torch.dtype = torch.float32, device: Device = "cpu", **kwargs
):
"""
Args:
dtype: data type to set for the inputs
device: str or torch.device
device: Device (as str or torch.device)
kwargs: any number of keyword arguments. Any arguments which are
of type (float/int/tuple/tensor/array) are broadcasted and
of type (float/int/list/tuple/tensor/array) are broadcasted and
other keyword arguments are set as attributes.
"""
super().__init__()
self.device = device
self.device = make_device(device)
self._N = 0
if kwargs is not None:
@@ -108,7 +112,7 @@ class TensorProperties(nn.Module):
for k, v in kwargs.items():
if v is None or isinstance(v, (str, bool)):
setattr(self, k, v)
elif isinstance(v, BROADCAST_TYPES):
elif isinstance(v, BROADCAST_TYPES): # pyre-fixme[6]
args_to_broadcast[k] = v
else:
msg = "Arg %s with type %r is not broadcastable"
@@ -152,17 +156,18 @@ class TensorProperties(nn.Module):
msg = "Expected index of type int or slice; got %r"
raise ValueError(msg % type(index))
def to(self, device: str = "cpu"):
def to(self, device: Device = "cpu"):
"""
In place operation to move class properties which are tensors to a
specified device. If self has a property "device", update this as well.
"""
device_ = make_device(device)
for k in dir(self):
v = getattr(self, k)
if k == "device":
setattr(self, k, device)
if torch.is_tensor(v) and v.device != device:
setattr(self, k, v.to(device))
setattr(self, k, device_)
if torch.is_tensor(v) and v.device != device_:
setattr(self, k, v.to(device_))
return self
def clone(self, other):
@@ -257,28 +262,37 @@ class TensorProperties(nn.Module):
return self
def format_tensor(input, dtype=torch.float32, device: str = "cpu") -> torch.Tensor:
def format_tensor(
input, dtype: torch.dtype = torch.float32, device: Device = "cpu"
) -> torch.Tensor:
"""
Helper function for converting a scalar value to a tensor.
Args:
input: Python scalar, Python list/tuple, torch scalar, 1D torch tensor
dtype: data type for the input
device: torch device on which the tensor should be placed.
device: Device (as str or torch.device) on which the tensor should be placed.
Returns:
input_vec: torch tensor with optional added batch dimension.
"""
device_ = make_device(device)
if not torch.is_tensor(input):
input = torch.tensor(input, dtype=dtype, device=device)
input = torch.tensor(input, dtype=dtype, device=device_)
if input.dim() == 0:
input = input.view(1)
if input.device != device:
input = input.to(device=device)
if input.device == device_:
return input
input = input.to(device=device)
return input
def convert_to_tensors_and_broadcast(*args, dtype=torch.float32, device: str = "cpu"):
def convert_to_tensors_and_broadcast(
*args, dtype: torch.dtype = torch.float32, device: Device = "cpu"
):
"""
Helper function to handle parsing an arbitrary number of inputs (*args)
which all need to have the same batch dimension.