mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-03 04:12:48 +08:00
Fix type annotations for device type
Summary: Fix type annotations for device type Reviewed By: nikhilaravi Differential Revision: D28971179 fbshipit-source-id: 410b673c76dfd65ac51b2d144f17ed86a04a3058
This commit is contained in:
parent
1f9661e150
commit
626bf3fe23
@ -9,6 +9,7 @@ from typing import Dict, List, Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from pytorch3d.common.types import Device
|
||||||
from pytorch3d.datasets.shapenet_base import ShapeNetBase
|
from pytorch3d.datasets.shapenet_base import ShapeNetBase
|
||||||
from pytorch3d.renderer import HardPhongShader
|
from pytorch3d.renderer import HardPhongShader
|
||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
@ -371,7 +372,7 @@ class R2N2(ShapeNetBase):
|
|||||||
idxs: Optional[List[int]] = None,
|
idxs: Optional[List[int]] = None,
|
||||||
view_idxs: Optional[List[int]] = None,
|
view_idxs: Optional[List[int]] = None,
|
||||||
shader_type=HardPhongShader,
|
shader_type=HardPhongShader,
|
||||||
device="cpu",
|
device: Device = "cpu",
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@ -394,7 +395,7 @@ class R2N2(ShapeNetBase):
|
|||||||
idxs: List[int] of indices of models to be rendered in the dataset.
|
idxs: List[int] of indices of models to be rendered in the dataset.
|
||||||
shader_type: Shader to use for rendering. Examples include HardPhongShader
|
shader_type: Shader to use for rendering. Examples include HardPhongShader
|
||||||
(default), SoftPhongShader etc or any other type of valid Shader class.
|
(default), SoftPhongShader etc or any other type of valid Shader class.
|
||||||
device: torch.device on which the tensors should be located.
|
device: Device (as str or torch.device) on which the tensors should be located.
|
||||||
**kwargs: Accepts any of the kwargs that the renderer supports and any of the
|
**kwargs: Accepts any of the kwargs that the renderer supports and any of the
|
||||||
args that BlenderCamera supports.
|
args that BlenderCamera supports.
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ from typing import Dict, List
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from pytorch3d.common.types import Device
|
||||||
from pytorch3d.datasets.utils import collate_batched_meshes
|
from pytorch3d.datasets.utils import collate_batched_meshes
|
||||||
from pytorch3d.ops import cubify
|
from pytorch3d.ops import cubify
|
||||||
from pytorch3d.renderer import (
|
from pytorch3d.renderer import (
|
||||||
@ -431,13 +432,13 @@ class BlenderCamera(CamerasBase):
|
|||||||
(which uses Blender for rendering the views for each model).
|
(which uses Blender for rendering the views for each model).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, R=r, T=t, K=k, device="cpu"):
|
def __init__(self, R=r, T=t, K=k, device: Device = "cpu"):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
R: Rotation matrix of shape (N, 3, 3).
|
R: Rotation matrix of shape (N, 3, 3).
|
||||||
T: Translation matrix of shape (N, 3).
|
T: Translation matrix of shape (N, 3).
|
||||||
K: Intrinsic matrix of shape (N, 4, 4).
|
K: Intrinsic matrix of shape (N, 4, 4).
|
||||||
device: torch.device or str.
|
device: Device (as str or torch.device).
|
||||||
"""
|
"""
|
||||||
# The initializer formats all inputs to torch tensors and broadcasts
|
# The initializer formats all inputs to torch tensors and broadcasts
|
||||||
# all the inputs to have the same batch dimension where necessary.
|
# all the inputs to have the same batch dimension where necessary.
|
||||||
@ -450,7 +451,7 @@ class BlenderCamera(CamerasBase):
|
|||||||
|
|
||||||
|
|
||||||
def render_cubified_voxels(
|
def render_cubified_voxels(
|
||||||
voxels: torch.Tensor, shader_type=HardPhongShader, device="cpu", **kwargs
|
voxels: torch.Tensor, shader_type=HardPhongShader, device: Device = "cpu", **kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Use the Cubify operator to convert inputs voxels to a mesh and then render that mesh.
|
Use the Cubify operator to convert inputs voxels to a mesh and then render that mesh.
|
||||||
@ -461,7 +462,7 @@ def render_cubified_voxels(
|
|||||||
shader_type: shader_type: shader_type: Shader to use for rendering. Examples
|
shader_type: shader_type: shader_type: Shader to use for rendering. Examples
|
||||||
include HardPhongShader (default), SoftPhongShader etc or any other type
|
include HardPhongShader (default), SoftPhongShader etc or any other type
|
||||||
of valid Shader class.
|
of valid Shader class.
|
||||||
device: torch.device on which the tensors should be located.
|
device: Device (as str or torch.device) on which the tensors should be located.
|
||||||
**kwargs: Accepts any of the kwargs that the renderer supports.
|
**kwargs: Accepts any of the kwargs that the renderer supports.
|
||||||
Returns:
|
Returns:
|
||||||
Batch of rendered images of shape (N, H, W, 3).
|
Batch of rendered images of shape (N, H, W, 3).
|
||||||
|
@ -4,6 +4,7 @@ import warnings
|
|||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from pytorch3d.common.types import Device
|
||||||
from pytorch3d.io import load_obj
|
from pytorch3d.io import load_obj
|
||||||
from pytorch3d.renderer import (
|
from pytorch3d.renderer import (
|
||||||
FoVPerspectiveCameras,
|
FoVPerspectiveCameras,
|
||||||
@ -105,7 +106,7 @@ class ShapeNetBase(torch.utils.data.Dataset):
|
|||||||
sample_nums: Optional[List[int]] = None,
|
sample_nums: Optional[List[int]] = None,
|
||||||
idxs: Optional[List[int]] = None,
|
idxs: Optional[List[int]] = None,
|
||||||
shader_type=HardPhongShader,
|
shader_type=HardPhongShader,
|
||||||
device="cpu",
|
device: Device = "cpu",
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@ -129,7 +130,7 @@ class ShapeNetBase(torch.utils.data.Dataset):
|
|||||||
shader_type: Select shading. Valid options include HardPhongShader (default),
|
shader_type: Select shading. Valid options include HardPhongShader (default),
|
||||||
SoftPhongShader, HardGouraudShader, SoftGouraudShader, HardFlatShader,
|
SoftPhongShader, HardGouraudShader, SoftGouraudShader, HardFlatShader,
|
||||||
SoftSilhouetteShader.
|
SoftSilhouetteShader.
|
||||||
device: torch.device on which the tensors should be located.
|
device: Device (as str or torch.device) on which the tensors should be located.
|
||||||
**kwargs: Accepts any of the kwargs that the renderer supports.
|
**kwargs: Accepts any of the kwargs that the renderer supports.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -9,6 +9,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from iopath.common.file_io import PathManager
|
from iopath.common.file_io import PathManager
|
||||||
|
from pytorch3d.common.types import Device
|
||||||
from pytorch3d.io.utils import _open_file, _read_image
|
from pytorch3d.io.utils import _open_file, _read_image
|
||||||
|
|
||||||
|
|
||||||
@ -393,7 +394,7 @@ TextureImages = Dict[str, torch.Tensor]
|
|||||||
|
|
||||||
|
|
||||||
def _parse_mtl(
|
def _parse_mtl(
|
||||||
f, path_manager: PathManager, device="cpu"
|
f, path_manager: PathManager, device: Device = "cpu"
|
||||||
) -> Tuple[MaterialProperties, TextureFiles]:
|
) -> Tuple[MaterialProperties, TextureFiles]:
|
||||||
material_properties = {}
|
material_properties = {}
|
||||||
texture_files = {}
|
texture_files = {}
|
||||||
@ -474,7 +475,7 @@ def load_mtl(
|
|||||||
*,
|
*,
|
||||||
material_names: List[str],
|
material_names: List[str],
|
||||||
data_dir: str,
|
data_dir: str,
|
||||||
device="cpu",
|
device: Device = "cpu",
|
||||||
path_manager: PathManager,
|
path_manager: PathManager,
|
||||||
) -> Tuple[MaterialProperties, TextureImages]:
|
) -> Tuple[MaterialProperties, TextureImages]:
|
||||||
"""
|
"""
|
||||||
@ -485,6 +486,7 @@ def load_mtl(
|
|||||||
f: a file-like object of the material information.
|
f: a file-like object of the material information.
|
||||||
material_names: a list of the material names found in the .obj file.
|
material_names: a list of the material names found in the .obj file.
|
||||||
data_dir: the directory where the material texture files are located.
|
data_dir: the directory where the material texture files are located.
|
||||||
|
device: Device (as str or torch.tensor) on which to return the new tensors.
|
||||||
path_manager: PathManager for interpreting both f and material_names.
|
path_manager: PathManager for interpreting both f and material_names.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -11,6 +11,7 @@ from typing import List, Optional, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from iopath.common.file_io import PathManager
|
from iopath.common.file_io import PathManager
|
||||||
|
from pytorch3d.common.types import Device
|
||||||
from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas
|
from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas
|
||||||
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
|
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
|
||||||
from pytorch3d.renderer import TexturesAtlas, TexturesUV
|
from pytorch3d.renderer import TexturesAtlas, TexturesUV
|
||||||
@ -71,7 +72,7 @@ def load_obj(
|
|||||||
create_texture_atlas: bool = False,
|
create_texture_atlas: bool = False,
|
||||||
texture_atlas_size: int = 4,
|
texture_atlas_size: int = 4,
|
||||||
texture_wrap: Optional[str] = "repeat",
|
texture_wrap: Optional[str] = "repeat",
|
||||||
device="cpu",
|
device: Device = "cpu",
|
||||||
path_manager: Optional[PathManager] = None,
|
path_manager: Optional[PathManager] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -143,7 +144,7 @@ def load_obj(
|
|||||||
is ignored and a repeating pattern is formed.
|
is ignored and a repeating pattern is formed.
|
||||||
If `texture_mode="clamp"` the values are clamped to the range [0, 1].
|
If `texture_mode="clamp"` the values are clamped to the range [0, 1].
|
||||||
If None, then there is no transformation of the texture values.
|
If None, then there is no transformation of the texture values.
|
||||||
device: string or torch.device on which to return the new tensors.
|
device: Device (as str or torch.device) on which to return the new tensors.
|
||||||
path_manager: optionally a PathManager object to interpret paths.
|
path_manager: optionally a PathManager object to interpret paths.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -226,7 +227,7 @@ def load_obj(
|
|||||||
|
|
||||||
def load_objs_as_meshes(
|
def load_objs_as_meshes(
|
||||||
files: list,
|
files: list,
|
||||||
device=None,
|
device: Optional[Device] = None,
|
||||||
load_textures: bool = True,
|
load_textures: bool = True,
|
||||||
create_texture_atlas: bool = False,
|
create_texture_atlas: bool = False,
|
||||||
texture_atlas_size: int = 4,
|
texture_atlas_size: int = 4,
|
||||||
@ -293,7 +294,7 @@ class MeshObjFormat(MeshFormatInterpreter):
|
|||||||
self,
|
self,
|
||||||
path: Union[str, Path],
|
path: Union[str, Path],
|
||||||
include_textures: bool,
|
include_textures: bool,
|
||||||
device,
|
device: Device,
|
||||||
path_manager: PathManager,
|
path_manager: PathManager,
|
||||||
create_texture_atlas: bool = False,
|
create_texture_atlas: bool = False,
|
||||||
texture_atlas_size: int = 4,
|
texture_atlas_size: int = 4,
|
||||||
@ -497,7 +498,7 @@ def _load_materials(
|
|||||||
*,
|
*,
|
||||||
data_dir: str,
|
data_dir: str,
|
||||||
load_textures: bool,
|
load_textures: bool,
|
||||||
device,
|
device: Device,
|
||||||
path_manager: PathManager,
|
path_manager: PathManager,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -508,7 +509,7 @@ def _load_materials(
|
|||||||
f: a file-like object of the material information.
|
f: a file-like object of the material information.
|
||||||
data_dir: the directory where the material texture files are located.
|
data_dir: the directory where the material texture files are located.
|
||||||
load_textures: whether textures should be loaded.
|
load_textures: whether textures should be loaded.
|
||||||
device: string or torch.device on which to return the new tensors.
|
device: Device (as str or torch.device) on which to return the new tensors.
|
||||||
path_manager: PathManager object to interpret paths.
|
path_manager: PathManager object to interpret paths.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -546,7 +547,7 @@ def _load_obj(
|
|||||||
texture_atlas_size: int = 4,
|
texture_atlas_size: int = 4,
|
||||||
texture_wrap: Optional[str] = "repeat",
|
texture_wrap: Optional[str] = "repeat",
|
||||||
path_manager: PathManager,
|
path_manager: PathManager,
|
||||||
device="cpu",
|
device: Device = "cpu",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Load a mesh from a file-like object. See load_obj function more details.
|
Load a mesh from a file-like object. See load_obj function more details.
|
||||||
|
@ -8,6 +8,7 @@ from pathlib import Path
|
|||||||
from typing import Deque, Optional, Union
|
from typing import Deque, Optional, Union
|
||||||
|
|
||||||
from iopath.common.file_io import PathManager
|
from iopath.common.file_io import PathManager
|
||||||
|
from pytorch3d.common.types import Device
|
||||||
from pytorch3d.structures import Meshes, Pointclouds
|
from pytorch3d.structures import Meshes, Pointclouds
|
||||||
|
|
||||||
from .obj_io import MeshObjFormat
|
from .obj_io import MeshObjFormat
|
||||||
@ -108,7 +109,7 @@ class IO:
|
|||||||
self,
|
self,
|
||||||
path: Union[str, Path],
|
path: Union[str, Path],
|
||||||
include_textures: bool = True,
|
include_textures: bool = True,
|
||||||
device="cpu",
|
device: Device = "cpu",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Meshes:
|
) -> Meshes:
|
||||||
"""
|
"""
|
||||||
@ -168,14 +169,14 @@ class IO:
|
|||||||
raise ValueError(f"No mesh interpreter found to write to {path}.")
|
raise ValueError(f"No mesh interpreter found to write to {path}.")
|
||||||
|
|
||||||
def load_pointcloud(
|
def load_pointcloud(
|
||||||
self, path: Union[str, Path], device="cpu", **kwargs
|
self, path: Union[str, Path], device: Device = "cpu", **kwargs
|
||||||
) -> Pointclouds:
|
) -> Pointclouds:
|
||||||
"""
|
"""
|
||||||
Attempt to load a point cloud from the given file, using a registered format.
|
Attempt to load a point cloud from the given file, using a registered format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path: file to read
|
path: file to read
|
||||||
device: torch.device on which to load the data.
|
device: Device (as str or torch.device) on which to load the data.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
new Pointclouds object containing one mesh.
|
new Pointclouds object containing one mesh.
|
||||||
|
@ -10,6 +10,8 @@ import torch
|
|||||||
from iopath.common.file_io import PathManager
|
from iopath.common.file_io import PathManager
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..common.types import Device
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def nullcontext(x):
|
def nullcontext(x):
|
||||||
@ -31,7 +33,7 @@ def _open_file(f, path_manager: PathManager, mode="r") -> ContextManager[IO]:
|
|||||||
|
|
||||||
|
|
||||||
def _make_tensor(
|
def _make_tensor(
|
||||||
data, cols: int, dtype: torch.dtype, device: str = "cpu"
|
data, cols: int, dtype: torch.dtype, device: Device = "cpu"
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Return a 2D tensor with the specified cols and dtype filled with data,
|
Return a 2D tensor with the specified cols and dtype filled with data,
|
||||||
|
@ -7,6 +7,7 @@ from typing import Optional, Sequence, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from pytorch3d.common.types import Device
|
||||||
from pytorch3d.transforms import Rotate, Transform3d, Translate
|
from pytorch3d.transforms import Rotate, Transform3d, Translate
|
||||||
|
|
||||||
from .utils import TensorProperties, convert_to_tensors_and_broadcast
|
from .utils import TensorProperties, convert_to_tensors_and_broadcast
|
||||||
@ -290,7 +291,7 @@ def OpenGLPerspectiveCameras(
|
|||||||
degrees: bool = True,
|
degrees: bool = True,
|
||||||
R=_R,
|
R=_R,
|
||||||
T=_T,
|
T=_T,
|
||||||
device="cpu",
|
device: Device = "cpu",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
OpenGLPerspectiveCameras has been DEPRECATED. Use FoVPerspectiveCameras instead.
|
OpenGLPerspectiveCameras has been DEPRECATED. Use FoVPerspectiveCameras instead.
|
||||||
@ -358,7 +359,7 @@ class FoVPerspectiveCameras(CamerasBase):
|
|||||||
R=_R,
|
R=_R,
|
||||||
T=_T,
|
T=_T,
|
||||||
K=None,
|
K=None,
|
||||||
device="cpu",
|
device: Device = "cpu",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -373,7 +374,7 @@ class FoVPerspectiveCameras(CamerasBase):
|
|||||||
T: Translation matrix of shape (N, 3)
|
T: Translation matrix of shape (N, 3)
|
||||||
K: (optional) A calibration matrix of shape (N, 4, 4)
|
K: (optional) A calibration matrix of shape (N, 4, 4)
|
||||||
If provided, don't need znear, zfar, fov, aspect_ratio, degrees
|
If provided, don't need znear, zfar, fov, aspect_ratio, degrees
|
||||||
device: torch.device or string
|
device: Device (as str or torch.device)
|
||||||
"""
|
"""
|
||||||
# The initializer formats all inputs to torch tensors and broadcasts
|
# The initializer formats all inputs to torch tensors and broadcasts
|
||||||
# all the inputs to have the same batch dimension where necessary.
|
# all the inputs to have the same batch dimension where necessary.
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from ..common.types import Device
|
||||||
from .utils import TensorProperties, convert_to_tensors_and_broadcast
|
from .utils import TensorProperties, convert_to_tensors_and_broadcast
|
||||||
|
|
||||||
|
|
||||||
@ -158,7 +159,7 @@ class DirectionalLights(TensorProperties):
|
|||||||
diffuse_color=((0.3, 0.3, 0.3),),
|
diffuse_color=((0.3, 0.3, 0.3),),
|
||||||
specular_color=((0.2, 0.2, 0.2),),
|
specular_color=((0.2, 0.2, 0.2),),
|
||||||
direction=((0, 1, 0),),
|
direction=((0, 1, 0),),
|
||||||
device="cpu",
|
device: Device = "cpu",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -166,7 +167,7 @@ class DirectionalLights(TensorProperties):
|
|||||||
diffuse_color: RGB color of the diffuse component.
|
diffuse_color: RGB color of the diffuse component.
|
||||||
specular_color: RGB color of the specular component.
|
specular_color: RGB color of the specular component.
|
||||||
direction: (x, y, z) direction vector of the light.
|
direction: (x, y, z) direction vector of the light.
|
||||||
device: torch.device on which the tensors should be located
|
device: Device (as str or torch.device) on which the tensors should be located
|
||||||
|
|
||||||
The inputs can each be
|
The inputs can each be
|
||||||
- 3 element tuple/list or list of lists
|
- 3 element tuple/list or list of lists
|
||||||
@ -219,7 +220,7 @@ class PointLights(TensorProperties):
|
|||||||
diffuse_color=((0.3, 0.3, 0.3),),
|
diffuse_color=((0.3, 0.3, 0.3),),
|
||||||
specular_color=((0.2, 0.2, 0.2),),
|
specular_color=((0.2, 0.2, 0.2),),
|
||||||
location=((0, 1, 0),),
|
location=((0, 1, 0),),
|
||||||
device="cpu",
|
device: Device = "cpu",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -227,7 +228,7 @@ class PointLights(TensorProperties):
|
|||||||
diffuse_color: RGB color of the diffuse component
|
diffuse_color: RGB color of the diffuse component
|
||||||
specular_color: RGB color of the specular component
|
specular_color: RGB color of the specular component
|
||||||
location: xyz position of the light.
|
location: xyz position of the light.
|
||||||
device: torch.device on which the tensors should be located
|
device: Device (as str or torch.device) on which the tensors should be located
|
||||||
|
|
||||||
The inputs can each be
|
The inputs can each be
|
||||||
- 3 element tuple/list or list of lists
|
- 3 element tuple/list or list of lists
|
||||||
@ -275,14 +276,14 @@ class AmbientLights(TensorProperties):
|
|||||||
not used in rendering.
|
not used in rendering.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *, ambient_color=None, device="cpu"):
|
def __init__(self, *, ambient_color=None, device: Device = "cpu"):
|
||||||
"""
|
"""
|
||||||
If ambient_color is provided, it should be a sequence of
|
If ambient_color is provided, it should be a sequence of
|
||||||
triples of floats.
|
triples of floats.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ambient_color: RGB color
|
ambient_color: RGB color
|
||||||
device: torch.device on which the tensors should be located
|
device: Device (as str or torch.device) on which the tensors should be located
|
||||||
|
|
||||||
The ambient_color if provided, should be
|
The ambient_color if provided, should be
|
||||||
- 3 element tuple/list or list of lists
|
- 3 element tuple/list or list of lists
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from ..common.types import Device
|
||||||
from .utils import TensorProperties
|
from .utils import TensorProperties
|
||||||
|
|
||||||
|
|
||||||
@ -18,7 +19,7 @@ class Materials(TensorProperties):
|
|||||||
diffuse_color=((1, 1, 1),),
|
diffuse_color=((1, 1, 1),),
|
||||||
specular_color=((1, 1, 1),),
|
specular_color=((1, 1, 1),),
|
||||||
shininess=64,
|
shininess=64,
|
||||||
device="cpu",
|
device: Device = "cpu",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -29,7 +30,7 @@ class Materials(TensorProperties):
|
|||||||
the focus of the specular highlight with a high value
|
the focus of the specular highlight with a high value
|
||||||
resulting in a concentrated highlight. Shininess values
|
resulting in a concentrated highlight. Shininess values
|
||||||
can range from 0-1000.
|
can range from 0-1000.
|
||||||
device: torch.device or string
|
device: Device (as str or torch.device) on which the tensors should be located
|
||||||
|
|
||||||
ambient_color, diffuse_color and specular_color can be of shape
|
ambient_color, diffuse_color and specular_color can be of shape
|
||||||
(1, 3) or (N, 3). shininess can be of shape (1) or (N).
|
(1, 3) or (N, 3). shininess can be of shape (1) or (N).
|
||||||
|
@ -5,6 +5,7 @@ import warnings
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from ...common.types import Device
|
||||||
from ..blending import (
|
from ..blending import (
|
||||||
BlendParams,
|
BlendParams,
|
||||||
hard_rgb_blend,
|
hard_rgb_blend,
|
||||||
@ -40,7 +41,12 @@ class HardPhongShader(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
|
self,
|
||||||
|
device: Device = "cpu",
|
||||||
|
cameras=None,
|
||||||
|
lights=None,
|
||||||
|
materials=None,
|
||||||
|
blend_params=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lights = lights if lights is not None else PointLights(device=device)
|
self.lights = lights if lights is not None else PointLights(device=device)
|
||||||
@ -50,7 +56,7 @@ class HardPhongShader(nn.Module):
|
|||||||
self.cameras = cameras
|
self.cameras = cameras
|
||||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||||
|
|
||||||
def to(self, device):
|
def to(self, device: Device):
|
||||||
# Manually move to device modules which are not subclasses of nn.Module
|
# Manually move to device modules which are not subclasses of nn.Module
|
||||||
self.cameras = self.cameras.to(device)
|
self.cameras = self.cameras.to(device)
|
||||||
self.materials = self.materials.to(device)
|
self.materials = self.materials.to(device)
|
||||||
@ -95,7 +101,12 @@ class SoftPhongShader(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
|
self,
|
||||||
|
device: Device = "cpu",
|
||||||
|
cameras=None,
|
||||||
|
lights=None,
|
||||||
|
materials=None,
|
||||||
|
blend_params=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lights = lights if lights is not None else PointLights(device=device)
|
self.lights = lights if lights is not None else PointLights(device=device)
|
||||||
@ -105,7 +116,7 @@ class SoftPhongShader(nn.Module):
|
|||||||
self.cameras = cameras
|
self.cameras = cameras
|
||||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||||
|
|
||||||
def to(self, device):
|
def to(self, device: Device):
|
||||||
# Manually move to device modules which are not subclasses of nn.Module
|
# Manually move to device modules which are not subclasses of nn.Module
|
||||||
self.cameras = self.cameras.to(device)
|
self.cameras = self.cameras.to(device)
|
||||||
self.materials = self.materials.to(device)
|
self.materials = self.materials.to(device)
|
||||||
@ -155,7 +166,12 @@ class HardGouraudShader(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
|
self,
|
||||||
|
device: Device = "cpu",
|
||||||
|
cameras=None,
|
||||||
|
lights=None,
|
||||||
|
materials=None,
|
||||||
|
blend_params=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lights = lights if lights is not None else PointLights(device=device)
|
self.lights = lights if lights is not None else PointLights(device=device)
|
||||||
@ -165,7 +181,7 @@ class HardGouraudShader(nn.Module):
|
|||||||
self.cameras = cameras
|
self.cameras = cameras
|
||||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||||
|
|
||||||
def to(self, device):
|
def to(self, device: Device):
|
||||||
# Manually move to device modules which are not subclasses of nn.Module
|
# Manually move to device modules which are not subclasses of nn.Module
|
||||||
self.cameras = self.cameras.to(device)
|
self.cameras = self.cameras.to(device)
|
||||||
self.materials = self.materials.to(device)
|
self.materials = self.materials.to(device)
|
||||||
@ -214,7 +230,12 @@ class SoftGouraudShader(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
|
self,
|
||||||
|
device: Device = "cpu",
|
||||||
|
cameras=None,
|
||||||
|
lights=None,
|
||||||
|
materials=None,
|
||||||
|
blend_params=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lights = lights if lights is not None else PointLights(device=device)
|
self.lights = lights if lights is not None else PointLights(device=device)
|
||||||
@ -224,7 +245,7 @@ class SoftGouraudShader(nn.Module):
|
|||||||
self.cameras = cameras
|
self.cameras = cameras
|
||||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||||
|
|
||||||
def to(self, device):
|
def to(self, device: Device):
|
||||||
# Manually move to device modules which are not subclasses of nn.Module
|
# Manually move to device modules which are not subclasses of nn.Module
|
||||||
self.cameras = self.cameras.to(device)
|
self.cameras = self.cameras.to(device)
|
||||||
self.materials = self.materials.to(device)
|
self.materials = self.materials.to(device)
|
||||||
@ -255,7 +276,7 @@ class SoftGouraudShader(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def TexturedSoftPhongShader(
|
def TexturedSoftPhongShader(
|
||||||
device="cpu", cameras=None, lights=None, materials=None, blend_params=None
|
device: Device = "cpu", cameras=None, lights=None, materials=None, blend_params=None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
TexturedSoftPhongShader class has been DEPRECATED. Use SoftPhongShader instead.
|
TexturedSoftPhongShader class has been DEPRECATED. Use SoftPhongShader instead.
|
||||||
@ -290,7 +311,12 @@ class HardFlatShader(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
|
self,
|
||||||
|
device: Device = "cpu",
|
||||||
|
cameras=None,
|
||||||
|
lights=None,
|
||||||
|
materials=None,
|
||||||
|
blend_params=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lights = lights if lights is not None else PointLights(device=device)
|
self.lights = lights if lights is not None else PointLights(device=device)
|
||||||
@ -300,7 +326,7 @@ class HardFlatShader(nn.Module):
|
|||||||
self.cameras = cameras
|
self.cameras = cameras
|
||||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||||
|
|
||||||
def to(self, device):
|
def to(self, device: Device):
|
||||||
# Manually move to device modules which are not subclasses of nn.Module
|
# Manually move to device modules which are not subclasses of nn.Module
|
||||||
self.cameras = self.cameras.to(device)
|
self.cameras = self.cameras.to(device)
|
||||||
self.materials = self.materials.to(device)
|
self.materials = self.materials.to(device)
|
||||||
|
@ -10,6 +10,8 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from ..common.types import Device, make_device
|
||||||
|
|
||||||
|
|
||||||
class TensorAccessor(nn.Module):
|
class TensorAccessor(nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -88,17 +90,19 @@ class TensorProperties(nn.Module):
|
|||||||
A mix-in class for storing tensors as properties with helper methods.
|
A mix-in class for storing tensors as properties with helper methods.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dtype=torch.float32, device="cpu", **kwargs):
|
def __init__(
|
||||||
|
self, dtype: torch.dtype = torch.float32, device: Device = "cpu", **kwargs
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
dtype: data type to set for the inputs
|
dtype: data type to set for the inputs
|
||||||
device: str or torch.device
|
device: Device (as str or torch.device)
|
||||||
kwargs: any number of keyword arguments. Any arguments which are
|
kwargs: any number of keyword arguments. Any arguments which are
|
||||||
of type (float/int/tuple/tensor/array) are broadcasted and
|
of type (float/int/list/tuple/tensor/array) are broadcasted and
|
||||||
other keyword arguments are set as attributes.
|
other keyword arguments are set as attributes.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = make_device(device)
|
||||||
self._N = 0
|
self._N = 0
|
||||||
if kwargs is not None:
|
if kwargs is not None:
|
||||||
|
|
||||||
@ -108,7 +112,7 @@ class TensorProperties(nn.Module):
|
|||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
if v is None or isinstance(v, (str, bool)):
|
if v is None or isinstance(v, (str, bool)):
|
||||||
setattr(self, k, v)
|
setattr(self, k, v)
|
||||||
elif isinstance(v, BROADCAST_TYPES):
|
elif isinstance(v, BROADCAST_TYPES): # pyre-fixme[6]
|
||||||
args_to_broadcast[k] = v
|
args_to_broadcast[k] = v
|
||||||
else:
|
else:
|
||||||
msg = "Arg %s with type %r is not broadcastable"
|
msg = "Arg %s with type %r is not broadcastable"
|
||||||
@ -152,17 +156,18 @@ class TensorProperties(nn.Module):
|
|||||||
msg = "Expected index of type int or slice; got %r"
|
msg = "Expected index of type int or slice; got %r"
|
||||||
raise ValueError(msg % type(index))
|
raise ValueError(msg % type(index))
|
||||||
|
|
||||||
def to(self, device: str = "cpu"):
|
def to(self, device: Device = "cpu"):
|
||||||
"""
|
"""
|
||||||
In place operation to move class properties which are tensors to a
|
In place operation to move class properties which are tensors to a
|
||||||
specified device. If self has a property "device", update this as well.
|
specified device. If self has a property "device", update this as well.
|
||||||
"""
|
"""
|
||||||
|
device_ = make_device(device)
|
||||||
for k in dir(self):
|
for k in dir(self):
|
||||||
v = getattr(self, k)
|
v = getattr(self, k)
|
||||||
if k == "device":
|
if k == "device":
|
||||||
setattr(self, k, device)
|
setattr(self, k, device_)
|
||||||
if torch.is_tensor(v) and v.device != device:
|
if torch.is_tensor(v) and v.device != device_:
|
||||||
setattr(self, k, v.to(device))
|
setattr(self, k, v.to(device_))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def clone(self, other):
|
def clone(self, other):
|
||||||
@ -257,28 +262,37 @@ class TensorProperties(nn.Module):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
def format_tensor(input, dtype=torch.float32, device: str = "cpu") -> torch.Tensor:
|
def format_tensor(
|
||||||
|
input, dtype: torch.dtype = torch.float32, device: Device = "cpu"
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Helper function for converting a scalar value to a tensor.
|
Helper function for converting a scalar value to a tensor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input: Python scalar, Python list/tuple, torch scalar, 1D torch tensor
|
input: Python scalar, Python list/tuple, torch scalar, 1D torch tensor
|
||||||
dtype: data type for the input
|
dtype: data type for the input
|
||||||
device: torch device on which the tensor should be placed.
|
device: Device (as str or torch.device) on which the tensor should be placed.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
input_vec: torch tensor with optional added batch dimension.
|
input_vec: torch tensor with optional added batch dimension.
|
||||||
"""
|
"""
|
||||||
|
device_ = make_device(device)
|
||||||
if not torch.is_tensor(input):
|
if not torch.is_tensor(input):
|
||||||
input = torch.tensor(input, dtype=dtype, device=device)
|
input = torch.tensor(input, dtype=dtype, device=device_)
|
||||||
|
|
||||||
if input.dim() == 0:
|
if input.dim() == 0:
|
||||||
input = input.view(1)
|
input = input.view(1)
|
||||||
if input.device != device:
|
|
||||||
input = input.to(device=device)
|
if input.device == device_:
|
||||||
|
return input
|
||||||
|
|
||||||
|
input = input.to(device=device)
|
||||||
return input
|
return input
|
||||||
|
|
||||||
|
|
||||||
def convert_to_tensors_and_broadcast(*args, dtype=torch.float32, device: str = "cpu"):
|
def convert_to_tensors_and_broadcast(
|
||||||
|
*args, dtype: torch.dtype = torch.float32, device: Device = "cpu"
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Helper function to handle parsing an arbitrary number of inputs (*args)
|
Helper function to handle parsing an arbitrary number of inputs (*args)
|
||||||
which all need to have the same batch dimension.
|
which all need to have the same batch dimension.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user