Fix type annotations for device type

Summary: Fix type annotations for device type

Reviewed By: nikhilaravi

Differential Revision: D28971179

fbshipit-source-id: 410b673c76dfd65ac51b2d144f17ed86a04a3058
This commit is contained in:
Patrick Labatut 2021-06-09 15:48:56 -07:00 committed by Facebook GitHub Bot
parent 1f9661e150
commit 626bf3fe23
12 changed files with 110 additions and 58 deletions

View File

@ -9,6 +9,7 @@ from typing import Dict, List, Optional
import numpy as np
import torch
from PIL import Image
from pytorch3d.common.types import Device
from pytorch3d.datasets.shapenet_base import ShapeNetBase
from pytorch3d.renderer import HardPhongShader
from tabulate import tabulate
@ -371,7 +372,7 @@ class R2N2(ShapeNetBase):
idxs: Optional[List[int]] = None,
view_idxs: Optional[List[int]] = None,
shader_type=HardPhongShader,
device="cpu",
device: Device = "cpu",
**kwargs
) -> torch.Tensor:
"""
@ -394,7 +395,7 @@ class R2N2(ShapeNetBase):
idxs: List[int] of indices of models to be rendered in the dataset.
shader_type: Shader to use for rendering. Examples include HardPhongShader
(default), SoftPhongShader etc or any other type of valid Shader class.
device: torch.device on which the tensors should be located.
device: Device (as str or torch.device) on which the tensors should be located.
**kwargs: Accepts any of the kwargs that the renderer supports and any of the
args that BlenderCamera supports.

View File

@ -4,6 +4,7 @@ from typing import Dict, List
import numpy as np
import torch
from pytorch3d.common.types import Device
from pytorch3d.datasets.utils import collate_batched_meshes
from pytorch3d.ops import cubify
from pytorch3d.renderer import (
@ -431,13 +432,13 @@ class BlenderCamera(CamerasBase):
(which uses Blender for rendering the views for each model).
"""
def __init__(self, R=r, T=t, K=k, device="cpu"):
def __init__(self, R=r, T=t, K=k, device: Device = "cpu"):
"""
Args:
R: Rotation matrix of shape (N, 3, 3).
T: Translation matrix of shape (N, 3).
K: Intrinsic matrix of shape (N, 4, 4).
device: torch.device or str.
device: Device (as str or torch.device).
"""
# The initializer formats all inputs to torch tensors and broadcasts
# all the inputs to have the same batch dimension where necessary.
@ -450,7 +451,7 @@ class BlenderCamera(CamerasBase):
def render_cubified_voxels(
voxels: torch.Tensor, shader_type=HardPhongShader, device="cpu", **kwargs
voxels: torch.Tensor, shader_type=HardPhongShader, device: Device = "cpu", **kwargs
):
"""
Use the Cubify operator to convert inputs voxels to a mesh and then render that mesh.
@ -461,7 +462,7 @@ def render_cubified_voxels(
shader_type: shader_type: shader_type: Shader to use for rendering. Examples
include HardPhongShader (default), SoftPhongShader etc or any other type
of valid Shader class.
device: torch.device on which the tensors should be located.
device: Device (as str or torch.device) on which the tensors should be located.
**kwargs: Accepts any of the kwargs that the renderer supports.
Returns:
Batch of rendered images of shape (N, H, W, 3).

View File

@ -4,6 +4,7 @@ import warnings
from typing import Dict, List, Optional, Tuple
import torch
from pytorch3d.common.types import Device
from pytorch3d.io import load_obj
from pytorch3d.renderer import (
FoVPerspectiveCameras,
@ -105,7 +106,7 @@ class ShapeNetBase(torch.utils.data.Dataset):
sample_nums: Optional[List[int]] = None,
idxs: Optional[List[int]] = None,
shader_type=HardPhongShader,
device="cpu",
device: Device = "cpu",
**kwargs
) -> torch.Tensor:
"""
@ -129,7 +130,7 @@ class ShapeNetBase(torch.utils.data.Dataset):
shader_type: Select shading. Valid options include HardPhongShader (default),
SoftPhongShader, HardGouraudShader, SoftGouraudShader, HardFlatShader,
SoftSilhouetteShader.
device: torch.device on which the tensors should be located.
device: Device (as str or torch.device) on which the tensors should be located.
**kwargs: Accepts any of the kwargs that the renderer supports.
Returns:

