Fix type annotations for device type

Summary: Fix type annotations for device type

Reviewed By: nikhilaravi

Differential Revision: D28971179

fbshipit-source-id: 410b673c76dfd65ac51b2d144f17ed86a04a3058
This commit is contained in:
Patrick Labatut 2021-06-09 15:48:56 -07:00 committed by Facebook GitHub Bot
parent 1f9661e150
commit 626bf3fe23
12 changed files with 110 additions and 58 deletions

View File

@ -9,6 +9,7 @@ from typing import Dict, List, Optional
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from pytorch3d.common.types import Device
from pytorch3d.datasets.shapenet_base import ShapeNetBase from pytorch3d.datasets.shapenet_base import ShapeNetBase
from pytorch3d.renderer import HardPhongShader from pytorch3d.renderer import HardPhongShader
from tabulate import tabulate from tabulate import tabulate
@ -371,7 +372,7 @@ class R2N2(ShapeNetBase):
idxs: Optional[List[int]] = None, idxs: Optional[List[int]] = None,
view_idxs: Optional[List[int]] = None, view_idxs: Optional[List[int]] = None,
shader_type=HardPhongShader, shader_type=HardPhongShader,
device="cpu", device: Device = "cpu",
**kwargs **kwargs
) -> torch.Tensor: ) -> torch.Tensor:
""" """
@ -394,7 +395,7 @@ class R2N2(ShapeNetBase):
idxs: List[int] of indices of models to be rendered in the dataset. idxs: List[int] of indices of models to be rendered in the dataset.
shader_type: Shader to use for rendering. Examples include HardPhongShader shader_type: Shader to use for rendering. Examples include HardPhongShader
(default), SoftPhongShader etc or any other type of valid Shader class. (default), SoftPhongShader etc or any other type of valid Shader class.
device: torch.device on which the tensors should be located. device: Device (as str or torch.device) on which the tensors should be located.
**kwargs: Accepts any of the kwargs that the renderer supports and any of the **kwargs: Accepts any of the kwargs that the renderer supports and any of the
args that BlenderCamera supports. args that BlenderCamera supports.

View File

