diff --git a/pytorch3d/datasets/r2n2/r2n2.py b/pytorch3d/datasets/r2n2/r2n2.py index dff335c5..6a3f6232 100644 --- a/pytorch3d/datasets/r2n2/r2n2.py +++ b/pytorch3d/datasets/r2n2/r2n2.py @@ -9,6 +9,7 @@ from typing import Dict, List, Optional import numpy as np import torch from PIL import Image +from pytorch3d.common.types import Device from pytorch3d.datasets.shapenet_base import ShapeNetBase from pytorch3d.renderer import HardPhongShader from tabulate import tabulate @@ -371,7 +372,7 @@ class R2N2(ShapeNetBase): idxs: Optional[List[int]] = None, view_idxs: Optional[List[int]] = None, shader_type=HardPhongShader, - device="cpu", + device: Device = "cpu", **kwargs ) -> torch.Tensor: """ @@ -394,7 +395,7 @@ class R2N2(ShapeNetBase): idxs: List[int] of indices of models to be rendered in the dataset. shader_type: Shader to use for rendering. Examples include HardPhongShader (default), SoftPhongShader etc or any other type of valid Shader class. - device: torch.device on which the tensors should be located. + device: Device (as str or torch.device) on which the tensors should be located. **kwargs: Accepts any of the kwargs that the renderer supports and any of the args that BlenderCamera supports. diff --git a/pytorch3d/datasets/r2n2/utils.py b/pytorch3d/datasets/r2n2/utils.py index f2c96802..b77e48d5 100644 --- a/pytorch3d/datasets/r2n2/utils.py +++ b/pytorch3d/datasets/r2n2/utils.py @@ -4,6 +4,7 @@ from typing import Dict, List import numpy as np import torch +from pytorch3d.common.types import Device from pytorch3d.datasets.utils import collate_batched_meshes from pytorch3d.ops import cubify from pytorch3d.renderer import ( @@ -431,13 +432,13 @@ class BlenderCamera(CamerasBase): (which uses Blender for rendering the views for each model). """ - def __init__(self, R=r, T=t, K=k, device="cpu"): + def __init__(self, R=r, T=t, K=k, device: Device = "cpu"): """ Args: R: Rotation matrix of shape (N, 3, 3). T: Translation matrix of shape (N, 3). K: Intrinsic matrix of shape (N, 4, 4). - device: torch.device or str. + device: Device (as str or torch.device). """ # The initializer formats all inputs to torch tensors and broadcasts # all the inputs to have the same batch dimension where necessary. @@ -450,7 +451,7 @@ class BlenderCamera(CamerasBase): def render_cubified_voxels( - voxels: torch.Tensor, shader_type=HardPhongShader, device="cpu", **kwargs + voxels: torch.Tensor, shader_type=HardPhongShader, device: Device = "cpu", **kwargs ): """ Use the Cubify operator to convert inputs voxels to a mesh and then render that mesh. @@ -461,7 +462,7 @@ def render_cubified_voxels( shader_type: shader_type: shader_type: Shader to use for rendering. Examples include HardPhongShader (default), SoftPhongShader etc or any other type of valid Shader class. - device: torch.device on which the tensors should be located. + device: Device (as str or torch.device) on which the tensors should be located. **kwargs: Accepts any of the kwargs that the renderer supports. Returns: Batch of rendered images of shape (N, H, W, 3). diff --git a/pytorch3d/datasets/shapenet_base.py b/pytorch3d/datasets/shapenet_base.py index 6eca6f84..0d9de02b 100644 --- a/pytorch3d/datasets/shapenet_base.py +++ b/pytorch3d/datasets/shapenet_base.py @@ -4,6 +4,7 @@ import warnings from typing import Dict, List, Optional, Tuple import torch +from pytorch3d.common.types import Device from pytorch3d.io import load_obj from pytorch3d.renderer import ( FoVPerspectiveCameras, @@ -105,7 +106,7 @@ class ShapeNetBase(torch.utils.data.Dataset): sample_nums: Optional[List[int]] = None, idxs: Optional[List[int]] = None, shader_type=HardPhongShader, - device="cpu", + device: Device = "cpu", **kwargs ) -> torch.Tensor: """ @@ -129,7 +130,7 @@ class ShapeNetBase(torch.utils.data.Dataset): shader_type: Select shading. Valid options include HardPhongShader (default), SoftPhongShader, HardGouraudShader, SoftGouraudShader, HardFlatShader, SoftSilhouetteShader. - device: torch.device on which the tensors should be located. + device: Device (as str or torch.device) on which the tensors should be located. **kwargs: Accepts any of the kwargs that the renderer supports. Returns: diff --git a/pytorch3d/io/mtl_io.py b/pytorch3d/io/mtl_io.py index 6d3fa602..29539610 100644 --- a/pytorch3d/io/mtl_io.py +++ b/pytorch3d/io/mtl_io.py @@ -9,6 +9,7 @@ import numpy as np import torch import torch.nn.functional as F from iopath.common.file_io import PathManager +from pytorch3d.common.types import Device from pytorch3d.io.utils import _open_file, _read_image @@ -393,7 +394,7 @@ TextureImages = Dict[str, torch.Tensor] def _parse_mtl( - f, path_manager: PathManager, device="cpu" + f, path_manager: PathManager, device: Device = "cpu" ) -> Tuple[MaterialProperties, TextureFiles]: material_properties = {} texture_files = {} @@ -474,7 +475,7 @@ def load_mtl( *, material_names: List[str], data_dir: str, - device="cpu", + device: Device = "cpu", path_manager: PathManager, ) -> Tuple[MaterialProperties, TextureImages]: """ @@ -485,6 +486,7 @@ def load_mtl( f: a file-like object of the material information. material_names: a list of the material names found in the .obj file. data_dir: the directory where the material texture files are located. + device: Device (as str or torch.tensor) on which to return the new tensors. path_manager: PathManager for interpreting both f and material_names. Returns: diff --git a/pytorch3d/io/obj_io.py b/pytorch3d/io/obj_io.py index 12ac537b..9311fd15 100644 --- a/pytorch3d/io/obj_io.py +++ b/pytorch3d/io/obj_io.py @@ -11,6 +11,7 @@ from typing import List, Optional, Union import numpy as np import torch from iopath.common.file_io import PathManager +from pytorch3d.common.types import Device from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file from pytorch3d.renderer import TexturesAtlas, TexturesUV @@ -71,7 +72,7 @@ def load_obj( create_texture_atlas: bool = False, texture_atlas_size: int = 4, texture_wrap: Optional[str] = "repeat", - device="cpu", + device: Device = "cpu", path_manager: Optional[PathManager] = None, ): """ @@ -143,7 +144,7 @@ def load_obj( is ignored and a repeating pattern is formed. If `texture_mode="clamp"` the values are clamped to the range [0, 1]. If None, then there is no transformation of the texture values. - device: string or torch.device on which to return the new tensors. + device: Device (as str or torch.device) on which to return the new tensors. path_manager: optionally a PathManager object to interpret paths. Returns: @@ -226,7 +227,7 @@ def load_obj( def load_objs_as_meshes( files: list, - device=None, + device: Optional[Device] = None, load_textures: bool = True, create_texture_atlas: bool = False, texture_atlas_size: int = 4, @@ -293,7 +294,7 @@ class MeshObjFormat(MeshFormatInterpreter): self, path: Union[str, Path], include_textures: bool, - device, + device: Device, path_manager: PathManager, create_texture_atlas: bool = False, texture_atlas_size: int = 4, @@ -497,7 +498,7 @@ def _load_materials( *, data_dir: str, load_textures: bool, - device, + device: Device, path_manager: PathManager, ): """ @@ -508,7 +509,7 @@ def _load_materials( f: a file-like object of the material information. data_dir: the directory where the material texture files are located. load_textures: whether textures should be loaded. - device: string or torch.device on which to return the new tensors. + device: Device (as str or torch.device) on which to return the new tensors. path_manager: PathManager object to interpret paths. Returns: @@ -546,7 +547,7 @@ def _load_obj( texture_atlas_size: int = 4, texture_wrap: Optional[str] = "repeat", path_manager: PathManager, - device="cpu", + device: Device = "cpu", ): """ Load a mesh from a file-like object. See load_obj function more details. diff --git a/pytorch3d/io/pluggable.py b/pytorch3d/io/pluggable.py index 8c729300..bbdc589b 100644 --- a/pytorch3d/io/pluggable.py +++ b/pytorch3d/io/pluggable.py @@ -8,6 +8,7 @@ from pathlib import Path from typing import Deque, Optional, Union from iopath.common.file_io import PathManager +from pytorch3d.common.types import Device from pytorch3d.structures import Meshes, Pointclouds from .obj_io import MeshObjFormat @@ -108,7 +109,7 @@ class IO: self, path: Union[str, Path], include_textures: bool = True, - device="cpu", + device: Device = "cpu", **kwargs, ) -> Meshes: """ @@ -168,14 +169,14 @@ class IO: raise ValueError(f"No mesh interpreter found to write to {path}.") def load_pointcloud( - self, path: Union[str, Path], device="cpu", **kwargs + self, path: Union[str, Path], device: Device = "cpu", **kwargs ) -> Pointclouds: """ Attempt to load a point cloud from the given file, using a registered format. Args: path: file to read - device: torch.device on which to load the data. + device: Device (as str or torch.device) on which to load the data. Returns: new Pointclouds object containing one mesh. diff --git a/pytorch3d/io/utils.py b/pytorch3d/io/utils.py index a9191789..e1d5c874 100644 --- a/pytorch3d/io/utils.py +++ b/pytorch3d/io/utils.py @@ -10,6 +10,8 @@ import torch from iopath.common.file_io import PathManager from PIL import Image +from ..common.types import Device + @contextlib.contextmanager def nullcontext(x): @@ -31,7 +33,7 @@ def _open_file(f, path_manager: PathManager, mode="r") -> ContextManager[IO]: def _make_tensor( - data, cols: int, dtype: torch.dtype, device: str = "cpu" + data, cols: int, dtype: torch.dtype, device: Device = "cpu" ) -> torch.Tensor: """ Return a 2D tensor with the specified cols and dtype filled with data, diff --git a/pytorch3d/renderer/cameras.py b/pytorch3d/renderer/cameras.py index 498ae17b..3c2f894a 100644 --- a/pytorch3d/renderer/cameras.py +++ b/pytorch3d/renderer/cameras.py @@ -7,6 +7,7 @@ from typing import Optional, Sequence, Tuple import numpy as np import torch import torch.nn.functional as F +from pytorch3d.common.types import Device from pytorch3d.transforms import Rotate, Transform3d, Translate from .utils import TensorProperties, convert_to_tensors_and_broadcast @@ -290,7 +291,7 @@ def OpenGLPerspectiveCameras( degrees: bool = True, R=_R, T=_T, - device="cpu", + device: Device = "cpu", ): """ OpenGLPerspectiveCameras has been DEPRECATED. Use FoVPerspectiveCameras instead. @@ -358,7 +359,7 @@ class FoVPerspectiveCameras(CamerasBase): R=_R, T=_T, K=None, - device="cpu", + device: Device = "cpu", ): """ @@ -373,7 +374,7 @@ class FoVPerspectiveCameras(CamerasBase): T: Translation matrix of shape (N, 3) K: (optional) A calibration matrix of shape (N, 4, 4) If provided, don't need znear, zfar, fov, aspect_ratio, degrees - device: torch.device or string + device: Device (as str or torch.device) """ # The initializer formats all inputs to torch tensors and broadcasts # all the inputs to have the same batch dimension where necessary. diff --git a/pytorch3d/renderer/lighting.py b/pytorch3d/renderer/lighting.py index 05ad4858..915f7141 100644 --- a/pytorch3d/renderer/lighting.py +++ b/pytorch3d/renderer/lighting.py @@ -4,6 +4,7 @@ import torch import torch.nn.functional as F +from ..common.types import Device from .utils import TensorProperties, convert_to_tensors_and_broadcast @@ -158,7 +159,7 @@ class DirectionalLights(TensorProperties): diffuse_color=((0.3, 0.3, 0.3),), specular_color=((0.2, 0.2, 0.2),), direction=((0, 1, 0),), - device="cpu", + device: Device = "cpu", ): """ Args: @@ -166,7 +167,7 @@ class DirectionalLights(TensorProperties): diffuse_color: RGB color of the diffuse component. specular_color: RGB color of the specular component. direction: (x, y, z) direction vector of the light. - device: torch.device on which the tensors should be located + device: Device (as str or torch.device) on which the tensors should be located The inputs can each be - 3 element tuple/list or list of lists @@ -219,7 +220,7 @@ class PointLights(TensorProperties): diffuse_color=((0.3, 0.3, 0.3),), specular_color=((0.2, 0.2, 0.2),), location=((0, 1, 0),), - device="cpu", + device: Device = "cpu", ): """ Args: @@ -227,7 +228,7 @@ class PointLights(TensorProperties): diffuse_color: RGB color of the diffuse component specular_color: RGB color of the specular component location: xyz position of the light. - device: torch.device on which the tensors should be located + device: Device (as str or torch.device) on which the tensors should be located The inputs can each be - 3 element tuple/list or list of lists @@ -275,14 +276,14 @@ class AmbientLights(TensorProperties): not used in rendering. """ - def __init__(self, *, ambient_color=None, device="cpu"): + def __init__(self, *, ambient_color=None, device: Device = "cpu"): """ If ambient_color is provided, it should be a sequence of triples of floats. Args: ambient_color: RGB color - device: torch.device on which the tensors should be located + device: Device (as str or torch.device) on which the tensors should be located The ambient_color if provided, should be - 3 element tuple/list or list of lists diff --git a/pytorch3d/renderer/materials.py b/pytorch3d/renderer/materials.py index ce37700b..ef428b93 100644 --- a/pytorch3d/renderer/materials.py +++ b/pytorch3d/renderer/materials.py @@ -3,6 +3,7 @@ import torch +from ..common.types import Device from .utils import TensorProperties @@ -18,7 +19,7 @@ class Materials(TensorProperties): diffuse_color=((1, 1, 1),), specular_color=((1, 1, 1),), shininess=64, - device="cpu", + device: Device = "cpu", ): """ Args: @@ -29,7 +30,7 @@ class Materials(TensorProperties): the focus of the specular highlight with a high value resulting in a concentrated highlight. Shininess values can range from 0-1000. - device: torch.device or string + device: Device (as str or torch.device) on which the tensors should be located ambient_color, diffuse_color and specular_color can be of shape (1, 3) or (N, 3). shininess can be of shape (1) or (N). diff --git a/pytorch3d/renderer/mesh/shader.py b/pytorch3d/renderer/mesh/shader.py index c213457d..abab7304 100644 --- a/pytorch3d/renderer/mesh/shader.py +++ b/pytorch3d/renderer/mesh/shader.py @@ -5,6 +5,7 @@ import warnings import torch import torch.nn as nn +from ...common.types import Device from ..blending import ( BlendParams, hard_rgb_blend, @@ -40,7 +41,12 @@ class HardPhongShader(nn.Module): """ def __init__( - self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None + self, + device: Device = "cpu", + cameras=None, + lights=None, + materials=None, + blend_params=None, ): super().__init__() self.lights = lights if lights is not None else PointLights(device=device) @@ -50,7 +56,7 @@ class HardPhongShader(nn.Module): self.cameras = cameras self.blend_params = blend_params if blend_params is not None else BlendParams() - def to(self, device): + def to(self, device: Device): # Manually move to device modules which are not subclasses of nn.Module self.cameras = self.cameras.to(device) self.materials = self.materials.to(device) @@ -95,7 +101,12 @@ class SoftPhongShader(nn.Module): """ def __init__( - self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None + self, + device: Device = "cpu", + cameras=None, + lights=None, + materials=None, + blend_params=None, ): super().__init__() self.lights = lights if lights is not None else PointLights(device=device) @@ -105,7 +116,7 @@ class SoftPhongShader(nn.Module): self.cameras = cameras self.blend_params = blend_params if blend_params is not None else BlendParams() - def to(self, device): + def to(self, device: Device): # Manually move to device modules which are not subclasses of nn.Module self.cameras = self.cameras.to(device) self.materials = self.materials.to(device) @@ -155,7 +166,12 @@ class HardGouraudShader(nn.Module): """ def __init__( - self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None + self, + device: Device = "cpu", + cameras=None, + lights=None, + materials=None, + blend_params=None, ): super().__init__() self.lights = lights if lights is not None else PointLights(device=device) @@ -165,7 +181,7 @@ class HardGouraudShader(nn.Module): self.cameras = cameras self.blend_params = blend_params if blend_params is not None else BlendParams() - def to(self, device): + def to(self, device: Device): # Manually move to device modules which are not subclasses of nn.Module self.cameras = self.cameras.to(device) self.materials = self.materials.to(device) @@ -214,7 +230,12 @@ class SoftGouraudShader(nn.Module): """ def __init__( - self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None + self, + device: Device = "cpu", + cameras=None, + lights=None, + materials=None, + blend_params=None, ): super().__init__() self.lights = lights if lights is not None else PointLights(device=device) @@ -224,7 +245,7 @@ class SoftGouraudShader(nn.Module): self.cameras = cameras self.blend_params = blend_params if blend_params is not None else BlendParams() - def to(self, device): + def to(self, device: Device): # Manually move to device modules which are not subclasses of nn.Module self.cameras = self.cameras.to(device) self.materials = self.materials.to(device) @@ -255,7 +276,7 @@ class SoftGouraudShader(nn.Module): def TexturedSoftPhongShader( - device="cpu", cameras=None, lights=None, materials=None, blend_params=None + device: Device = "cpu", cameras=None, lights=None, materials=None, blend_params=None ): """ TexturedSoftPhongShader class has been DEPRECATED. Use SoftPhongShader instead. @@ -290,7 +311,12 @@ class HardFlatShader(nn.Module): """ def __init__( - self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None + self, + device: Device = "cpu", + cameras=None, + lights=None, + materials=None, + blend_params=None, ): super().__init__() self.lights = lights if lights is not None else PointLights(device=device) @@ -300,7 +326,7 @@ class HardFlatShader(nn.Module): self.cameras = cameras self.blend_params = blend_params if blend_params is not None else BlendParams() - def to(self, device): + def to(self, device: Device): # Manually move to device modules which are not subclasses of nn.Module self.cameras = self.cameras.to(device) self.materials = self.materials.to(device) diff --git a/pytorch3d/renderer/utils.py b/pytorch3d/renderer/utils.py index e714e563..937e31ef 100644 --- a/pytorch3d/renderer/utils.py +++ b/pytorch3d/renderer/utils.py @@ -10,6 +10,8 @@ import numpy as np import torch import torch.nn as nn +from ..common.types import Device, make_device + class TensorAccessor(nn.Module): """ @@ -88,17 +90,19 @@ class TensorProperties(nn.Module): A mix-in class for storing tensors as properties with helper methods. """ - def __init__(self, dtype=torch.float32, device="cpu", **kwargs): + def __init__( + self, dtype: torch.dtype = torch.float32, device: Device = "cpu", **kwargs + ): """ Args: dtype: data type to set for the inputs - device: str or torch.device + device: Device (as str or torch.device) kwargs: any number of keyword arguments. Any arguments which are - of type (float/int/tuple/tensor/array) are broadcasted and + of type (float/int/list/tuple/tensor/array) are broadcasted and other keyword arguments are set as attributes. """ super().__init__() - self.device = device + self.device = make_device(device) self._N = 0 if kwargs is not None: @@ -108,7 +112,7 @@ class TensorProperties(nn.Module): for k, v in kwargs.items(): if v is None or isinstance(v, (str, bool)): setattr(self, k, v) - elif isinstance(v, BROADCAST_TYPES): + elif isinstance(v, BROADCAST_TYPES): # pyre-fixme[6] args_to_broadcast[k] = v else: msg = "Arg %s with type %r is not broadcastable" @@ -152,17 +156,18 @@ class TensorProperties(nn.Module): msg = "Expected index of type int or slice; got %r" raise ValueError(msg % type(index)) - def to(self, device: str = "cpu"): + def to(self, device: Device = "cpu"): """ In place operation to move class properties which are tensors to a specified device. If self has a property "device", update this as well. """ + device_ = make_device(device) for k in dir(self): v = getattr(self, k) if k == "device": - setattr(self, k, device) - if torch.is_tensor(v) and v.device != device: - setattr(self, k, v.to(device)) + setattr(self, k, device_) + if torch.is_tensor(v) and v.device != device_: + setattr(self, k, v.to(device_)) return self def clone(self, other): @@ -257,28 +262,37 @@ class TensorProperties(nn.Module): return self -def format_tensor(input, dtype=torch.float32, device: str = "cpu") -> torch.Tensor: +def format_tensor( + input, dtype: torch.dtype = torch.float32, device: Device = "cpu" +) -> torch.Tensor: """ Helper function for converting a scalar value to a tensor. Args: input: Python scalar, Python list/tuple, torch scalar, 1D torch tensor dtype: data type for the input - device: torch device on which the tensor should be placed. + device: Device (as str or torch.device) on which the tensor should be placed. Returns: input_vec: torch tensor with optional added batch dimension. """ + device_ = make_device(device) if not torch.is_tensor(input): - input = torch.tensor(input, dtype=dtype, device=device) + input = torch.tensor(input, dtype=dtype, device=device_) + if input.dim() == 0: input = input.view(1) - if input.device != device: - input = input.to(device=device) + + if input.device == device_: + return input + + input = input.to(device=device) return input -def convert_to_tensors_and_broadcast(*args, dtype=torch.float32, device: str = "cpu"): +def convert_to_tensors_and_broadcast( + *args, dtype: torch.dtype = torch.float32, device: Device = "cpu" +): """ Helper function to handle parsing an arbitrary number of inputs (*args) which all need to have the same batch dimension.