View File

@ -9,6 +9,7 @@ import numpy as np
import torch
import torch.nn.functional as F
from iopath.common.file_io import PathManager
from pytorch3d.common.types import Device
from pytorch3d.io.utils import _open_file, _read_image
@ -393,7 +394,7 @@ TextureImages = Dict[str, torch.Tensor]
def _parse_mtl(
f, path_manager: PathManager, device="cpu"
f, path_manager: PathManager, device: Device = "cpu"
) -> Tuple[MaterialProperties, TextureFiles]:
material_properties = {}
texture_files = {}
@ -474,7 +475,7 @@ def load_mtl(
*,
material_names: List[str],
data_dir: str,
device="cpu",
device: Device = "cpu",
path_manager: PathManager,
) -> Tuple[MaterialProperties, TextureImages]:
"""
@ -485,6 +486,7 @@ def load_mtl(
f: a file-like object of the material information.
material_names: a list of the material names found in the .obj file.
data_dir: the directory where the material texture files are located.
device: Device (as str or torch.tensor) on which to return the new tensors.
path_manager: PathManager for interpreting both f and material_names.
Returns:

View File

@ -11,6 +11,7 @@ from typing import List, Optional, Union
import numpy as np
import torch
from iopath.common.file_io import PathManager
from pytorch3d.common.types import Device
from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
from pytorch3d.renderer import TexturesAtlas, TexturesUV
@ -71,7 +72,7 @@ def load_obj(
create_texture_atlas: bool = False,
texture_atlas_size: int = 4,
texture_wrap: Optional[str] = "repeat",
device="cpu",
device: Device = "cpu",
path_manager: Optional[PathManager] = None,
):
"""
@ -143,7 +144,7 @@ def load_obj(
is ignored and a repeating pattern is formed.
If `texture_mode="clamp"` the values are clamped to the range [0, 1].
If None, then there is no transformation of the texture values.
device: string or torch.device on which to return the new tensors.
device: Device (as str or torch.device) on which to return the new tensors.
path_manager: optionally a PathManager object to interpret paths.
Returns:
@ -226,7 +227,7 @@ def load_obj(
def load_objs_as_meshes(
files: list,
device=None,
device: Optional[Device] = None,
load_textures: bool = True,
create_texture_atlas: bool = False,
texture_atlas_size: int = 4,
@ -293,7 +294,7 @@ class MeshObjFormat(MeshFormatInterpreter):
self,
path: Union[str, Path],
include_textures: bool,
device,
device: Device,
path_manager: PathManager,
create_texture_atlas: bool = False,
texture_atlas_size: int = 4,
@ -497,7 +498,7 @@ def _load_materials(
*,
data_dir: str,
load_textures: bool,
device,
device: Device,
path_manager: PathManager,
):
"""
@ -508,7 +509,7 @@ def _load_materials(
f: a file-like object of the material information.
data_dir: the directory where the material texture files are located.
load_textures: whether textures should be loaded.
device: string or torch.device on which to return the new tensors.
device: Device (as str or torch.device) on which to return the new tensors.
path_manager: PathManager object to interpret paths.
Returns:
@ -546,7 +547,7 @@ def _load_obj(
texture_atlas_size: int = 4,
texture_wrap: Optional[str] = "repeat",
path_manager: PathManager,
device="cpu",
device: Device = "cpu",
):
"""
Load a mesh from a file-like object. See load_obj function more details.

View File

@ -8,6 +8,7 @@ from pathlib import Path
from typing import Deque, Optional, Union
from iopath.common.file_io import PathManager
from pytorch3d.common.types import Device
from pytorch3d.structures import Meshes, Pointclouds
from .obj_io import MeshObjFormat
@ -108,7 +109,7 @@ class IO:
self,
path: Union[str, Path],
include_textures: bool = True,
device="cpu",
device: Device = "cpu",
**kwargs,
) -> Meshes:
"""
@ -168,14 +169,14 @@ class IO:
raise ValueError(f"No mesh interpreter found to write to {path}.")
def load_pointcloud(
self, path: Union[str, Path], device="cpu", **kwargs
self, path: Union[str, Path], device: Device = "cpu", **kwargs
) -> Pointclouds:
"""
Attempt to load a point cloud from the given file, using a registered format.
Args:
path: file to read
device: torch.device on which to load the data.
device: Device (as str or torch.device) on which to load the data.
Returns:
new Pointclouds object containing one mesh.

View File

@ -10,6 +10,8 @@ import torch
from iopath.common.file_io import PathManager
from PIL import Image
from ..common.types import Device
@contextlib.contextmanager
def nullcontext(x):
@ -31,7 +33,7 @@ def _open_file(f, path_manager: PathManager, mode="r") -> ContextManager[IO]:
def _make_tensor(
data, cols: int, dtype: torch.dtype, device: str = "cpu"
data, cols: int, dtype: torch.dtype, device: Device = "cpu"
) -> torch.Tensor:
"""
Return a 2D tensor with the specified cols and dtype filled with data,

View File

@ -7,6 +7,7 @@ from typing import Optional, Sequence, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from pytorch3d.common.types import Device
from pytorch3d.transforms import Rotate, Transform3d, Translate
from .utils import TensorProperties, convert_to_tensors_and_broadcast
@ -290,7 +291,7 @@ def OpenGLPerspectiveCameras(
degrees: bool = True,
R=_R,
T=_T,
device="cpu",
device: Device = "cpu",
):
"""
OpenGLPerspectiveCameras has been DEPRECATED. Use FoVPerspectiveCameras instead.
@ -358,7 +359,7 @@ class FoVPerspectiveCameras(CamerasBase):
R=_R,
T=_T,
K=None,
device="cpu",
device: Device = "cpu",
):
"""
@ -373,7 +374,7 @@ class FoVPerspectiveCameras(CamerasBase):
T: Translation matrix of shape (N, 3)
K: (optional) A calibration matrix of shape (N, 4, 4)
If provided, don't need znear, zfar, fov, aspect_ratio, degrees
device: torch.device or string
device: Device (as str or torch.device)
"""
# The initializer formats all inputs to torch tensors and broadcasts
# all the inputs to have the same batch dimension where necessary.

View File

@ -4,6 +4,7 @@
import torch
import torch.nn.functional as F
from ..common.types import Device
from .utils import TensorProperties, convert_to_tensors_and_broadcast
@ -158,7 +159,7 @@ class DirectionalLights(TensorProperties):
diffuse_color=((0.3, 0.3, 0.3),),
specular_color=((0.2, 0.2, 0.2),),
direction=((0, 1, 0),),
device="cpu",
device: Device = "cpu",
):
"""
Args:
@ -166,7 +167,7 @@ class DirectionalLights(TensorProperties):
diffuse_color: RGB color of the diffuse component.
specular_color: RGB color of the specular component.
direction: (x, y, z) direction vector of the light.
device: torch.device on which the tensors should be located
device: Device (as str or torch.device) on which the tensors should be located
The inputs can each be
- 3 element tuple/list or list of lists
@ -219,7 +220,7 @@ class PointLights(TensorProperties):
diffuse_color=((0.3, 0.3, 0.3),),
specular_color=((0.2, 0.2, 0.2),),
location=((0, 1, 0),),
device="cpu",
device: Device = "cpu",
):
"""
Args:
@ -227,7 +228,7 @@ class PointLights(TensorProperties):
diffuse_color: RGB color of the diffuse component
specular_color: RGB color of the specular component
location: xyz position of the light.
device: torch.device on which the tensors should be located
device: Device (as str or torch.device) on which the tensors should be located
The inputs can each be
- 3 element tuple/list or list of lists
@ -275,14 +276,14 @@ class AmbientLights(TensorProperties):
not used in rendering.
"""
def __init__(self, *, ambient_color=None, device="cpu"):
def __init__(self, *, ambient_color=None, device: Device = "cpu"):
"""
If ambient_color is provided, it should be a sequence of
triples of floats.
Args:
ambient_color: RGB color
device: torch.device on which the tensors should be located
device: Device (as str or torch.device) on which the tensors should be located
The ambient_color if provided, should be
- 3 element tuple/list or list of lists