@ -4,6 +4,7 @@ from typing import Dict, List
import numpy as np import numpy as np
import torch import torch
from pytorch3d.common.types import Device
from pytorch3d.datasets.utils import collate_batched_meshes from pytorch3d.datasets.utils import collate_batched_meshes
from pytorch3d.ops import cubify from pytorch3d.ops import cubify
from pytorch3d.renderer import ( from pytorch3d.renderer import (
@ -431,13 +432,13 @@ class BlenderCamera(CamerasBase):
(which uses Blender for rendering the views for each model). (which uses Blender for rendering the views for each model).
""" """
def __init__(self, R=r, T=t, K=k, device="cpu"): def __init__(self, R=r, T=t, K=k, device: Device = "cpu"):
""" """
Args: Args:
R: Rotation matrix of shape (N, 3, 3). R: Rotation matrix of shape (N, 3, 3).
T: Translation matrix of shape (N, 3). T: Translation matrix of shape (N, 3).
K: Intrinsic matrix of shape (N, 4, 4). K: Intrinsic matrix of shape (N, 4, 4).
device: torch.device or str. device: Device (as str or torch.device).
""" """
# The initializer formats all inputs to torch tensors and broadcasts # The initializer formats all inputs to torch tensors and broadcasts
# all the inputs to have the same batch dimension where necessary. # all the inputs to have the same batch dimension where necessary.
@ -450,7 +451,7 @@ class BlenderCamera(CamerasBase):
def render_cubified_voxels( def render_cubified_voxels(
voxels: torch.Tensor, shader_type=HardPhongShader, device="cpu", **kwargs voxels: torch.Tensor, shader_type=HardPhongShader, device: Device = "cpu", **kwargs
): ):
""" """
Use the Cubify operator to convert inputs voxels to a mesh and then render that mesh. Use the Cubify operator to convert inputs voxels to a mesh and then render that mesh.
@ -461,7 +462,7 @@ def render_cubified_voxels(
shader_type: shader_type: shader_type: Shader to use for rendering. Examples shader_type: shader_type: shader_type: Shader to use for rendering. Examples
include HardPhongShader (default), SoftPhongShader etc or any other type include HardPhongShader (default), SoftPhongShader etc or any other type
of valid Shader class. of valid Shader class.
device: torch.device on which the tensors should be located. device: Device (as str or torch.device) on which the tensors should be located.
**kwargs: Accepts any of the kwargs that the renderer supports. **kwargs: Accepts any of the kwargs that the renderer supports.
Returns: Returns:
Batch of rendered images of shape (N, H, W, 3). Batch of rendered images of shape (N, H, W, 3).

View File

@ -4,6 +4,7 @@ import warnings
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
from pytorch3d.common.types import Device
from pytorch3d.io import load_obj from pytorch3d.io import load_obj
from pytorch3d.renderer import ( from pytorch3d.renderer import (
FoVPerspectiveCameras, FoVPerspectiveCameras,
@ -105,7 +106,7 @@ class ShapeNetBase(torch.utils.data.Dataset):
sample_nums: Optional[List[int]] = None, sample_nums: Optional[List[int]] = None,
idxs: Optional[List[int]] = None, idxs: Optional[List[int]] = None,
shader_type=HardPhongShader, shader_type=HardPhongShader,
device="cpu", device: Device = "cpu",
**kwargs **kwargs
) -> torch.Tensor: ) -> torch.Tensor:
""" """
@ -129,7 +130,7 @@ class ShapeNetBase(torch.utils.data.Dataset):
shader_type: Select shading. Valid options include HardPhongShader (default), shader_type: Select shading. Valid options include HardPhongShader (default),
SoftPhongShader, HardGouraudShader, SoftGouraudShader, HardFlatShader, SoftPhongShader, HardGouraudShader, SoftGouraudShader, HardFlatShader,
SoftSilhouetteShader. SoftSilhouetteShader.
device: torch.device on which the tensors should be located. device: Device (as str or torch.device) on which the tensors should be located.
**kwargs: Accepts any of the kwargs that the renderer supports. **kwargs: Accepts any of the kwargs that the renderer supports.
Returns: Returns:

View File

@ -9,6 +9,7 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from iopath.common.file_io import PathManager from iopath.common.file_io import PathManager
from pytorch3d.common.types import Device
from pytorch3d.io.utils import _open_file, _read_image from pytorch3d.io.utils import _open_file, _read_image
@ -393,7 +394,7 @@ TextureImages = Dict[str, torch.Tensor]
def _parse_mtl( def _parse_mtl(
f, path_manager: PathManager, device="cpu" f, path_manager: PathManager, device: Device = "cpu"
) -> Tuple[MaterialProperties, TextureFiles]: ) -> Tuple[MaterialProperties, TextureFiles]:
material_properties = {} material_properties = {}
texture_files = {} texture_files = {}
@ -474,7 +475,7 @@ def load_mtl(
*, *,
material_names: List[str], material_names: List[str],
data_dir: str, data_dir: str,
device="cpu", device: Device = "cpu",
path_manager: PathManager, path_manager: PathManager,
) -> Tuple[MaterialProperties, TextureImages]: ) -> Tuple[MaterialProperties, TextureImages]:
""" """
@ -485,6 +486,7 @@ def load_mtl(
f: a file-like object of the material information. f: a file-like object of the material information.
material_names: a list of the material names found in the .obj file. material_names: a list of the material names found in the .obj file.
data_dir: the directory where the material texture files are located. data_dir: the directory where the material texture files are located.
device: Device (as str or torch.tensor) on which to return the new tensors.
path_manager: PathManager for interpreting both f and material_names. path_manager: PathManager for interpreting both f and material_names.
Returns: Returns:

View File

@ -11,6 +11,7 @@ from typing import List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from iopath.common.file_io import PathManager from iopath.common.file_io import PathManager
from pytorch3d.common.types import Device
from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
from pytorch3d.renderer import TexturesAtlas, TexturesUV from pytorch3d.renderer import TexturesAtlas, TexturesUV
@ -71,7 +72,7 @@ def load_obj(
create_texture_atlas: bool = False, create_texture_atlas: bool = False,
texture_atlas_size: int = 4, texture_atlas_size: int = 4,
texture_wrap: Optional[str] = "repeat", texture_wrap: Optional[str] = "repeat",
device="cpu", device: Device = "cpu",
path_manager: Optional[PathManager] = None, path_manager: Optional[PathManager] = None,
): ):
""" """
@ -143,7 +144,7 @@ def load_obj(
is ignored and a repeating pattern is formed. is ignored and a repeating pattern is formed.
If `texture_mode="clamp"` the values are clamped to the range [0, 1]. If `texture_mode="clamp"` the values are clamped to the range [0, 1].
If None, then there is no transformation of the texture values. If None, then there is no transformation of the texture values.
device: string or torch.device on which to return the new tensors. device: Device (as str or torch.device) on which to return the new tensors.
path_manager: optionally a PathManager object to interpret paths. path_manager: optionally a PathManager object to interpret paths.
Returns: Returns:
@ -226,7 +227,7 @@ def load_obj(
def load_objs_as_meshes( def load_objs_as_meshes(
files: list, files: list,
device=None, device: Optional[Device] = None,
load_textures: bool = True, load_textures: bool = True,
create_texture_atlas: bool = False, create_texture_atlas: bool = False,
texture_atlas_size: int = 4, texture_atlas_size: int = 4,
@ -293,7 +294,7 @@ class MeshObjFormat(MeshFormatInterpreter):
self, self,
path: Union[str, Path], path: Union[str, Path],
include_textures: bool, include_textures: bool,
device, device: Device,
path_manager: PathManager, path_manager: PathManager,
create_texture_atlas: bool = False, create_texture_atlas: bool = False,
texture_atlas_size: int = 4, texture_atlas_size: int = 4,
@ -497,7 +498,7 @@ def _load_materials(
*, *,
data_dir: str, data_dir: str,
load_textures: bool, load_textures: bool,
device, device: Device,
path_manager: PathManager, path_manager: PathManager,
): ):
""" """
@ -508,7 +509,7 @@ def _load_materials(
f: a file-like object of the material information. f: a file-like object of the material information.
data_dir: the directory where the material texture files are located. data_dir: the directory where the material texture files are located.
load_textures: whether textures should be loaded. load_textures: whether textures should be loaded.
device: string or torch.device on which to return the new tensors. device: Device (as str or torch.device) on which to return the new tensors.
path_manager: PathManager object to interpret paths. path_manager: PathManager object to interpret paths.
Returns: Returns:
@ -546,7 +547,7 @@ def _load_obj(
texture_atlas_size: int = 4, texture_atlas_size: int = 4,
texture_wrap: Optional[str] = "repeat", texture_wrap: Optional[str] = "repeat",
path_manager: PathManager, path_manager: PathManager,
device="cpu", device: Device = "cpu",
): ):
""" """
Load a mesh from a file-like object. See load_obj function more details. Load a mesh from a file-like object. See load_obj function more details.

View File

@ -8,6 +8,7 @@ from pathlib import Path
from typing import Deque, Optional, Union from typing import Deque, Optional, Union
from iopath.common.file_io import PathManager from iopath.common.file_io import PathManager
from pytorch3d.common.types import Device
from pytorch3d.structures import Meshes, Pointclouds from pytorch3d.structures import Meshes, Pointclouds
from .obj_io import MeshObjFormat from .obj_io import MeshObjFormat
@ -108,7 +109,7 @@ class IO:
self, self,
path: Union[str, Path], path: Union[str, Path],
include_textures: bool = True, include_textures: bool = True,
device="cpu", device: Device = "cpu",
**kwargs, **kwargs,
) -> Meshes: ) -> Meshes:
""" """
@ -168,14 +169,14 @@ class IO:
raise ValueError(f"No mesh interpreter found to write to {path}.") raise ValueError(f"No mesh interpreter found to write to {path}.")
def load_pointcloud( def load_pointcloud(
self, path: Union[str, Path], device="cpu", **kwargs self, path: Union[str, Path], device: Device = "cpu", **kwargs
) -> Pointclouds: ) -> Pointclouds:
""" """
Attempt to load a point cloud from the given file, using a registered format. Attempt to load a point cloud from the given file, using a registered format.
Args: Args:
path: file to read path: file to read
device: torch.device on which to load the data. device: Device (as str or torch.device) on which to load the data.
Returns: Returns:
new Pointclouds object containing one mesh. new Pointclouds object containing one mesh.

View File

@ -10,6 +10,8 @@ import torch
from iopath.common.file_io import PathManager from iopath.common.file_io import PathManager
from PIL import Image from PIL import Image
from ..common.types import Device
@contextlib.contextmanager @contextlib.contextmanager
def nullcontext(x): def nullcontext(x):
@ -31,7 +33,7 @@ def _open_file(f, path_manager: PathManager, mode="r") -> ContextManager[IO]:
def _make_tensor( def _make_tensor(
data, cols: int, dtype: torch.dtype, device: str = "cpu" data, cols: int, dtype: torch.dtype, device: Device = "cpu"
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Return a 2D tensor with the specified cols and dtype filled with data, Return a 2D tensor with the specified cols and dtype filled with data,

View File

@ -7,6 +7,7 @@ from typing import Optional, Sequence, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from pytorch3d.common.types import Device
from pytorch3d.transforms import Rotate, Transform3d, Translate from pytorch3d.transforms import Rotate, Transform3d, Translate
from .utils import TensorProperties, convert_to_tensors_and_broadcast from .utils import TensorProperties, convert_to_tensors_and_broadcast
@ -290,7 +291,7 @@ def OpenGLPerspectiveCameras(
degrees: bool = True, degrees: bool = True,
R=_R, R=_R,
T=_T, T=_T,
device="cpu", device: Device = "cpu",
): ):
""" """
OpenGLPerspectiveCameras has been DEPRECATED. Use FoVPerspectiveCameras instead. OpenGLPerspectiveCameras has been DEPRECATED. Use FoVPerspectiveCameras instead.
@ -358,7 +359,7 @@ class FoVPerspectiveCameras(CamerasBase):
R=_R, R=_R,
T=_T, T=_T,
K=None, K=None,
device="cpu", device: Device = "cpu",
): ):
""" """
@ -373,7 +374,7 @@ class FoVPerspectiveCameras(CamerasBase):
T: Translation matrix of shape (N, 3) T: Translation matrix of shape (N, 3)
K: (optional) A calibration matrix of shape (N, 4, 4) K: (optional) A calibration matrix of shape (N, 4, 4)
If provided, don't need znear, zfar, fov, aspect_ratio, degrees If provided, don't need znear, zfar, fov, aspect_ratio, degrees
device: torch.device or string device: Device (as str or torch.device)
""" """
# The initializer formats all inputs to torch tensors and broadcasts # The initializer formats all inputs to torch tensors and broadcasts
# all the inputs to have the same batch dimension where necessary. # all the inputs to have the same batch dimension where necessary.

View File

@ -4,6 +4,7 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from ..common.types import Device
from .utils import TensorProperties, convert_to_tensors_and_broadcast from .utils import TensorProperties, convert_to_tensors_and_broadcast
@ -158,7 +159,7 @@ class DirectionalLights(TensorProperties):
diffuse_color=((0.3, 0.3, 0.3),), diffuse_color=((0.3, 0.3, 0.3),),
specular_color=((0.2, 0.2, 0.2),), specular_color=((0.2, 0.2, 0.2),),
direction=((0, 1, 0),), direction=((0, 1, 0),),
device="cpu", device: Device = "cpu",
): ):
""" """
Args: Args:
@ -166,7 +167,7 @@ class DirectionalLights(TensorProperties):
diffuse_color: RGB color of the diffuse component. diffuse_color: RGB color of the diffuse component.
specular_color: RGB color of the specular component. specular_color: RGB color of the specular component.
direction: (x, y, z) direction vector of the light. direction: (x, y, z) direction vector of the light.
device: torch.device on which the tensors should be located device: Device (as str or torch.device) on which the tensors should be located
The inputs can each be The inputs can each be
- 3 element tuple/list or list of lists - 3 element tuple/list or list of lists
@ -219,7 +220,7 @@ class PointLights(TensorProperties):
diffuse_color=((0.3, 0.3, 0.3),), diffuse_color=((0.3, 0.3, 0.3),),
specular_color=((0.2, 0.2, 0.2),), specular_color=((0.2, 0.2, 0.2),),
location=((0, 1, 0),), location=((0, 1, 0),),
device="cpu", device: Device = "cpu",
): ):
""" """
Args: Args:
@ -227,7 +228,7 @@ class PointLights(TensorProperties):
diffuse_color: RGB color of the diffuse component diffuse_color: RGB color of the diffuse component
specular_color: RGB color of the specular component specular_color: RGB color of the specular component
location: xyz position of the light. location: xyz position of the light.
device: torch.device on which the tensors should be located device: Device (as str or torch.device) on which the tensors should be located
The inputs can each be The inputs can each be
- 3 element tuple/list or list of lists - 3 element tuple/list or list of lists
@ -275,14 +276,14 @@ class AmbientLights(TensorProperties):
not used in rendering. not used in rendering.
""" """
def __init__(self, *, ambient_color=None, device="cpu"): def __init__(self, *, ambient_color=None, device: Device = "cpu"):
""" """
If ambient_color is provided, it should be a sequence of If ambient_color is provided, it should be a sequence of
triples of floats. triples of floats.
Args: Args:
ambient_color: RGB color ambient_color: RGB color
device: torch.device on which the tensors should be located device: Device (as str or torch.device) on which the tensors should be located
The ambient_color if provided, should be The ambient_color if provided, should be
- 3 element tuple/list or list of lists - 3 element tuple/list or list of lists

