mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Fix type annotations for device type
Summary: Fix type annotations for device type Reviewed By: nikhilaravi Differential Revision: D28971179 fbshipit-source-id: 410b673c76dfd65ac51b2d144f17ed86a04a3058
This commit is contained in:
parent
1f9661e150
commit
626bf3fe23
@ -9,6 +9,7 @@ from typing import Dict, List, Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from pytorch3d.common.types import Device
|
||||
from pytorch3d.datasets.shapenet_base import ShapeNetBase
|
||||
from pytorch3d.renderer import HardPhongShader
|
||||
from tabulate import tabulate
|
||||
@ -371,7 +372,7 @@ class R2N2(ShapeNetBase):
|
||||
idxs: Optional[List[int]] = None,
|
||||
view_idxs: Optional[List[int]] = None,
|
||||
shader_type=HardPhongShader,
|
||||
device="cpu",
|
||||
device: Device = "cpu",
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@ -394,7 +395,7 @@ class R2N2(ShapeNetBase):
|
||||
idxs: List[int] of indices of models to be rendered in the dataset.
|
||||
shader_type: Shader to use for rendering. Examples include HardPhongShader
|
||||
(default), SoftPhongShader etc or any other type of valid Shader class.
|
||||
device: torch.device on which the tensors should be located.
|
||||
device: Device (as str or torch.device) on which the tensors should be located.
|
||||
**kwargs: Accepts any of the kwargs that the renderer supports and any of the
|
||||
args that BlenderCamera supports.
|
||||
|
||||
|
@ -4,6 +4,7 @@ from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d.common.types import Device
|
||||
from pytorch3d.datasets.utils import collate_batched_meshes
|
||||
from pytorch3d.ops import cubify
|
||||
from pytorch3d.renderer import (
|
||||
@ -431,13 +432,13 @@ class BlenderCamera(CamerasBase):
|
||||
(which uses Blender for rendering the views for each model).
|
||||
"""
|
||||
|
||||
def __init__(self, R=r, T=t, K=k, device="cpu"):
|
||||
def __init__(self, R=r, T=t, K=k, device: Device = "cpu"):
|
||||
"""
|
||||
Args:
|
||||
R: Rotation matrix of shape (N, 3, 3).
|
||||
T: Translation matrix of shape (N, 3).
|
||||
K: Intrinsic matrix of shape (N, 4, 4).
|
||||
device: torch.device or str.
|
||||
device: Device (as str or torch.device).
|
||||
"""
|
||||
# The initializer formats all inputs to torch tensors and broadcasts
|
||||
# all the inputs to have the same batch dimension where necessary.
|
||||
@ -450,7 +451,7 @@ class BlenderCamera(CamerasBase):
|
||||
|
||||
|
||||
def render_cubified_voxels(
|
||||
voxels: torch.Tensor, shader_type=HardPhongShader, device="cpu", **kwargs
|
||||
voxels: torch.Tensor, shader_type=HardPhongShader, device: Device = "cpu", **kwargs
|
||||
):
|
||||
"""
|
||||
Use the Cubify operator to convert inputs voxels to a mesh and then render that mesh.
|
||||
@ -461,7 +462,7 @@ def render_cubified_voxels(
|
||||
shader_type: shader_type: shader_type: Shader to use for rendering. Examples
|
||||
include HardPhongShader (default), SoftPhongShader etc or any other type
|
||||
of valid Shader class.
|
||||
device: torch.device on which the tensors should be located.
|
||||
device: Device (as str or torch.device) on which the tensors should be located.
|
||||
**kwargs: Accepts any of the kwargs that the renderer supports.
|
||||
Returns:
|
||||
Batch of rendered images of shape (N, H, W, 3).
|
||||
|
@ -4,6 +4,7 @@ import warnings
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.common.types import Device
|
||||
from pytorch3d.io import load_obj
|
||||
from pytorch3d.renderer import (
|
||||
FoVPerspectiveCameras,
|
||||
@ -105,7 +106,7 @@ class ShapeNetBase(torch.utils.data.Dataset):
|
||||
sample_nums: Optional[List[int]] = None,
|
||||
idxs: Optional[List[int]] = None,
|
||||
shader_type=HardPhongShader,
|
||||
device="cpu",
|
||||
device: Device = "cpu",
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@ -129,7 +130,7 @@ class ShapeNetBase(torch.utils.data.Dataset):
|
||||
shader_type: Select shading. Valid options include HardPhongShader (default),
|
||||
SoftPhongShader, HardGouraudShader, SoftGouraudShader, HardFlatShader,
|
||||
SoftSilhouetteShader.
|
||||
device: torch.device on which the tensors should be located.
|
||||
device: Device (as str or torch.device) on which the tensors should be located.
|
||||
**kwargs: Accepts any of the kwargs that the renderer supports.
|
||||
|
||||
Returns:
|
||||
|
@ -9,6 +9,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.common.types import Device
|
||||
from pytorch3d.io.utils import _open_file, _read_image
|
||||
|
||||
|
||||
@ -393,7 +394,7 @@ TextureImages = Dict[str, torch.Tensor]
|
||||
|
||||
|
||||
def _parse_mtl(
|
||||
f, path_manager: PathManager, device="cpu"
|
||||
f, path_manager: PathManager, device: Device = "cpu"
|
||||
) -> Tuple[MaterialProperties, TextureFiles]:
|
||||
material_properties = {}
|
||||
texture_files = {}
|
||||
@ -474,7 +475,7 @@ def load_mtl(
|
||||
*,
|
||||
material_names: List[str],
|
||||
data_dir: str,
|
||||
device="cpu",
|
||||
device: Device = "cpu",
|
||||
path_manager: PathManager,
|
||||
) -> Tuple[MaterialProperties, TextureImages]:
|
||||
"""
|
||||
@ -485,6 +486,7 @@ def load_mtl(
|
||||
f: a file-like object of the material information.
|
||||
material_names: a list of the material names found in the .obj file.
|
||||
data_dir: the directory where the material texture files are located.
|
||||
device: Device (as str or torch.tensor) on which to return the new tensors.
|
||||
path_manager: PathManager for interpreting both f and material_names.
|
||||
|
||||
Returns:
|
||||
|
@ -11,6 +11,7 @@ from typing import List, Optional, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.common.types import Device
|
||||
from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas
|
||||
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
|
||||
from pytorch3d.renderer import TexturesAtlas, TexturesUV
|
||||
@ -71,7 +72,7 @@ def load_obj(
|
||||
create_texture_atlas: bool = False,
|
||||
texture_atlas_size: int = 4,
|
||||
texture_wrap: Optional[str] = "repeat",
|
||||
device="cpu",
|
||||
device: Device = "cpu",
|
||||
path_manager: Optional[PathManager] = None,
|
||||
):
|
||||
"""
|
||||
@ -143,7 +144,7 @@ def load_obj(
|
||||
is ignored and a repeating pattern is formed.
|
||||
If `texture_mode="clamp"` the values are clamped to the range [0, 1].
|
||||
If None, then there is no transformation of the texture values.
|
||||
device: string or torch.device on which to return the new tensors.
|
||||
device: Device (as str or torch.device) on which to return the new tensors.
|
||||
path_manager: optionally a PathManager object to interpret paths.
|
||||
|
||||
Returns:
|
||||
@ -226,7 +227,7 @@ def load_obj(
|
||||
|
||||
def load_objs_as_meshes(
|
||||
files: list,
|
||||
device=None,
|
||||
device: Optional[Device] = None,
|
||||
load_textures: bool = True,
|
||||
create_texture_atlas: bool = False,
|
||||
texture_atlas_size: int = 4,
|
||||
@ -293,7 +294,7 @@ class MeshObjFormat(MeshFormatInterpreter):
|
||||
self,
|
||||
path: Union[str, Path],
|
||||
include_textures: bool,
|
||||
device,
|
||||
device: Device,
|
||||
path_manager: PathManager,
|
||||
create_texture_atlas: bool = False,
|
||||
texture_atlas_size: int = 4,
|
||||
@ -497,7 +498,7 @@ def _load_materials(
|
||||
*,
|
||||
data_dir: str,
|
||||
load_textures: bool,
|
||||
device,
|
||||
device: Device,
|
||||
path_manager: PathManager,
|
||||
):
|
||||
"""
|
||||
@ -508,7 +509,7 @@ def _load_materials(
|
||||
f: a file-like object of the material information.
|
||||
data_dir: the directory where the material texture files are located.
|
||||
load_textures: whether textures should be loaded.
|
||||
device: string or torch.device on which to return the new tensors.
|
||||
device: Device (as str or torch.device) on which to return the new tensors.
|
||||
path_manager: PathManager object to interpret paths.
|
||||
|
||||
Returns:
|
||||
@ -546,7 +547,7 @@ def _load_obj(
|
||||
texture_atlas_size: int = 4,
|
||||
texture_wrap: Optional[str] = "repeat",
|
||||
path_manager: PathManager,
|
||||
device="cpu",
|
||||
device: Device = "cpu",
|
||||
):
|
||||
"""
|
||||
Load a mesh from a file-like object. See load_obj function more details.
|
||||
|
@ -8,6 +8,7 @@ from pathlib import Path
|
||||
from typing import Deque, Optional, Union
|
||||
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.common.types import Device
|
||||
from pytorch3d.structures import Meshes, Pointclouds
|
||||
|
||||
from .obj_io import MeshObjFormat
|
||||
@ -108,7 +109,7 @@ class IO:
|
||||
self,
|
||||
path: Union[str, Path],
|
||||
include_textures: bool = True,
|
||||
device="cpu",
|
||||
device: Device = "cpu",
|
||||
**kwargs,
|
||||
) -> Meshes:
|
||||
"""
|
||||
@ -168,14 +169,14 @@ class IO:
|
||||
raise ValueError(f"No mesh interpreter found to write to {path}.")
|
||||
|
||||
def load_pointcloud(
|
||||
self, path: Union[str, Path], device="cpu", **kwargs
|
||||
self, path: Union[str, Path], device: Device = "cpu", **kwargs
|
||||
) -> Pointclouds:
|
||||
"""
|
||||
Attempt to load a point cloud from the given file, using a registered format.
|
||||
|
||||
Args:
|
||||
path: file to read
|
||||
device: torch.device on which to load the data.
|
||||
device: Device (as str or torch.device) on which to load the data.
|
||||
|
||||
Returns:
|
||||
new Pointclouds object containing one mesh.
|
||||
|
@ -10,6 +10,8 @@ import torch
|
||||
from iopath.common.file_io import PathManager
|
||||
from PIL import Image
|
||||
|
||||
from ..common.types import Device
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def nullcontext(x):
|
||||
@ -31,7 +33,7 @@ def _open_file(f, path_manager: PathManager, mode="r") -> ContextManager[IO]:
|
||||
|
||||
|
||||
def _make_tensor(
|
||||
data, cols: int, dtype: torch.dtype, device: str = "cpu"
|
||||
data, cols: int, dtype: torch.dtype, device: Device = "cpu"
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Return a 2D tensor with the specified cols and dtype filled with data,
|
||||
|
@ -7,6 +7,7 @@ from typing import Optional, Sequence, Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch3d.common.types import Device
|
||||
from pytorch3d.transforms import Rotate, Transform3d, Translate
|
||||
|
||||
from .utils import TensorProperties, convert_to_tensors_and_broadcast
|
||||
@ -290,7 +291,7 @@ def OpenGLPerspectiveCameras(
|
||||
degrees: bool = True,
|
||||
R=_R,
|
||||
T=_T,
|
||||
device="cpu",
|
||||
device: Device = "cpu",
|
||||
):
|
||||
"""
|
||||
OpenGLPerspectiveCameras has been DEPRECATED. Use FoVPerspectiveCameras instead.
|
||||
@ -358,7 +359,7 @@ class FoVPerspectiveCameras(CamerasBase):
|
||||
R=_R,
|
||||
T=_T,
|
||||
K=None,
|
||||
device="cpu",
|
||||
device: Device = "cpu",
|
||||
):
|
||||
"""
|
||||
|
||||
@ -373,7 +374,7 @@ class FoVPerspectiveCameras(CamerasBase):
|
||||
T: Translation matrix of shape (N, 3)
|
||||
K: (optional) A calibration matrix of shape (N, 4, 4)
|
||||
If provided, don't need znear, zfar, fov, aspect_ratio, degrees
|
||||
device: torch.device or string
|
||||
device: Device (as str or torch.device)
|
||||
"""
|
||||
# The initializer formats all inputs to torch tensors and broadcasts
|
||||
# all the inputs to have the same batch dimension where necessary.
|
||||
|
@ -4,6 +4,7 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..common.types import Device
|
||||
from .utils import TensorProperties, convert_to_tensors_and_broadcast
|
||||
|
||||
|
||||
@ -158,7 +159,7 @@ class DirectionalLights(TensorProperties):
|
||||
diffuse_color=((0.3, 0.3, 0.3),),
|
||||
specular_color=((0.2, 0.2, 0.2),),
|
||||
direction=((0, 1, 0),),
|
||||
device="cpu",
|
||||
device: Device = "cpu",
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -166,7 +167,7 @@ class DirectionalLights(TensorProperties):
|
||||
diffuse_color: RGB color of the diffuse component.
|
||||
specular_color: RGB color of the specular component.
|
||||
direction: (x, y, z) direction vector of the light.
|
||||
device: torch.device on which the tensors should be located
|
||||
device: Device (as str or torch.device) on which the tensors should be located
|
||||
|
||||
The inputs can each be
|
||||
- 3 element tuple/list or list of lists
|
||||
@ -219,7 +220,7 @@ class PointLights(TensorProperties):
|
||||
diffuse_color=((0.3, 0.3, 0.3),),
|
||||
specular_color=((0.2, 0.2, 0.2),),
|
||||
location=((0, 1, 0),),
|
||||
device="cpu",
|
||||
device: Device = "cpu",
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -227,7 +228,7 @@ class PointLights(TensorProperties):
|
||||
diffuse_color: RGB color of the diffuse component
|
||||
specular_color: RGB color of the specular component
|
||||
location: xyz position of the light.
|
||||
device: torch.device on which the tensors should be located
|
||||
device: Device (as str or torch.device) on which the tensors should be located
|
||||
|
||||
The inputs can each be
|
||||
- 3 element tuple/list or list of lists
|
||||
@ -275,14 +276,14 @@ class AmbientLights(TensorProperties):
|
||||
not used in rendering.
|
||||
"""
|
||||
|
||||
def __init__(self, *, ambient_color=None, device="cpu"):
|
||||
def __init__(self, *, ambient_color=None, device: Device = "cpu"):
|
||||
"""
|
||||
If ambient_color is provided, it should be a sequence of
|
||||
triples of floats.
|
||||
|
||||
Args:
|
||||
ambient_color: RGB color
|
||||
device: torch.device on which the tensors should be located
|
||||
device: Device (as str or torch.device) on which the tensors should be located
|
||||
|
||||
The ambient_color if provided, should be
|
||||
- 3 element tuple/list or list of lists
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
import torch
|
||||
|
||||
from ..common.types import Device
|
||||
from .utils import TensorProperties
|
||||
|
||||
|
||||
@ -18,7 +19,7 @@ class Materials(TensorProperties):
|
||||
diffuse_color=((1, 1, 1),),
|
||||
specular_color=((1, 1, 1),),
|
||||
shininess=64,
|
||||
device="cpu",
|
||||
device: Device = "cpu",
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -29,7 +30,7 @@ class Materials(TensorProperties):
|
||||
the focus of the specular highlight with a high value
|
||||
resulting in a concentrated highlight. Shininess values
|
||||
can range from 0-1000.
|
||||
device: torch.device or string
|
||||
device: Device (as str or torch.device) on which the tensors should be located
|
||||
|
||||
ambient_color, diffuse_color and specular_color can be of shape
|
||||
(1, 3) or (N, 3). shininess can be of shape (1) or (N).
|
||||
|
@ -5,6 +5,7 @@ import warnings
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...common.types import Device
|
||||
from ..blending import (
|
||||
BlendParams,
|
||||
hard_rgb_blend,
|
||||
@ -40,7 +41,12 @@ class HardPhongShader(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
|
||||
self,
|
||||
device: Device = "cpu",
|
||||
cameras=None,
|
||||
lights=None,
|
||||
materials=None,
|
||||
blend_params=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.lights = lights if lights is not None else PointLights(device=device)
|
||||
@ -50,7 +56,7 @@ class HardPhongShader(nn.Module):
|
||||
self.cameras = cameras
|
||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||
|
||||
def to(self, device):
|
||||
def to(self, device: Device):
|
||||
# Manually move to device modules which are not subclasses of nn.Module
|
||||
self.cameras = self.cameras.to(device)
|
||||
self.materials = self.materials.to(device)
|
||||
@ -95,7 +101,12 @@ class SoftPhongShader(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
|
||||
self,
|
||||
device: Device = "cpu",
|
||||
cameras=None,
|
||||
lights=None,
|
||||
materials=None,
|
||||
blend_params=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.lights = lights if lights is not None else PointLights(device=device)
|
||||
@ -105,7 +116,7 @@ class SoftPhongShader(nn.Module):
|
||||
self.cameras = cameras
|
||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||
|
||||
def to(self, device):
|
||||
def to(self, device: Device):
|
||||
# Manually move to device modules which are not subclasses of nn.Module
|
||||
self.cameras = self.cameras.to(device)
|
||||
self.materials = self.materials.to(device)
|
||||
@ -155,7 +166,12 @@ class HardGouraudShader(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
|
||||
self,
|
||||
device: Device = "cpu",
|
||||
cameras=None,
|
||||
lights=None,
|
||||
materials=None,
|
||||
blend_params=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.lights = lights if lights is not None else PointLights(device=device)
|
||||
@ -165,7 +181,7 @@ class HardGouraudShader(nn.Module):
|
||||
self.cameras = cameras
|
||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||
|
||||
def to(self, device):
|
||||
def to(self, device: Device):
|
||||
# Manually move to device modules which are not subclasses of nn.Module
|
||||
self.cameras = self.cameras.to(device)
|
||||
self.materials = self.materials.to(device)
|
||||
@ -214,7 +230,12 @@ class SoftGouraudShader(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
|
||||
self,
|
||||
device: Device = "cpu",
|
||||
cameras=None,
|
||||
lights=None,
|
||||
materials=None,
|
||||
blend_params=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.lights = lights if lights is not None else PointLights(device=device)
|
||||
@ -224,7 +245,7 @@ class SoftGouraudShader(nn.Module):
|
||||
self.cameras = cameras
|
||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||
|
||||
def to(self, device):
|
||||
def to(self, device: Device):
|
||||
# Manually move to device modules which are not subclasses of nn.Module
|
||||
self.cameras = self.cameras.to(device)
|
||||
self.materials = self.materials.to(device)
|
||||
@ -255,7 +276,7 @@ class SoftGouraudShader(nn.Module):
|
||||
|
||||
|
||||
def TexturedSoftPhongShader(
|
||||
device="cpu", cameras=None, lights=None, materials=None, blend_params=None
|
||||
device: Device = "cpu", cameras=None, lights=None, materials=None, blend_params=None
|
||||
):
|
||||
"""
|
||||
TexturedSoftPhongShader class has been DEPRECATED. Use SoftPhongShader instead.
|
||||
@ -290,7 +311,12 @@ class HardFlatShader(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
|
||||
self,
|
||||
device: Device = "cpu",
|
||||
cameras=None,
|
||||
lights=None,
|
||||
materials=None,
|
||||
blend_params=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.lights = lights if lights is not None else PointLights(device=device)
|
||||
@ -300,7 +326,7 @@ class HardFlatShader(nn.Module):
|
||||
self.cameras = cameras
|
||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||
|
||||
def to(self, device):
|
||||
def to(self, device: Device):
|
||||
# Manually move to device modules which are not subclasses of nn.Module
|
||||
self.cameras = self.cameras.to(device)
|
||||
self.materials = self.materials.to(device)
|
||||
|
@ -10,6 +10,8 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..common.types import Device, make_device
|
||||
|
||||
|
||||
class TensorAccessor(nn.Module):
|
||||
"""
|
||||
@ -88,17 +90,19 @@ class TensorProperties(nn.Module):
|
||||
A mix-in class for storing tensors as properties with helper methods.
|
||||
"""
|
||||
|
||||
def __init__(self, dtype=torch.float32, device="cpu", **kwargs):
|
||||
def __init__(
|
||||
self, dtype: torch.dtype = torch.float32, device: Device = "cpu", **kwargs
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dtype: data type to set for the inputs
|
||||
device: str or torch.device
|
||||
device: Device (as str or torch.device)
|
||||
kwargs: any number of keyword arguments. Any arguments which are
|
||||
of type (float/int/tuple/tensor/array) are broadcasted and
|
||||
of type (float/int/list/tuple/tensor/array) are broadcasted and
|
||||
other keyword arguments are set as attributes.
|
||||
"""
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.device = make_device(device)
|
||||
self._N = 0
|
||||
if kwargs is not None:
|
||||
|
||||
@ -108,7 +112,7 @@ class TensorProperties(nn.Module):
|
||||
for k, v in kwargs.items():
|
||||
if v is None or isinstance(v, (str, bool)):
|
||||
setattr(self, k, v)
|
||||
elif isinstance(v, BROADCAST_TYPES):
|
||||
elif isinstance(v, BROADCAST_TYPES): # pyre-fixme[6]
|
||||
args_to_broadcast[k] = v
|
||||
else:
|
||||
msg = "Arg %s with type %r is not broadcastable"
|
||||
@ -152,17 +156,18 @@ class TensorProperties(nn.Module):
|
||||
msg = "Expected index of type int or slice; got %r"
|
||||
raise ValueError(msg % type(index))
|
||||
|
||||
def to(self, device: str = "cpu"):
|
||||
def to(self, device: Device = "cpu"):
|
||||
"""
|
||||
In place operation to move class properties which are tensors to a
|
||||
specified device. If self has a property "device", update this as well.
|
||||
"""
|
||||
device_ = make_device(device)
|
||||
for k in dir(self):
|
||||
v = getattr(self, k)
|
||||
if k == "device":
|
||||
setattr(self, k, device)
|
||||
if torch.is_tensor(v) and v.device != device:
|
||||
setattr(self, k, v.to(device))
|
||||
setattr(self, k, device_)
|
||||
if torch.is_tensor(v) and v.device != device_:
|
||||
setattr(self, k, v.to(device_))
|
||||
return self
|
||||
|
||||
def clone(self, other):
|
||||
@ -257,28 +262,37 @@ class TensorProperties(nn.Module):
|
||||
return self
|
||||
|
||||
|
||||
def format_tensor(input, dtype=torch.float32, device: str = "cpu") -> torch.Tensor:
|
||||
def format_tensor(
|
||||
input, dtype: torch.dtype = torch.float32, device: Device = "cpu"
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Helper function for converting a scalar value to a tensor.
|
||||
|
||||
Args:
|
||||
input: Python scalar, Python list/tuple, torch scalar, 1D torch tensor
|
||||
dtype: data type for the input
|
||||
device: torch device on which the tensor should be placed.
|
||||
device: Device (as str or torch.device) on which the tensor should be placed.
|
||||
|
||||
Returns:
|
||||
input_vec: torch tensor with optional added batch dimension.
|
||||
"""
|
||||
device_ = make_device(device)
|
||||
if not torch.is_tensor(input):
|
||||
input = torch.tensor(input, dtype=dtype, device=device)
|
||||
input = torch.tensor(input, dtype=dtype, device=device_)
|
||||
|
||||
if input.dim() == 0:
|
||||
input = input.view(1)
|
||||
if input.device != device:
|
||||
input = input.to(device=device)
|
||||
|
||||
if input.device == device_:
|
||||
return input
|
||||
|
||||
input = input.to(device=device)
|
||||
return input
|
||||
|
||||
|
||||
def convert_to_tensors_and_broadcast(*args, dtype=torch.float32, device: str = "cpu"):
|
||||
def convert_to_tensors_and_broadcast(
|
||||
*args, dtype: torch.dtype = torch.float32, device: Device = "cpu"
|
||||
):
|
||||
"""
|
||||
Helper function to handle parsing an arbitrary number of inputs (*args)
|
||||
which all need to have the same batch dimension.
|
||||
|
Loading…
x
Reference in New Issue
Block a user