From f593bfd3c258b0ff2b7bdbabfb06ab5210b43a52 Mon Sep 17 00:00:00 2001 From: Patrick Labatut Date: Fri, 25 Jun 2021 19:55:26 -0700 Subject: [PATCH] More type annotations Summary: More type annotations: device, shaders, pluggable I/O, stats in NeRF project, cameras, textures, etc... Reviewed By: nikhilaravi Differential Revision: D29327396 fbshipit-source-id: cdf0ceaaa010e22423088752688c8dd81f1acc3c --- projects/nerf/nerf/dataset.py | 2 +- projects/nerf/nerf/stats.py | 16 ++-- pytorch3d/io/experimental_gltf_io.py | 9 +- pytorch3d/io/obj_io.py | 15 ++- pytorch3d/io/off_io.py | 7 +- pytorch3d/io/pluggable_formats.py | 18 ++-- pytorch3d/io/ply_io.py | 16 ++-- pytorch3d/io/utils.py | 5 +- pytorch3d/renderer/blending.py | 27 +++--- pytorch3d/renderer/cameras.py | 98 ++++++++++++-------- pytorch3d/renderer/mesh/shader.py | 89 ++++++++++-------- pytorch3d/renderer/mesh/textures.py | 12 +-- pytorch3d/renderer/utils.py | 6 +- pytorch3d/structures/volumes.py | 10 +- pytorch3d/transforms/rotation_conversions.py | 14 ++- 15 files changed, 196 insertions(+), 148 deletions(-) diff --git a/projects/nerf/nerf/dataset.py b/projects/nerf/nerf/dataset.py index e8978d07..f4fae445 100644 --- a/projects/nerf/nerf/dataset.py +++ b/projects/nerf/nerf/dataset.py @@ -136,7 +136,7 @@ def download_data( dataset_names: Optional[List[str]] = None, data_root: str = DEFAULT_DATA_ROOT, url_root: str = DEFAULT_URL_ROOT, -): +) -> None: """ Downloads the relevant dataset files. diff --git a/projects/nerf/nerf/stats.py b/projects/nerf/nerf/stats.py index 95708ab6..ddb07640 100644 --- a/projects/nerf/nerf/stats.py +++ b/projects/nerf/nerf/stats.py @@ -29,7 +29,7 @@ class AverageMeter: self.history = [] self.reset() - def reset(self): + def reset(self) -> None: """ Reset the running average meter. """ @@ -38,7 +38,7 @@ class AverageMeter: self.sum = 0 self.count = 0 - def update(self, val: float, n: int = 1, epoch: int = 0): + def update(self, val: float, n: int = 1, epoch: int = 0) -> None: """ Updates the average meter with a value `val`. @@ -123,7 +123,7 @@ class Stats: self.plot_file = plot_file self.hard_reset(epoch=epoch) - def reset(self): + def reset(self) -> None: """ Called before an epoch to clear current epoch buffers. """ @@ -138,7 +138,7 @@ class Stats: # Set a new timestamp. self._epoch_start = time.time() - def hard_reset(self, epoch: int = -1): + def hard_reset(self, epoch: int = -1) -> None: """ Erases all logged data. """ @@ -149,7 +149,7 @@ class Stats: self.stats = {} self.reset() - def new_epoch(self): + def new_epoch(self) -> None: """ Initializes a new epoch. """ @@ -166,7 +166,7 @@ class Stats: val = float(val.sum()) return val - def update(self, preds: dict, stat_set: str = "train"): + def update(self, preds: dict, stat_set: str = "train") -> None: """ Update the internal logs with metrics of a training step. @@ -211,7 +211,7 @@ class Stats: if val is not None: self.stats[stat_set][stat].update(val, epoch=epoch, n=1) - def print(self, max_it: Optional[int] = None, stat_set: str = "train"): + def print(self, max_it: Optional[int] = None, stat_set: str = "train") -> None: """ Print the current values of all stored stats. @@ -247,7 +247,7 @@ class Stats: viz: Visdom = None, visdom_env: Optional[str] = None, plot_file: Optional[str] = None, - ): + ) -> None: """ Plot the line charts of the history of the stats. diff --git a/pytorch3d/io/experimental_gltf_io.py b/pytorch3d/io/experimental_gltf_io.py index ef38a8df..7a42b0c0 100644 --- a/pytorch3d/io/experimental_gltf_io.py +++ b/pytorch3d/io/experimental_gltf_io.py @@ -42,14 +42,13 @@ from base64 import b64decode from collections import deque from enum import IntEnum from io import BytesIO -from pathlib import Path from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union, cast import numpy as np import torch from iopath.common.file_io import PathManager from PIL import Image -from pytorch3d.io.utils import _open_file +from pytorch3d.io.utils import PathOrStr, _open_file from pytorch3d.renderer.mesh import TexturesBase, TexturesUV, TexturesVertex from pytorch3d.structures import Meshes, join_meshes_as_scene from pytorch3d.transforms import Transform3d, quaternion_to_matrix @@ -498,7 +497,7 @@ class _GLTFLoader: def load_meshes( - path: Union[str, Path], + path: PathOrStr, path_manager: PathManager, include_textures: bool = True, ) -> List[Tuple[Optional[str], Meshes]]: @@ -544,7 +543,7 @@ class MeshGlbFormat(MeshFormatInterpreter): def read( self, - path: Union[str, Path], + path: PathOrStr, include_textures: bool, device, path_manager: PathManager, @@ -566,7 +565,7 @@ class MeshGlbFormat(MeshFormatInterpreter): def save( self, data: Meshes, - path: Union[str, Path], + path: PathOrStr, path_manager: PathManager, binary: Optional[bool], **kwargs, diff --git a/pytorch3d/io/obj_io.py b/pytorch3d/io/obj_io.py index 05df1bfd..16cb9eb1 100644 --- a/pytorch3d/io/obj_io.py +++ b/pytorch3d/io/obj_io.py @@ -18,7 +18,12 @@ from iopath.common.file_io import PathManager from PIL import Image from pytorch3d.common.types import Device from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas -from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file +from pytorch3d.io.utils import ( + PathOrStr, + _check_faces_indices, + _make_tensor, + _open_file, +) from pytorch3d.renderer import TexturesAtlas, TexturesUV from pytorch3d.structures import Meshes, join_meshes_as_batch @@ -213,7 +218,7 @@ def load_obj( None. """ data_dir = "./" - if isinstance(f, (str, bytes, os.PathLike)): + if isinstance(f, (str, bytes, Path)): data_dir = os.path.dirname(f) if path_manager is None: path_manager = PathManager() @@ -297,7 +302,7 @@ class MeshObjFormat(MeshFormatInterpreter): def read( self, - path: Union[str, Path], + path: PathOrStr, include_textures: bool, device: Device, path_manager: PathManager, @@ -322,7 +327,7 @@ class MeshObjFormat(MeshFormatInterpreter): def save( self, data: Meshes, - path: Union[str, Path], + path: PathOrStr, path_manager: PathManager, binary: Optional[bool], decimal_places: Optional[int] = None, @@ -650,7 +655,7 @@ def _load_obj( def save_obj( - f: Union[str, os.PathLike], + f: PathOrStr, verts, faces, decimal_places: Optional[int] = None, diff --git a/pytorch3d/io/off_io.py b/pytorch3d/io/off_io.py index 7fe70960..84217656 100644 --- a/pytorch3d/io/off_io.py +++ b/pytorch3d/io/off_io.py @@ -13,13 +13,12 @@ This format is introduced, for example, at http://www.geomview.org/docs/html/OFF.html . """ import warnings -from pathlib import Path from typing import Optional, Tuple, Union, cast import numpy as np import torch from iopath.common.file_io import PathManager -from pytorch3d.io.utils import _check_faces_indices, _open_file +from pytorch3d.io.utils import PathOrStr, _check_faces_indices, _open_file from pytorch3d.renderer import TexturesAtlas, TexturesVertex from pytorch3d.structures import Meshes @@ -424,7 +423,7 @@ class MeshOffFormat(MeshFormatInterpreter): def read( self, - path: Union[str, Path], + path: PathOrStr, include_textures: bool, device, path_manager: PathManager, @@ -460,7 +459,7 @@ class MeshOffFormat(MeshFormatInterpreter): def save( self, data: Meshes, - path: Union[str, Path], + path: PathOrStr, path_manager: PathManager, binary: Optional[bool], decimal_places: Optional[int] = None, diff --git a/pytorch3d/io/pluggable_formats.py b/pytorch3d/io/pluggable_formats.py index 35cb919f..0c52f853 100644 --- a/pytorch3d/io/pluggable_formats.py +++ b/pytorch3d/io/pluggable_formats.py @@ -5,10 +5,12 @@ # LICENSE file in the root directory of this source tree. -from pathlib import Path +import pathlib from typing import Optional, Tuple, Union from iopath.common.file_io import PathManager +from pytorch3d.common.types import Device +from pytorch3d.io.utils import PathOrStr from pytorch3d.structures import Meshes, Pointclouds @@ -20,14 +22,14 @@ its load_* and save_* functions. """ -def endswith(path, suffixes: Tuple[str, ...]) -> bool: +def endswith(path: PathOrStr, suffixes: Tuple[str, ...]) -> bool: """ Returns whether the path ends with one of the given suffixes. If `path` is not actually a path, returns True. This is useful for allowing interpreters to bypass inappropriate paths, but always accepting streams. """ - if isinstance(path, Path): + if isinstance(path, pathlib.Path): return path.suffix.lower() in suffixes if isinstance(path, str): return path.lower().endswith(suffixes) @@ -42,9 +44,9 @@ class MeshFormatInterpreter: def read( self, - path: Union[str, Path], + path: PathOrStr, include_textures: bool, - device, + device: Device, path_manager: PathManager, **kwargs, ) -> Optional[Meshes]: @@ -68,7 +70,7 @@ class MeshFormatInterpreter: def save( self, data: Meshes, - path: Union[str, Path], + path: PathOrStr, path_manager: PathManager, binary: Optional[bool], **kwargs, @@ -96,7 +98,7 @@ class PointcloudFormatInterpreter: """ def read( - self, path: Union[str, Path], device, path_manager: PathManager, **kwargs + self, path: PathOrStr, device: Device, path_manager: PathManager, **kwargs ) -> Optional[Pointclouds]: """ Read the data from the specified file and return it as @@ -117,7 +119,7 @@ class PointcloudFormatInterpreter: def save( self, data: Pointclouds, - path: Union[str, Path], + path: PathOrStr, path_manager: PathManager, binary: Optional[bool], **kwargs, diff --git a/pytorch3d/io/ply_io.py b/pytorch3d/io/ply_io.py index 80b6fc71..802d4636 100644 --- a/pytorch3d/io/ply_io.py +++ b/pytorch3d/io/ply_io.py @@ -15,13 +15,17 @@ import sys import warnings from collections import namedtuple from io import BytesIO, TextIOBase -from pathlib import Path from typing import List, Optional, Tuple, Union, cast import numpy as np import torch from iopath.common.file_io import PathManager -from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file +from pytorch3d.io.utils import ( + PathOrStr, + _check_faces_indices, + _make_tensor, + _open_file, +) from pytorch3d.renderer import TexturesVertex from pytorch3d.structures import Meshes, Pointclouds @@ -1237,7 +1241,7 @@ class MeshPlyFormat(MeshFormatInterpreter): def read( self, - path: Union[str, Path], + path: PathOrStr, include_textures: bool, device, path_manager: PathManager, @@ -1269,7 +1273,7 @@ class MeshPlyFormat(MeshFormatInterpreter): def save( self, data: Meshes, - path: Union[str, Path], + path: PathOrStr, path_manager: PathManager, binary: Optional[bool], decimal_places: Optional[int] = None, @@ -1318,7 +1322,7 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter): def read( self, - path: Union[str, Path], + path: PathOrStr, device, path_manager: PathManager, **kwargs, @@ -1339,7 +1343,7 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter): def save( self, data: Pointclouds, - path: Union[str, Path], + path: PathOrStr, path_manager: PathManager, binary: Optional[bool], decimal_places: Optional[int] = None, diff --git a/pytorch3d/io/utils.py b/pytorch3d/io/utils.py index 08ac27c2..55a8e772 100644 --- a/pytorch3d/io/utils.py +++ b/pytorch3d/io/utils.py @@ -7,7 +7,7 @@ import contextlib import pathlib import warnings -from typing import IO, ContextManager, Optional +from typing import IO, ContextManager, Optional, Union import numpy as np import torch @@ -25,6 +25,9 @@ def nullcontext(x): yield x +PathOrStr = Union[pathlib.Path, str] + + def _open_file(f, path_manager: PathManager, mode="r") -> ContextManager[IO]: if isinstance(f, str): f = path_manager.open(f, mode) diff --git a/pytorch3d/renderer/blending.py b/pytorch3d/renderer/blending.py index e177d793..d91a4969 100644 --- a/pytorch3d/renderer/blending.py +++ b/pytorch3d/renderer/blending.py @@ -20,10 +20,12 @@ from pytorch3d import _C class BlendParams(NamedTuple): sigma: float = 1e-4 gamma: float = 1e-4 - background_color: Sequence = (1.0, 1.0, 1.0) + background_color: Union[torch.Tensor, Sequence[float]] = (1.0, 1.0, 1.0) -def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor: +def hard_rgb_blend( + colors: torch.Tensor, fragments, blend_params: BlendParams +) -> torch.Tensor: """ Naive blending of top K faces to return an RGBA image - **RGB** - choose color of the closest point i.e. K=0 @@ -47,10 +49,11 @@ def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor: # Mask for the background. is_background = fragments.pix_to_face[..., 0] < 0 # (N, H, W) - if torch.is_tensor(blend_params.background_color): - background_color = blend_params.background_color.to(device) + background_color_ = blend_params.background_color + if isinstance(background_color_, torch.Tensor): + background_color = background_color_.to(device) else: - background_color = colors.new_tensor(blend_params.background_color) # (3) + background_color = colors.new_tensor(background_color_) # pyre-fixme[16] # Find out how much background_color needs to be expanded to be used for masked_scatter. num_background_pixels = is_background.sum() @@ -90,7 +93,7 @@ class _SigmoidAlphaBlend(torch.autograd.Function): _sigmoid_alpha = _SigmoidAlphaBlend.apply -def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor: +def sigmoid_alpha_blend(colors, fragments, blend_params: BlendParams) -> torch.Tensor: """ Silhouette blending to return an RGBA image - **RGB** - choose color of the closest point. @@ -121,9 +124,9 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor: def softmax_rgb_blend( - colors, + colors: torch.Tensor, fragments, - blend_params, + blend_params: BlendParams, znear: Union[float, torch.Tensor] = 1.0, zfar: Union[float, torch.Tensor] = 100, ) -> torch.Tensor: @@ -167,11 +170,11 @@ def softmax_rgb_blend( N, H, W, K = fragments.pix_to_face.shape device = fragments.pix_to_face.device pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=colors.device) - background = blend_params.background_color - if not torch.is_tensor(background): - background = torch.tensor(background, dtype=torch.float32, device=device) + background_ = blend_params.background_color + if not isinstance(background_, torch.Tensor): + background = torch.tensor(background_, dtype=torch.float32, device=device) else: - background = background.to(device) + background = background_.to(device) # Weight for background color eps = 1e-10 diff --git a/pytorch3d/renderer/cameras.py b/pytorch3d/renderer/cameras.py index 41f4d371..1c61aa6f 100644 --- a/pytorch3d/renderer/cameras.py +++ b/pytorch3d/renderer/cameras.py @@ -172,9 +172,11 @@ class CamerasBase(TensorProperties): A Transform3d object which represents a batch of transforms of shape (N, 3, 3) """ - self.R = kwargs.get("R", self.R) # pyre-ignore[16] - self.T = kwargs.get("T", self.T) # pyre-ignore[16] - world_to_view_transform = get_world_to_view_transform(R=self.R, T=self.T) + R: torch.Tensor = kwargs.get("R", self.R) + T: torch.Tensor = kwargs.get("T", self.T) + self.R = R # pyre-ignore[16] + self.T = T # pyre-ignore[16] + world_to_view_transform = get_world_to_view_transform(R=R, T=T) return world_to_view_transform def get_full_projection_transform(self, **kwargs) -> Transform3d: @@ -195,8 +197,8 @@ class CamerasBase(TensorProperties): a Transform3d object which represents a batch of transforms of shape (N, 3, 3) """ - self.R = kwargs.get("R", self.R) # pyre-ignore[16] - self.T = kwargs.get("T", self.T) # pyre-ignore[16] + self.R: torch.Tensor = kwargs.get("R", self.R) # pyre-ignore[16] + self.T: torch.Tensor = kwargs.get("T", self.T) # pyre-ignore[16] world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T) view_to_ndc_transform = self.get_projection_transform(**kwargs) return world_to_view_transform.compose(view_to_ndc_transform) @@ -293,10 +295,10 @@ def OpenGLPerspectiveCameras( aspect_ratio=1.0, fov=60.0, degrees: bool = True, - R=_R, - T=_T, + R: torch.Tensor = _R, + T: torch.Tensor = _T, device: Device = "cpu", -): +) -> "FoVPerspectiveCameras": """ OpenGLPerspectiveCameras has been DEPRECATED. Use FoVPerspectiveCameras instead. Preserving OpenGLPerspectiveCameras for backward compatibility. @@ -360,9 +362,9 @@ class FoVPerspectiveCameras(CamerasBase): aspect_ratio=1.0, fov=60.0, degrees: bool = True, - R=_R, - T=_T, - K=None, + R: torch.Tensor = _R, + T: torch.Tensor = _T, + K: Optional[torch.Tensor] = None, device: Device = "cpu", ) -> None: """ @@ -397,7 +399,7 @@ class FoVPerspectiveCameras(CamerasBase): self.degrees = degrees def compute_projection_matrix( - self, znear, zfar, fov, aspect_ratio, degrees + self, znear, zfar, fov, aspect_ratio, degrees: bool ) -> torch.Tensor: """ Compute the calibration matrix K of shape (N, 4, 4) @@ -559,10 +561,10 @@ def OpenGLOrthographicCameras( left=-1.0, right=1.0, scale_xyz=((1.0, 1.0, 1.0),), # (1, 3) - R=_R, - T=_T, - device="cpu", -): + R: torch.Tensor = _R, + T: torch.Tensor = _T, + device: Device = "cpu", +) -> "FoVOrthographicCameras": """ OpenGLOrthographicCameras has been DEPRECATED. Use FoVOrthographicCameras instead. Preserving OpenGLOrthographicCameras for backward compatibility. @@ -605,10 +607,10 @@ class FoVOrthographicCameras(CamerasBase): max_x=1.0, min_x=-1.0, scale_xyz=((1.0, 1.0, 1.0),), # (1, 3) - R=_R, - T=_T, - K=None, - device="cpu", + R: torch.Tensor = _R, + T: torch.Tensor = _T, + K: Optional[torch.Tensor] = None, + device: Device = "cpu", ): """ @@ -784,8 +786,12 @@ we assume the parameters are in screen space. def SfMPerspectiveCameras( - focal_length=1.0, principal_point=((0.0, 0.0),), R=_R, T=_T, device="cpu" -): + focal_length=1.0, + principal_point=((0.0, 0.0),), + R: torch.Tensor = _R, + T: torch.Tensor = _T, + device: Device = "cpu", +) -> "PerspectiveCameras": """ SfMPerspectiveCameras has been DEPRECATED. Use PerspectiveCameras instead. Preserving SfMPerspectiveCameras for backward compatibility. @@ -843,10 +849,10 @@ class PerspectiveCameras(CamerasBase): self, focal_length=1.0, principal_point=((0.0, 0.0),), - R=_R, - T=_T, - K=None, - device="cpu", + R: torch.Tensor = _R, + T: torch.Tensor = _T, + K: Optional[torch.Tensor] = None, + device: Device = "cpu", image_size=((-1, -1),), ) -> None: """ @@ -950,8 +956,12 @@ class PerspectiveCameras(CamerasBase): def SfMOrthographicCameras( - focal_length=1.0, principal_point=((0.0, 0.0),), R=_R, T=_T, device="cpu" -): + focal_length=1.0, + principal_point=((0.0, 0.0),), + R: torch.Tensor = _R, + T: torch.Tensor = _T, + device: Device = "cpu", +) -> "OrthographicCameras": """ SfMOrthographicCameras has been DEPRECATED. Use OrthographicCameras instead. Preserving SfMOrthographicCameras for backward compatibility. @@ -1008,10 +1018,10 @@ class OrthographicCameras(CamerasBase): self, focal_length=1.0, principal_point=((0.0, 0.0),), - R=_R, - T=_T, - K=None, - device="cpu", + R: torch.Tensor = _R, + T: torch.Tensor = _T, + K: Optional[torch.Tensor] = None, + device: Device = "cpu", image_size=((-1, -1),), ) -> None: """ @@ -1116,8 +1126,8 @@ class OrthographicCameras(CamerasBase): def _get_sfm_calibration_matrix( - N, - device, + N: int, + device: Device, focal_length, principal_point, orthographic: bool = False, @@ -1216,7 +1226,9 @@ def _get_sfm_calibration_matrix( ################################################ -def get_world_to_view_transform(R=_R, T=_T) -> Transform3d: +def get_world_to_view_transform( + R: torch.Tensor = _R, T: torch.Tensor = _T +) -> Transform3d: """ This function returns a Transform3d representing the transformation matrix to go from world space to view space by applying a rotation and @@ -1250,13 +1262,17 @@ def get_world_to_view_transform(R=_R, T=_T) -> Transform3d: raise ValueError(msg % repr(R.shape)) # Create a Transform3d object - T = Translate(T, device=T.device) - R = Rotate(R, device=R.device) - return R.compose(T) + T_ = Translate(T, device=T.device) + R_ = Rotate(R, device=R.device) + return R_.compose(T_) def camera_position_from_spherical_angles( - distance, elevation, azimuth, degrees: bool = True, device: str = "cpu" + distance: float, + elevation: float, + azimuth: float, + degrees: bool = True, + device: Device = "cpu", ) -> torch.Tensor: """ Calculate the location of the camera based on the distance away from @@ -1294,7 +1310,7 @@ def camera_position_from_spherical_angles( def look_at_rotation( - camera_position, at=((0, 0, 0),), up=((0, 1, 0),), device: str = "cpu" + camera_position, at=((0, 0, 0),), up=((0, 1, 0),), device: Device = "cpu" ) -> torch.Tensor: """ This function takes a vector 'camera_position' which specifies the location @@ -1351,7 +1367,7 @@ def look_at_view_transform( eye: Optional[Sequence] = None, at=((0, 0, 0),), # (1, 3) up=((0, 1, 0),), # (1, 3) - device="cpu", + device: Device = "cpu", ) -> Tuple[torch.Tensor, torch.Tensor]: """ This function returns a rotation and translation matrix diff --git a/pytorch3d/renderer/mesh/shader.py b/pytorch3d/renderer/mesh/shader.py index c3b441e4..6977c04b 100644 --- a/pytorch3d/renderer/mesh/shader.py +++ b/pytorch3d/renderer/mesh/shader.py @@ -5,11 +5,13 @@ # LICENSE file in the root directory of this source tree. import warnings +from typing import Optional import torch import torch.nn as nn from ...common.types import Device +from ...structures.meshes import Meshes from ..blending import ( BlendParams, hard_rgb_blend, @@ -18,6 +20,8 @@ from ..blending import ( ) from ..lighting import PointLights from ..materials import Materials +from ..utils import TensorProperties +from .rasterizer import Fragments from .shading import flat_shading, gouraud_shading, phong_shading @@ -47,10 +51,10 @@ class HardPhongShader(nn.Module): def __init__( self, device: Device = "cpu", - cameras=None, - lights=None, - materials=None, - blend_params=None, + cameras: Optional[TensorProperties] = None, + lights: Optional[TensorProperties] = None, + materials: Optional[Materials] = None, + blend_params: Optional[BlendParams] = None, ) -> None: super().__init__() self.lights = lights if lights is not None else PointLights(device=device) @@ -62,13 +66,14 @@ class HardPhongShader(nn.Module): def to(self, device: Device): # Manually move to device modules which are not subclasses of nn.Module - if self.cameras is not None: - self.cameras = self.cameras.to(device) + cameras = self.cameras + if cameras is not None: + self.cameras = cameras.to(device) self.materials = self.materials.to(device) self.lights = self.lights.to(device) return self - def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: + def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) if cameras is None: msg = "Cameras must be specified either at initialization \ @@ -108,10 +113,10 @@ class SoftPhongShader(nn.Module): def __init__( self, device: Device = "cpu", - cameras=None, - lights=None, - materials=None, - blend_params=None, + cameras: Optional[TensorProperties] = None, + lights: Optional[TensorProperties] = None, + materials: Optional[Materials] = None, + blend_params: Optional[BlendParams] = None, ) -> None: super().__init__() self.lights = lights if lights is not None else PointLights(device=device) @@ -123,13 +128,14 @@ class SoftPhongShader(nn.Module): def to(self, device: Device): # Manually move to device modules which are not subclasses of nn.Module - if self.cameras is not None: - self.cameras = self.cameras.to(device) + cameras = self.cameras + if cameras is not None: + self.cameras = cameras.to(device) self.materials = self.materials.to(device) self.lights = self.lights.to(device) return self - def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: + def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) if cameras is None: msg = "Cameras must be specified either at initialization \ @@ -174,10 +180,10 @@ class HardGouraudShader(nn.Module): def __init__( self, device: Device = "cpu", - cameras=None, - lights=None, - materials=None, - blend_params=None, + cameras: Optional[TensorProperties] = None, + lights: Optional[TensorProperties] = None, + materials: Optional[Materials] = None, + blend_params: Optional[BlendParams] = None, ) -> None: super().__init__() self.lights = lights if lights is not None else PointLights(device=device) @@ -189,13 +195,14 @@ class HardGouraudShader(nn.Module): def to(self, device: Device): # Manually move to device modules which are not subclasses of nn.Module - if self.cameras is not None: - self.cameras = self.cameras.to(device) + cameras = self.cameras + if cameras is not None: + self.cameras = cameras.to(device) self.materials = self.materials.to(device) self.lights = self.lights.to(device) return self - def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: + def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) if cameras is None: msg = "Cameras must be specified either at initialization \ @@ -239,10 +246,10 @@ class SoftGouraudShader(nn.Module): def __init__( self, device: Device = "cpu", - cameras=None, - lights=None, - materials=None, - blend_params=None, + cameras: Optional[TensorProperties] = None, + lights: Optional[TensorProperties] = None, + materials: Optional[Materials] = None, + blend_params: Optional[BlendParams] = None, ) -> None: super().__init__() self.lights = lights if lights is not None else PointLights(device=device) @@ -254,13 +261,14 @@ class SoftGouraudShader(nn.Module): def to(self, device: Device): # Manually move to device modules which are not subclasses of nn.Module - if self.cameras is not None: - self.cameras = self.cameras.to(device) + cameras = self.cameras + if cameras is not None: + self.cameras = cameras.to(device) self.materials = self.materials.to(device) self.lights = self.lights.to(device) return self - def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: + def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) if cameras is None: msg = "Cameras must be specified either at initialization \ @@ -284,7 +292,11 @@ class SoftGouraudShader(nn.Module): def TexturedSoftPhongShader( - device: Device = "cpu", cameras=None, lights=None, materials=None, blend_params=None + device: Device = "cpu", + cameras: Optional[TensorProperties] = None, + lights: Optional[TensorProperties] = None, + materials: Optional[Materials] = None, + blend_params: Optional[BlendParams] = None, ): """ TexturedSoftPhongShader class has been DEPRECATED. Use SoftPhongShader instead. @@ -321,10 +333,10 @@ class HardFlatShader(nn.Module): def __init__( self, device: Device = "cpu", - cameras=None, - lights=None, - materials=None, - blend_params=None, + cameras: Optional[TensorProperties] = None, + lights: Optional[TensorProperties] = None, + materials: Optional[Materials] = None, + blend_params: Optional[BlendParams] = None, ) -> None: super().__init__() self.lights = lights if lights is not None else PointLights(device=device) @@ -336,13 +348,14 @@ class HardFlatShader(nn.Module): def to(self, device: Device): # Manually move to device modules which are not subclasses of nn.Module - if self.cameras is not None: - self.cameras = self.cameras.to(device) + cameras = self.cameras + if cameras is not None: + self.cameras = cameras.to(device) self.materials = self.materials.to(device) self.lights = self.lights.to(device) return self - def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: + def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) if cameras is None: msg = "Cameras must be specified either at initialization \ @@ -381,11 +394,11 @@ class SoftSilhouetteShader(nn.Module): 3D Reasoning', ICCV 2019 """ - def __init__(self, blend_params=None) -> None: + def __init__(self, blend_params: Optional[BlendParams] = None) -> None: super().__init__() self.blend_params = blend_params if blend_params is not None else BlendParams() - def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: + def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: """ Only want to render the silhouette so RGB values can be ones. There is no need for lighting or texturing diff --git a/pytorch3d/renderer/mesh/textures.py b/pytorch3d/renderer/mesh/textures.py index b002ddb2..7e73bfa5 100644 --- a/pytorch3d/renderer/mesh/textures.py +++ b/pytorch3d/renderer/mesh/textures.py @@ -223,7 +223,7 @@ class TexturesBase: return new_props - def sample_textures(self): + def sample_textures(self) -> torch.Tensor: """ Different texture classes sample textures in different ways e.g. for vertex textures, the values at each vertex @@ -237,7 +237,7 @@ class TexturesBase: """ raise NotImplementedError() - def faces_verts_textures_packed(self): + def faces_verts_textures_packed(self) -> torch.Tensor: """ Returns the texture for each vertex for each face in the mesh. For N meshes, this function returns sum(Fi)x3xC where Fi is the @@ -248,14 +248,14 @@ class TexturesBase: """ raise NotImplementedError() - def clone(self): + def clone(self) -> "TexturesBase": """ Each texture class should implement a method to clone all necessary internal tensors. """ raise NotImplementedError() - def detach(self): + def detach(self) -> "TexturesBase": """ Each texture class should implement a method to detach all necessary internal tensors. @@ -394,7 +394,7 @@ class TexturesAtlas(TexturesBase): # refer to the __init__ of Meshes. self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device) - def clone(self): + def clone(self) -> "TexturesAtlas": tex = self.__class__(atlas=self.atlas_padded().clone()) if self._atlas_list is not None: tex._atlas_list = [atlas.clone() for atlas in self._atlas_list] @@ -407,7 +407,7 @@ class TexturesAtlas(TexturesBase): tex._num_faces_per_mesh = num_faces return tex - def detach(self): + def detach(self) -> "TexturesAtlas": tex = self.__class__(atlas=self.atlas_padded().detach()) if self._atlas_list is not None: tex._atlas_list = [atlas.detach() for atlas in self._atlas_list] diff --git a/pytorch3d/renderer/utils.py b/pytorch3d/renderer/utils.py index 99bb4157..cf0d1187 100644 --- a/pytorch3d/renderer/utils.py +++ b/pytorch3d/renderer/utils.py @@ -160,7 +160,7 @@ class TensorProperties(nn.Module): msg = "Expected index of type int or slice; got %r" raise ValueError(msg % type(index)) - def to(self, device: Device = "cpu"): + def to(self, device: Device = "cpu") -> "TensorProperties": """ In place operation to move class properties which are tensors to a specified device. If self has a property "device", update this as well. @@ -174,7 +174,7 @@ class TensorProperties(nn.Module): setattr(self, k, v.to(device_)) return self - def clone(self, other): + def clone(self, other) -> "TensorProperties": """ Update the tensor properties of other with the cloned properties of self. """ @@ -189,7 +189,7 @@ class TensorProperties(nn.Module): setattr(other, k, v_clone) return other - def gather_props(self, batch_idx): + def gather_props(self, batch_idx) -> "TensorProperties": """ This is an in place operation to reformat all tensor class attributes based on a set of given indices using torch.gather. This is useful when diff --git a/pytorch3d/structures/volumes.py b/pytorch3d/structures/volumes.py index f6bcd2df..2b798820 100644 --- a/pytorch3d/structures/volumes.py +++ b/pytorch3d/structures/volumes.py @@ -187,15 +187,15 @@ class Volumes: """ # handle densities - densities, grid_sizes = self._convert_densities_features_to_tensor( + densities_, grid_sizes = self._convert_densities_features_to_tensor( densities, "densities" ) # take device from densities - self.device = densities.device + self.device = densities_.device # assign to the internal buffers - self._densities = densities + self._densities = densities_ self._grid_sizes = grid_sizes # handle features @@ -497,7 +497,6 @@ class Volumes: ) def __len__(self) -> int: - # pyre-fixme[16]: `List` has no attribute `shape`. return self._densities.shape[0] def __getitem__( @@ -547,8 +546,6 @@ class Volumes: Returns: **densities**: The tensor of volume densities. """ - # pyre-fixme[7]: Expected `Tensor` but got `Union[List[torch.Tensor], - # torch.Tensor]`. return self._densities def densities_list(self) -> List[torch.Tensor]: @@ -723,7 +720,6 @@ class Volumes: return other other.device = device_ - # pyre-fixme[16]: `List` has no attribute `to`. other._densities = self._densities.to(device_) if self._features is not None: # pyre-fixme[16]: `Optional` has no attribute `to`. diff --git a/pytorch3d/transforms/rotation_conversions.py b/pytorch3d/transforms/rotation_conversions.py index 03ae3409..396eaa18 100644 --- a/pytorch3d/transforms/rotation_conversions.py +++ b/pytorch3d/transforms/rotation_conversions.py @@ -10,6 +10,8 @@ from typing import Optional import torch import torch.nn.functional as F +from ..common.types import Device + """ The transformation matrices returned from the functions in this file assume @@ -286,7 +288,9 @@ def matrix_to_euler_angles(matrix, convention: str): return torch.stack(o, -1) -def random_quaternions(n: int, dtype: Optional[torch.dtype] = None, device=None): +def random_quaternions( + n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None +): """ Generate random quaternions representing rotations, i.e. versors with nonnegative real part. @@ -306,7 +310,9 @@ def random_quaternions(n: int, dtype: Optional[torch.dtype] = None, device=None) return o -def random_rotations(n: int, dtype: Optional[torch.dtype] = None, device=None): +def random_rotations( + n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None +): """ Generate random rotations as 3x3 rotation matrices. @@ -323,7 +329,9 @@ def random_rotations(n: int, dtype: Optional[torch.dtype] = None, device=None): return quaternion_to_matrix(quaternions) -def random_rotation(dtype: Optional[torch.dtype] = None, device=None): +def random_rotation( + dtype: Optional[torch.dtype] = None, device: Optional[Device] = None +): """ Generate a single random 3x3 rotation matrix.