View File

@ -3,6 +3,7 @@
import torch import torch
from ..common.types import Device
from .utils import TensorProperties from .utils import TensorProperties
@ -18,7 +19,7 @@ class Materials(TensorProperties):
diffuse_color=((1, 1, 1),), diffuse_color=((1, 1, 1),),
specular_color=((1, 1, 1),), specular_color=((1, 1, 1),),
shininess=64, shininess=64,
device="cpu", device: Device = "cpu",
): ):
""" """
Args: Args:
@ -29,7 +30,7 @@ class Materials(TensorProperties):
the focus of the specular highlight with a high value the focus of the specular highlight with a high value
resulting in a concentrated highlight. Shininess values resulting in a concentrated highlight. Shininess values
can range from 0-1000. can range from 0-1000.
device: torch.device or string device: Device (as str or torch.device) on which the tensors should be located
ambient_color, diffuse_color and specular_color can be of shape ambient_color, diffuse_color and specular_color can be of shape
(1, 3) or (N, 3). shininess can be of shape (1) or (N). (1, 3) or (N, 3). shininess can be of shape (1) or (N).

View File

@ -5,6 +5,7 @@ import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
from ...common.types import Device
from ..blending import ( from ..blending import (
BlendParams, BlendParams,
hard_rgb_blend, hard_rgb_blend,
@ -40,7 +41,12 @@ class HardPhongShader(nn.Module):
""" """
def __init__( def __init__(
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None self,
device: Device = "cpu",
cameras=None,
lights=None,
materials=None,
blend_params=None,
): ):
super().__init__() super().__init__()
self.lights = lights if lights is not None else PointLights(device=device) self.lights = lights if lights is not None else PointLights(device=device)
@ -50,7 +56,7 @@ class HardPhongShader(nn.Module):
self.cameras = cameras self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams() self.blend_params = blend_params if blend_params is not None else BlendParams()
def to(self, device): def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module # Manually move to device modules which are not subclasses of nn.Module
self.cameras = self.cameras.to(device) self.cameras = self.cameras.to(device)
self.materials = self.materials.to(device) self.materials = self.materials.to(device)
@ -95,7 +101,12 @@ class SoftPhongShader(nn.Module):
""" """
def __init__( def __init__(
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None self,
device: Device = "cpu",
cameras=None,
lights=None,
materials=None,
blend_params=None,
): ):
super().__init__() super().__init__()
self.lights = lights if lights is not None else PointLights(device=device) self.lights = lights if lights is not None else PointLights(device=device)
@ -105,7 +116,7 @@ class SoftPhongShader(nn.Module):
self.cameras = cameras self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams() self.blend_params = blend_params if blend_params is not None else BlendParams()
def to(self, device): def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module # Manually move to device modules which are not subclasses of nn.Module
self.cameras = self.cameras.to(device) self.cameras = self.cameras.to(device)
self.materials = self.materials.to(device) self.materials = self.materials.to(device)
@ -155,7 +166,12 @@ class HardGouraudShader(nn.Module):
""" """
def __init__( def __init__(
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None self,
device: Device = "cpu",
cameras=None,
lights=None,
materials=None,
blend_params=None,
): ):
super().__init__() super().__init__()
self.lights = lights if lights is not None else PointLights(device=device) self.lights = lights if lights is not None else PointLights(device=device)
@ -165,7 +181,7 @@ class HardGouraudShader(nn.Module):
self.cameras = cameras self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams() self.blend_params = blend_params if blend_params is not None else BlendParams()
def to(self, device): def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module # Manually move to device modules which are not subclasses of nn.Module
self.cameras = self.cameras.to(device) self.cameras = self.cameras.to(device)
self.materials = self.materials.to(device) self.materials = self.materials.to(device)
@ -214,7 +230,12 @@ class SoftGouraudShader(nn.Module):
""" """
def __init__( def __init__(
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None self,
device: Device = "cpu",
cameras=None,
lights=None,
materials=None,
blend_params=None,
): ):
super().__init__() super().__init__()
self.lights = lights if lights is not None else PointLights(device=device) self.lights = lights if lights is not None else PointLights(device=device)
@ -224,7 +245,7 @@ class SoftGouraudShader(nn.Module):
self.cameras = cameras self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams() self.blend_params = blend_params if blend_params is not None else BlendParams()
def to(self, device): def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module # Manually move to device modules which are not subclasses of nn.Module
self.cameras = self.cameras.to(device) self.cameras = self.cameras.to(device)
self.materials = self.materials.to(device) self.materials = self.materials.to(device)
@ -255,7 +276,7 @@ class SoftGouraudShader(nn.Module):
def TexturedSoftPhongShader( def TexturedSoftPhongShader(
device="cpu", cameras=None, lights=None, materials=None, blend_params=None device: Device = "cpu", cameras=None, lights=None, materials=None, blend_params=None
): ):
""" """
TexturedSoftPhongShader class has been DEPRECATED. Use SoftPhongShader instead. TexturedSoftPhongShader class has been DEPRECATED. Use SoftPhongShader instead.
@ -290,7 +311,12 @@ class HardFlatShader(nn.Module):
""" """
def __init__( def __init__(
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None self,
device: Device = "cpu",
cameras=None,
lights=None,
materials=None,
blend_params=None,
): ):
super().__init__() super().__init__()
self.lights = lights if lights is not None else PointLights(device=device) self.lights = lights if lights is not None else PointLights(device=device)
@ -300,7 +326,7 @@ class HardFlatShader(nn.Module):
self.cameras = cameras self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams() self.blend_params = blend_params if blend_params is not None else BlendParams()
def to(self, device): def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module # Manually move to device modules which are not subclasses of nn.Module
self.cameras = self.cameras.to(device) self.cameras = self.cameras.to(device)
self.materials = self.materials.to(device) self.materials = self.materials.to(device)

View File

@ -10,6 +10,8 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..common.types import Device, make_device
class TensorAccessor(nn.Module): class TensorAccessor(nn.Module):
""" """
@ -88,17 +90,19 @@ class TensorProperties(nn.Module):
A mix-in class for storing tensors as properties with helper methods. A mix-in class for storing tensors as properties with helper methods.
""" """
def __init__(self, dtype=torch.float32, device="cpu", **kwargs): def __init__(
self, dtype: torch.dtype = torch.float32, device: Device = "cpu", **kwargs
):
""" """
Args: Args:
dtype: data type to set for the inputs dtype: data type to set for the inputs
device: str or torch.device device: Device (as str or torch.device)
kwargs: any number of keyword arguments. Any arguments which are kwargs: any number of keyword arguments. Any arguments which are
of type (float/int/tuple/tensor/array) are broadcasted and of type (float/int/list/tuple/tensor/array) are broadcasted and
other keyword arguments are set as attributes. other keyword arguments are set as attributes.
""" """
super().__init__() super().__init__()
self.device = device self.device = make_device(device)
self._N = 0 self._N = 0
if kwargs is not None: if kwargs is not None:
@ -108,7 +112,7 @@ class TensorProperties(nn.Module):
for k, v in kwargs.items(): for k, v in kwargs.items():
if v is None or isinstance(v, (str, bool)): if v is None or isinstance(v, (str, bool)):
setattr(self, k, v) setattr(self, k, v)
elif isinstance(v, BROADCAST_TYPES): elif isinstance(v, BROADCAST_TYPES): # pyre-fixme[6]
args_to_broadcast[k] = v args_to_broadcast[k] = v
else: else:
msg = "Arg %s with type %r is not broadcastable" msg = "Arg %s with type %r is not broadcastable"
@ -152,17 +156,18 @@ class TensorProperties(nn.Module):
msg = "Expected index of type int or slice; got %r" msg = "Expected index of type int or slice; got %r"
raise ValueError(msg % type(index)) raise ValueError(msg % type(index))
def to(self, device: str = "cpu"): def to(self, device: Device = "cpu"):
""" """
In place operation to move class properties which are tensors to a In place operation to move class properties which are tensors to a
specified device. If self has a property "device", update this as well. specified device. If self has a property "device", update this as well.
""" """
device_ = make_device(device)
for k in dir(self): for k in dir(self):
v = getattr(self, k) v = getattr(self, k)
if k == "device": if k == "device":
setattr(self, k, device) setattr(self, k, device_)
if torch.is_tensor(v) and v.device != device: if torch.is_tensor(v) and v.device != device_:
setattr(self, k, v.to(device)) setattr(self, k, v.to(device_))
return self return self
def clone(self, other): def clone(self, other):
@ -257,28 +262,37 @@ class TensorProperties(nn.Module):
return self return self
def format_tensor(input, dtype=torch.float32, device: str = "cpu") -> torch.Tensor: def format_tensor(
input, dtype: torch.dtype = torch.float32, device: Device = "cpu"
) -> torch.Tensor:
""" """
Helper function for converting a scalar value to a tensor. Helper function for converting a scalar value to a tensor.
Args: Args:
input: Python scalar, Python list/tuple, torch scalar, 1D torch tensor input: Python scalar, Python list/tuple, torch scalar, 1D torch tensor
dtype: data type for the input dtype: data type for the input
device: torch device on which the tensor should be placed. device: Device (as str or torch.device) on which the tensor should be placed.
Returns: Returns:
input_vec: torch tensor with optional added batch dimension. input_vec: torch tensor with optional added batch dimension.
""" """
device_ = make_device(device)
if not torch.is_tensor(input): if not torch.is_tensor(input):
input = torch.tensor(input, dtype=dtype, device=device) input = torch.tensor(input, dtype=dtype, device=device_)
if input.dim() == 0: if input.dim() == 0:
input = input.view(1) input = input.view(1)
if input.device != device:
input = input.to(device=device) if input.device == device_:
return input
input = input.to(device=device)
return input return input
def convert_to_tensors_and_broadcast(*args, dtype=torch.float32, device: str = "cpu"): def convert_to_tensors_and_broadcast(
*args, dtype: torch.dtype = torch.float32, device: Device = "cpu"
):
""" """
Helper function to handle parsing an arbitrary number of inputs (*args) Helper function to handle parsing an arbitrary number of inputs (*args)
which all need to have the same batch dimension. which all need to have the same batch dimension.