mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-15 17:05:58 +08:00
Fix type annotations for device type
Summary: Fix type annotations for device type Reviewed By: nikhilaravi Differential Revision: D28971179 fbshipit-source-id: 410b673c76dfd65ac51b2d144f17ed86a04a3058
This commit is contained in:
committed by
Facebook GitHub Bot
parent
1f9661e150
commit
626bf3fe23
@@ -10,6 +10,8 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..common.types import Device, make_device
|
||||
|
||||
|
||||
class TensorAccessor(nn.Module):
|
||||
"""
|
||||
@@ -88,17 +90,19 @@ class TensorProperties(nn.Module):
|
||||
A mix-in class for storing tensors as properties with helper methods.
|
||||
"""
|
||||
|
||||
def __init__(self, dtype=torch.float32, device="cpu", **kwargs):
|
||||
def __init__(
|
||||
self, dtype: torch.dtype = torch.float32, device: Device = "cpu", **kwargs
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dtype: data type to set for the inputs
|
||||
device: str or torch.device
|
||||
device: Device (as str or torch.device)
|
||||
kwargs: any number of keyword arguments. Any arguments which are
|
||||
of type (float/int/tuple/tensor/array) are broadcasted and
|
||||
of type (float/int/list/tuple/tensor/array) are broadcasted and
|
||||
other keyword arguments are set as attributes.
|
||||
"""
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.device = make_device(device)
|
||||
self._N = 0
|
||||
if kwargs is not None:
|
||||
|
||||
@@ -108,7 +112,7 @@ class TensorProperties(nn.Module):
|
||||
for k, v in kwargs.items():
|
||||
if v is None or isinstance(v, (str, bool)):
|
||||
setattr(self, k, v)
|
||||
elif isinstance(v, BROADCAST_TYPES):
|
||||
elif isinstance(v, BROADCAST_TYPES): # pyre-fixme[6]
|
||||
args_to_broadcast[k] = v
|
||||
else:
|
||||
msg = "Arg %s with type %r is not broadcastable"
|
||||
@@ -152,17 +156,18 @@ class TensorProperties(nn.Module):
|
||||
msg = "Expected index of type int or slice; got %r"
|
||||
raise ValueError(msg % type(index))
|
||||
|
||||
def to(self, device: str = "cpu"):
|
||||
def to(self, device: Device = "cpu"):
|
||||
"""
|
||||
In place operation to move class properties which are tensors to a
|
||||
specified device. If self has a property "device", update this as well.
|
||||
"""
|
||||
device_ = make_device(device)
|
||||
for k in dir(self):
|
||||
v = getattr(self, k)
|
||||
if k == "device":
|
||||
setattr(self, k, device)
|
||||
if torch.is_tensor(v) and v.device != device:
|
||||
setattr(self, k, v.to(device))
|
||||
setattr(self, k, device_)
|
||||
if torch.is_tensor(v) and v.device != device_:
|
||||
setattr(self, k, v.to(device_))
|
||||
return self
|
||||
|
||||
def clone(self, other):
|
||||
@@ -257,28 +262,37 @@ class TensorProperties(nn.Module):
|
||||
return self
|
||||
|
||||
|
||||
def format_tensor(input, dtype=torch.float32, device: str = "cpu") -> torch.Tensor:
|
||||
def format_tensor(
|
||||
input, dtype: torch.dtype = torch.float32, device: Device = "cpu"
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Helper function for converting a scalar value to a tensor.
|
||||
|
||||
Args:
|
||||
input: Python scalar, Python list/tuple, torch scalar, 1D torch tensor
|
||||
dtype: data type for the input
|
||||
device: torch device on which the tensor should be placed.
|
||||
device: Device (as str or torch.device) on which the tensor should be placed.
|
||||
|
||||
Returns:
|
||||
input_vec: torch tensor with optional added batch dimension.
|
||||
"""
|
||||
device_ = make_device(device)
|
||||
if not torch.is_tensor(input):
|
||||
input = torch.tensor(input, dtype=dtype, device=device)
|
||||
input = torch.tensor(input, dtype=dtype, device=device_)
|
||||
|
||||
if input.dim() == 0:
|
||||
input = input.view(1)
|
||||
if input.device != device:
|
||||
input = input.to(device=device)
|
||||
|
||||
if input.device == device_:
|
||||
return input
|
||||
|
||||
input = input.to(device=device)
|
||||
return input
|
||||
|
||||
|
||||
def convert_to_tensors_and_broadcast(*args, dtype=torch.float32, device: str = "cpu"):
|
||||
def convert_to_tensors_and_broadcast(
|
||||
*args, dtype: torch.dtype = torch.float32, device: Device = "cpu"
|
||||
):
|
||||
"""
|
||||
Helper function to handle parsing an arbitrary number of inputs (*args)
|
||||
which all need to have the same batch dimension.
|
||||
|
||||
Reference in New Issue
Block a user