View File

@ -3,6 +3,7 @@
import torch
from ..common.types import Device
from .utils import TensorProperties
@ -18,7 +19,7 @@ class Materials(TensorProperties):
diffuse_color=((1, 1, 1),),
specular_color=((1, 1, 1),),
shininess=64,
device="cpu",
device: Device = "cpu",
):
"""
Args:
@ -29,7 +30,7 @@ class Materials(TensorProperties):
the focus of the specular highlight with a high value
resulting in a concentrated highlight. Shininess values
can range from 0-1000.
device: torch.device or string
device: Device (as str or torch.device) on which the tensors should be located
ambient_color, diffuse_color and specular_color can be of shape
(1, 3) or (N, 3). shininess can be of shape (1) or (N).

View File

@ -5,6 +5,7 @@ import warnings
import torch
import torch.nn as nn
from ...common.types import Device
from ..blending import (
BlendParams,
hard_rgb_blend,
@ -40,7 +41,12 @@ class HardPhongShader(nn.Module):
"""
def __init__(
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
self,
device: Device = "cpu",
cameras=None,
lights=None,
materials=None,
blend_params=None,
):
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
@ -50,7 +56,7 @@ class HardPhongShader(nn.Module):
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()
def to(self, device):
def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module
self.cameras = self.cameras.to(device)
self.materials = self.materials.to(device)
@ -95,7 +101,12 @@ class SoftPhongShader(nn.Module):
"""
def __init__(
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
self,
device: Device = "cpu",
cameras=None,
lights=None,
materials=None,
blend_params=None,
):
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
@ -105,7 +116,7 @@ class SoftPhongShader(nn.Module):
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()
def to(self, device):
def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module
self.cameras = self.cameras.to(device)
self.materials = self.materials.to(device)
@ -155,7 +166,12 @@ class HardGouraudShader(nn.Module):
"""
def __init__(
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
self,
device: Device = "cpu",
cameras=None,
lights=None,
materials=None,
blend_params=None,
):
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
@ -165,7 +181,7 @@ class HardGouraudShader(nn.Module):
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()
def to(self, device):
def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module
self.cameras = self.cameras.to(device)
self.materials = self.materials.to(device)
@ -214,7 +230,12 @@ class SoftGouraudShader(nn.Module):
"""
def __init__(
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
self,
device: Device = "cpu",
cameras=None,
lights=None,
materials=None,
blend_params=None,
):
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
@ -224,7 +245,7 @@ class SoftGouraudShader(nn.Module):
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()
def to(self, device):
def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module
self.cameras = self.cameras.to(device)
self.materials = self.materials.to(device)
@ -255,7 +276,7 @@ class SoftGouraudShader(nn.Module):
def TexturedSoftPhongShader(
device="cpu", cameras=None, lights=None, materials=None, blend_params=None
device: Device = "cpu", cameras=None, lights=None, materials=None, blend_params=None
):
"""
TexturedSoftPhongShader class has been DEPRECATED. Use SoftPhongShader instead.
@ -290,7 +311,12 @@ class HardFlatShader(nn.Module):
"""
def __init__(
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
self,
device: Device = "cpu",
cameras=None,
lights=None,
materials=None,
blend_params=None,
):
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
@ -300,7 +326,7 @@ class HardFlatShader(nn.Module):
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()
def to(self, device):
def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module
self.cameras = self.cameras.to(device)
self.materials = self.materials.to(device)

View File

@ -10,6 +10,8 @@ import numpy as np
import torch
import torch.nn as nn
from ..common.types import Device, make_device
class TensorAccessor(nn.Module):
"""
@ -88,17 +90,19 @@ class TensorProperties(nn.Module):
A mix-in class for storing tensors as properties with helper methods.
"""
def __init__(self, dtype=torch.float32, device="cpu", **kwargs):
def __init__(
self, dtype: torch.dtype = torch.float32, device: Device = "cpu", **kwargs
):
"""
Args:
dtype: data type to set for the inputs
device: str or torch.device
device: Device (as str or torch.device)
kwargs: any number of keyword arguments. Any arguments which are
of type (float/int/tuple/tensor/array) are broadcasted and
of type (float/int/list/tuple/tensor/array) are broadcasted and
other keyword arguments are set as attributes.
"""
super().__init__()
self.device = device
self.device = make_device(device)
self._N = 0
if kwargs is not None:
@ -108,7 +112,7 @@ class TensorProperties(nn.Module):
for k, v in kwargs.items():
if v is None or isinstance(v, (str, bool)):
setattr(self, k, v)
elif isinstance(v, BROADCAST_TYPES):
elif isinstance(v, BROADCAST_TYPES): # pyre-fixme[6]
args_to_broadcast[k] = v
else:
msg = "Arg %s with type %r is not broadcastable"
@ -152,17 +156,18 @@ class TensorProperties(nn.Module):
msg = "Expected index of type int or slice; got %r"
raise ValueError(msg % type(index))
def to(self, device: str = "cpu"):
def to(self, device: Device = "cpu"):
"""
In place operation to move class properties which are tensors to a
specified device. If self has a property "device", update this as well.
"""
device_ = make_device(device)
for k in dir(self):
v = getattr(self, k)
if k == "device":
setattr(self, k, device)
if torch.is_tensor(v) and v.device != device:
setattr(self, k, v.to(device))
setattr(self, k, device_)
if torch.is_tensor(v) and v.device != device_:
setattr(self, k, v.to(device_))
return self
def clone(self, other):
@ -257,28 +262,37 @@ class TensorProperties(nn.Module):
return self
def format_tensor(input, dtype=torch.float32, device: str = "cpu") -> torch.Tensor:
def format_tensor(
input, dtype: torch.dtype = torch.float32, device: Device = "cpu"
) -> torch.Tensor:
"""
Helper function for converting a scalar value to a tensor.
Args:
input: Python scalar, Python list/tuple, torch scalar, 1D torch tensor
dtype: data type for the input
device: torch device on which the tensor should be placed.
device: Device (as str or torch.device) on which the tensor should be placed.
Returns:
input_vec: torch tensor with optional added batch dimension.
"""
device_ = make_device(device)
if not torch.is_tensor(input):
input = torch.tensor(input, dtype=dtype, device=device)
input = torch.tensor(input, dtype=dtype, device=device_)
if input.dim() == 0:
input = input.view(1)
if input.device != device:
input = input.to(device=device)
if input.device == device_:
return input
input = input.to(device=device)
return input
def convert_to_tensors_and_broadcast(*args, dtype=torch.float32, device: str = "cpu"):
def convert_to_tensors_and_broadcast(
*args, dtype: torch.dtype = torch.float32, device: Device = "cpu"
):
"""
Helper function to handle parsing an arbitrary number of inputs (*args)
which all need to have the same batch dimension.