More type annotations

Summary: More type annotations: device, shaders, pluggable I/O, stats in NeRF project, cameras, textures, etc...

Reviewed By: nikhilaravi

Differential Revision: D29327396

fbshipit-source-id: cdf0ceaaa010e22423088752688c8dd81f1acc3c
This commit is contained in:
Patrick Labatut 2021-06-25 19:55:26 -07:00 committed by Facebook GitHub Bot
parent 542e2e7c07
commit f593bfd3c2
15 changed files with 196 additions and 148 deletions

View File

@ -136,7 +136,7 @@ def download_data(
dataset_names: Optional[List[str]] = None,
data_root: str = DEFAULT_DATA_ROOT,
url_root: str = DEFAULT_URL_ROOT,
):
) -> None:
"""
Downloads the relevant dataset files.

View File

@ -29,7 +29,7 @@ class AverageMeter:
self.history = []
self.reset()
def reset(self):
def reset(self) -> None:
"""
Reset the running average meter.
"""
@ -38,7 +38,7 @@ class AverageMeter:
self.sum = 0
self.count = 0
def update(self, val: float, n: int = 1, epoch: int = 0):
def update(self, val: float, n: int = 1, epoch: int = 0) -> None:
"""
Updates the average meter with a value `val`.
@ -123,7 +123,7 @@ class Stats:
self.plot_file = plot_file
self.hard_reset(epoch=epoch)
def reset(self):
def reset(self) -> None:
"""
Called before an epoch to clear current epoch buffers.
"""
@ -138,7 +138,7 @@ class Stats:
# Set a new timestamp.
self._epoch_start = time.time()
def hard_reset(self, epoch: int = -1):
def hard_reset(self, epoch: int = -1) -> None:
"""
Erases all logged data.
"""
@ -149,7 +149,7 @@ class Stats:
self.stats = {}
self.reset()
def new_epoch(self):
def new_epoch(self) -> None:
"""
Initializes a new epoch.
"""
@ -166,7 +166,7 @@ class Stats:
val = float(val.sum())
return val
def update(self, preds: dict, stat_set: str = "train"):
def update(self, preds: dict, stat_set: str = "train") -> None:
"""
Update the internal logs with metrics of a training step.
@ -211,7 +211,7 @@ class Stats:
if val is not None:
self.stats[stat_set][stat].update(val, epoch=epoch, n=1)
def print(self, max_it: Optional[int] = None, stat_set: str = "train"):
def print(self, max_it: Optional[int] = None, stat_set: str = "train") -> None:
"""
Print the current values of all stored stats.
@ -247,7 +247,7 @@ class Stats:
viz: Visdom = None,
visdom_env: Optional[str] = None,
plot_file: Optional[str] = None,
):
) -> None:
"""
Plot the line charts of the history of the stats.

View File

@ -42,14 +42,13 @@ from base64 import b64decode
from collections import deque
from enum import IntEnum
from io import BytesIO
from pathlib import Path
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union, cast
import numpy as np
import torch
from iopath.common.file_io import PathManager
from PIL import Image
from pytorch3d.io.utils import _open_file
from pytorch3d.io.utils import PathOrStr, _open_file
from pytorch3d.renderer.mesh import TexturesBase, TexturesUV, TexturesVertex
from pytorch3d.structures import Meshes, join_meshes_as_scene
from pytorch3d.transforms import Transform3d, quaternion_to_matrix
@ -498,7 +497,7 @@ class _GLTFLoader:
def load_meshes(
path: Union[str, Path],
path: PathOrStr,
path_manager: PathManager,
include_textures: bool = True,
) -> List[Tuple[Optional[str], Meshes]]:
@ -544,7 +543,7 @@ class MeshGlbFormat(MeshFormatInterpreter):
def read(
self,
path: Union[str, Path],
path: PathOrStr,
include_textures: bool,
device,
path_manager: PathManager,
@ -566,7 +565,7 @@ class MeshGlbFormat(MeshFormatInterpreter):
def save(
self,
data: Meshes,
path: Union[str, Path],
path: PathOrStr,
path_manager: PathManager,
binary: Optional[bool],
**kwargs,

View File

@ -18,7 +18,12 @@ from iopath.common.file_io import PathManager
from PIL import Image
from pytorch3d.common.types import Device
from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
from pytorch3d.io.utils import (
PathOrStr,
_check_faces_indices,
_make_tensor,
_open_file,
)
from pytorch3d.renderer import TexturesAtlas, TexturesUV
from pytorch3d.structures import Meshes, join_meshes_as_batch
@ -213,7 +218,7 @@ def load_obj(
None.
"""
data_dir = "./"
if isinstance(f, (str, bytes, os.PathLike)):
if isinstance(f, (str, bytes, Path)):
data_dir = os.path.dirname(f)
if path_manager is None:
path_manager = PathManager()
@ -297,7 +302,7 @@ class MeshObjFormat(MeshFormatInterpreter):
def read(
self,
path: Union[str, Path],
path: PathOrStr,
include_textures: bool,
device: Device,
path_manager: PathManager,
@ -322,7 +327,7 @@ class MeshObjFormat(MeshFormatInterpreter):
def save(
self,
data: Meshes,
path: Union[str, Path],
path: PathOrStr,
path_manager: PathManager,
binary: Optional[bool],
decimal_places: Optional[int] = None,
@ -650,7 +655,7 @@ def _load_obj(
def save_obj(
f: Union[str, os.PathLike],
f: PathOrStr,
verts,
faces,
decimal_places: Optional[int] = None,

View File

@ -13,13 +13,12 @@ This format is introduced, for example, at
http://www.geomview.org/docs/html/OFF.html .
"""
import warnings
from pathlib import Path
from typing import Optional, Tuple, Union, cast
import numpy as np
import torch
from iopath.common.file_io import PathManager
from pytorch3d.io.utils import _check_faces_indices, _open_file
from pytorch3d.io.utils import PathOrStr, _check_faces_indices, _open_file
from pytorch3d.renderer import TexturesAtlas, TexturesVertex
from pytorch3d.structures import Meshes
@ -424,7 +423,7 @@ class MeshOffFormat(MeshFormatInterpreter):
def read(
self,
path: Union[str, Path],
path: PathOrStr,
include_textures: bool,
device,
path_manager: PathManager,
@ -460,7 +459,7 @@ class MeshOffFormat(MeshFormatInterpreter):
def save(
self,
data: Meshes,
path: Union[str, Path],
path: PathOrStr,
path_manager: PathManager,
binary: Optional[bool],
decimal_places: Optional[int] = None,

View File

@ -5,10 +5,12 @@
# LICENSE file in the root directory of this source tree.
from pathlib import Path
import pathlib
from typing import Optional, Tuple, Union
from iopath.common.file_io import PathManager
from pytorch3d.common.types import Device
from pytorch3d.io.utils import PathOrStr
from pytorch3d.structures import Meshes, Pointclouds
@ -20,14 +22,14 @@ its load_* and save_* functions.
"""
def endswith(path, suffixes: Tuple[str, ...]) -> bool:
def endswith(path: PathOrStr, suffixes: Tuple[str, ...]) -> bool:
"""
Returns whether the path ends with one of the given suffixes.
If `path` is not actually a path, returns True. This is useful
for allowing interpreters to bypass inappropriate paths, but
always accepting streams.
"""
if isinstance(path, Path):
if isinstance(path, pathlib.Path):
return path.suffix.lower() in suffixes
if isinstance(path, str):
return path.lower().endswith(suffixes)
@ -42,9 +44,9 @@ class MeshFormatInterpreter:
def read(
self,
path: Union[str, Path],
path: PathOrStr,
include_textures: bool,
device,
device: Device,
path_manager: PathManager,
**kwargs,
) -> Optional[Meshes]:
@ -68,7 +70,7 @@ class MeshFormatInterpreter:
def save(
self,
data: Meshes,
path: Union[str, Path],
path: PathOrStr,
path_manager: PathManager,
binary: Optional[bool],
**kwargs,
@ -96,7 +98,7 @@ class PointcloudFormatInterpreter:
"""
def read(
self, path: Union[str, Path], device, path_manager: PathManager, **kwargs
self, path: PathOrStr, device: Device, path_manager: PathManager, **kwargs
) -> Optional[Pointclouds]:
"""
Read the data from the specified file and return it as
@ -117,7 +119,7 @@ class PointcloudFormatInterpreter:
def save(
self,
data: Pointclouds,
path: Union[str, Path],
path: PathOrStr,
path_manager: PathManager,
binary: Optional[bool],
**kwargs,

View File

@ -15,13 +15,17 @@ import sys
import warnings
from collections import namedtuple
from io import BytesIO, TextIOBase
from pathlib import Path
from typing import List, Optional, Tuple, Union, cast
import numpy as np
import torch
from iopath.common.file_io import PathManager
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
from pytorch3d.io.utils import (
PathOrStr,
_check_faces_indices,
_make_tensor,
_open_file,
)
from pytorch3d.renderer import TexturesVertex
from pytorch3d.structures import Meshes, Pointclouds
@ -1237,7 +1241,7 @@ class MeshPlyFormat(MeshFormatInterpreter):
def read(
self,
path: Union[str, Path],
path: PathOrStr,
include_textures: bool,
device,
path_manager: PathManager,
@ -1269,7 +1273,7 @@ class MeshPlyFormat(MeshFormatInterpreter):
def save(
self,
data: Meshes,
path: Union[str, Path],
path: PathOrStr,
path_manager: PathManager,
binary: Optional[bool],
decimal_places: Optional[int] = None,
@ -1318,7 +1322,7 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter):
def read(
self,
path: Union[str, Path],
path: PathOrStr,
device,
path_manager: PathManager,
**kwargs,
@ -1339,7 +1343,7 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter):
def save(
self,
data: Pointclouds,
path: Union[str, Path],
path: PathOrStr,
path_manager: PathManager,
binary: Optional[bool],
decimal_places: Optional[int] = None,

View File

@ -7,7 +7,7 @@
import contextlib
import pathlib
import warnings
from typing import IO, ContextManager, Optional
from typing import IO, ContextManager, Optional, Union
import numpy as np
import torch
@ -25,6 +25,9 @@ def nullcontext(x):
yield x
PathOrStr = Union[pathlib.Path, str]
def _open_file(f, path_manager: PathManager, mode="r") -> ContextManager[IO]:
if isinstance(f, str):
f = path_manager.open(f, mode)

View File

@ -20,10 +20,12 @@ from pytorch3d import _C
class BlendParams(NamedTuple):
sigma: float = 1e-4
gamma: float = 1e-4
background_color: Sequence = (1.0, 1.0, 1.0)
background_color: Union[torch.Tensor, Sequence[float]] = (1.0, 1.0, 1.0)
def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
def hard_rgb_blend(
colors: torch.Tensor, fragments, blend_params: BlendParams
) -> torch.Tensor:
"""
Naive blending of top K faces to return an RGBA image
- **RGB** - choose color of the closest point i.e. K=0
@ -47,10 +49,11 @@ def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
# Mask for the background.
is_background = fragments.pix_to_face[..., 0] < 0 # (N, H, W)
if torch.is_tensor(blend_params.background_color):
background_color = blend_params.background_color.to(device)
background_color_ = blend_params.background_color
if isinstance(background_color_, torch.Tensor):
background_color = background_color_.to(device)
else:
background_color = colors.new_tensor(blend_params.background_color) # (3)
background_color = colors.new_tensor(background_color_) # pyre-fixme[16]
# Find out how much background_color needs to be expanded to be used for masked_scatter.
num_background_pixels = is_background.sum()
@ -90,7 +93,7 @@ class _SigmoidAlphaBlend(torch.autograd.Function):
_sigmoid_alpha = _SigmoidAlphaBlend.apply
def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
def sigmoid_alpha_blend(colors, fragments, blend_params: BlendParams) -> torch.Tensor:
"""
Silhouette blending to return an RGBA image
- **RGB** - choose color of the closest point.
@ -121,9 +124,9 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
def softmax_rgb_blend(
colors,
colors: torch.Tensor,
fragments,
blend_params,
blend_params: BlendParams,
znear: Union[float, torch.Tensor] = 1.0,
zfar: Union[float, torch.Tensor] = 100,
) -> torch.Tensor:
@ -167,11 +170,11 @@ def softmax_rgb_blend(
N, H, W, K = fragments.pix_to_face.shape
device = fragments.pix_to_face.device
pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=colors.device)
background = blend_params.background_color
if not torch.is_tensor(background):
background = torch.tensor(background, dtype=torch.float32, device=device)
background_ = blend_params.background_color
if not isinstance(background_, torch.Tensor):
background = torch.tensor(background_, dtype=torch.float32, device=device)
else:
background = background.to(device)
background = background_.to(device)
# Weight for background color
eps = 1e-10

View File

@ -172,9 +172,11 @@ class CamerasBase(TensorProperties):
A Transform3d object which represents a batch of transforms
of shape (N, 3, 3)
"""
self.R = kwargs.get("R", self.R) # pyre-ignore[16]
self.T = kwargs.get("T", self.T) # pyre-ignore[16]
world_to_view_transform = get_world_to_view_transform(R=self.R, T=self.T)
R: torch.Tensor = kwargs.get("R", self.R)
T: torch.Tensor = kwargs.get("T", self.T)
self.R = R # pyre-ignore[16]
self.T = T # pyre-ignore[16]
world_to_view_transform = get_world_to_view_transform(R=R, T=T)
return world_to_view_transform
def get_full_projection_transform(self, **kwargs) -> Transform3d:
@ -195,8 +197,8 @@ class CamerasBase(TensorProperties):
a Transform3d object which represents a batch of transforms
of shape (N, 3, 3)
"""
self.R = kwargs.get("R", self.R) # pyre-ignore[16]
self.T = kwargs.get("T", self.T) # pyre-ignore[16]
self.R: torch.Tensor = kwargs.get("R", self.R) # pyre-ignore[16]
self.T: torch.Tensor = kwargs.get("T", self.T) # pyre-ignore[16]
world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T)
view_to_ndc_transform = self.get_projection_transform(**kwargs)
return world_to_view_transform.compose(view_to_ndc_transform)
@ -293,10 +295,10 @@ def OpenGLPerspectiveCameras(
aspect_ratio=1.0,
fov=60.0,
degrees: bool = True,
R=_R,
T=_T,
R: torch.Tensor = _R,
T: torch.Tensor = _T,
device: Device = "cpu",
):
) -> "FoVPerspectiveCameras":
"""
OpenGLPerspectiveCameras has been DEPRECATED. Use FoVPerspectiveCameras instead.
Preserving OpenGLPerspectiveCameras for backward compatibility.
@ -360,9 +362,9 @@ class FoVPerspectiveCameras(CamerasBase):
aspect_ratio=1.0,
fov=60.0,
degrees: bool = True,
R=_R,
T=_T,
K=None,
R: torch.Tensor = _R,
T: torch.Tensor = _T,
K: Optional[torch.Tensor] = None,
device: Device = "cpu",
) -> None:
"""
@ -397,7 +399,7 @@ class FoVPerspectiveCameras(CamerasBase):
self.degrees = degrees
def compute_projection_matrix(
self, znear, zfar, fov, aspect_ratio, degrees
self, znear, zfar, fov, aspect_ratio, degrees: bool
) -> torch.Tensor:
"""
Compute the calibration matrix K of shape (N, 4, 4)
@ -559,10 +561,10 @@ def OpenGLOrthographicCameras(
left=-1.0,
right=1.0,
scale_xyz=((1.0, 1.0, 1.0),), # (1, 3)
R=_R,
T=_T,
device="cpu",
):
R: torch.Tensor = _R,
T: torch.Tensor = _T,
device: Device = "cpu",
) -> "FoVOrthographicCameras":
"""
OpenGLOrthographicCameras has been DEPRECATED. Use FoVOrthographicCameras instead.
Preserving OpenGLOrthographicCameras for backward compatibility.
@ -605,10 +607,10 @@ class FoVOrthographicCameras(CamerasBase):
max_x=1.0,
min_x=-1.0,
scale_xyz=((1.0, 1.0, 1.0),), # (1, 3)
R=_R,
T=_T,
K=None,
device="cpu",
R: torch.Tensor = _R,
T: torch.Tensor = _T,
K: Optional[torch.Tensor] = None,
device: Device = "cpu",
):
"""
@ -784,8 +786,12 @@ we assume the parameters are in screen space.
def SfMPerspectiveCameras(
focal_length=1.0, principal_point=((0.0, 0.0),), R=_R, T=_T, device="cpu"
):
focal_length=1.0,
principal_point=((0.0, 0.0),),
R: torch.Tensor = _R,
T: torch.Tensor = _T,
device: Device = "cpu",
) -> "PerspectiveCameras":
"""
SfMPerspectiveCameras has been DEPRECATED. Use PerspectiveCameras instead.
Preserving SfMPerspectiveCameras for backward compatibility.
@ -843,10 +849,10 @@ class PerspectiveCameras(CamerasBase):
self,
focal_length=1.0,
principal_point=((0.0, 0.0),),
R=_R,
T=_T,
K=None,
device="cpu",
R: torch.Tensor = _R,
T: torch.Tensor = _T,
K: Optional[torch.Tensor] = None,
device: Device = "cpu",
image_size=((-1, -1),),
) -> None:
"""
@ -950,8 +956,12 @@ class PerspectiveCameras(CamerasBase):
def SfMOrthographicCameras(
focal_length=1.0, principal_point=((0.0, 0.0),), R=_R, T=_T, device="cpu"
):
focal_length=1.0,
principal_point=((0.0, 0.0),),
R: torch.Tensor = _R,
T: torch.Tensor = _T,
device: Device = "cpu",
) -> "OrthographicCameras":
"""
SfMOrthographicCameras has been DEPRECATED. Use OrthographicCameras instead.
Preserving SfMOrthographicCameras for backward compatibility.
@ -1008,10 +1018,10 @@ class OrthographicCameras(CamerasBase):
self,
focal_length=1.0,
principal_point=((0.0, 0.0),),
R=_R,
T=_T,
K=None,
device="cpu",
R: torch.Tensor = _R,
T: torch.Tensor = _T,
K: Optional[torch.Tensor] = None,
device: Device = "cpu",
image_size=((-1, -1),),
) -> None:
"""
@ -1116,8 +1126,8 @@ class OrthographicCameras(CamerasBase):
def _get_sfm_calibration_matrix(
N,
device,
N: int,
device: Device,
focal_length,
principal_point,
orthographic: bool = False,
@ -1216,7 +1226,9 @@ def _get_sfm_calibration_matrix(
################################################
def get_world_to_view_transform(R=_R, T=_T) -> Transform3d:
def get_world_to_view_transform(
R: torch.Tensor = _R, T: torch.Tensor = _T
) -> Transform3d:
"""
This function returns a Transform3d representing the transformation
matrix to go from world space to view space by applying a rotation and
@ -1250,13 +1262,17 @@ def get_world_to_view_transform(R=_R, T=_T) -> Transform3d:
raise ValueError(msg % repr(R.shape))
# Create a Transform3d object
T = Translate(T, device=T.device)
R = Rotate(R, device=R.device)
return R.compose(T)
T_ = Translate(T, device=T.device)
R_ = Rotate(R, device=R.device)
return R_.compose(T_)
def camera_position_from_spherical_angles(
distance, elevation, azimuth, degrees: bool = True, device: str = "cpu"
distance: float,
elevation: float,
azimuth: float,
degrees: bool = True,
device: Device = "cpu",
) -> torch.Tensor:
"""
Calculate the location of the camera based on the distance away from
@ -1294,7 +1310,7 @@ def camera_position_from_spherical_angles(
def look_at_rotation(
camera_position, at=((0, 0, 0),), up=((0, 1, 0),), device: str = "cpu"
camera_position, at=((0, 0, 0),), up=((0, 1, 0),), device: Device = "cpu"
) -> torch.Tensor:
"""
This function takes a vector 'camera_position' which specifies the location
@ -1351,7 +1367,7 @@ def look_at_view_transform(
eye: Optional[Sequence] = None,
at=((0, 0, 0),), # (1, 3)
up=((0, 1, 0),), # (1, 3)
device="cpu",
device: Device = "cpu",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
This function returns a rotation and translation matrix

View File

@ -5,11 +5,13 @@
# LICENSE file in the root directory of this source tree.
import warnings
from typing import Optional
import torch
import torch.nn as nn
from ...common.types import Device
from ...structures.meshes import Meshes
from ..blending import (
BlendParams,
hard_rgb_blend,
@ -18,6 +20,8 @@ from ..blending import (
)
from ..lighting import PointLights
from ..materials import Materials
from ..utils import TensorProperties
from .rasterizer import Fragments
from .shading import flat_shading, gouraud_shading, phong_shading
@ -47,10 +51,10 @@ class HardPhongShader(nn.Module):
def __init__(
self,
device: Device = "cpu",
cameras=None,
lights=None,
materials=None,
blend_params=None,
cameras: Optional[TensorProperties] = None,
lights: Optional[TensorProperties] = None,
materials: Optional[Materials] = None,
blend_params: Optional[BlendParams] = None,
) -> None:
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
@ -62,13 +66,14 @@ class HardPhongShader(nn.Module):
def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module
if self.cameras is not None:
self.cameras = self.cameras.to(device)
cameras = self.cameras
if cameras is not None:
self.cameras = cameras.to(device)
self.materials = self.materials.to(device)
self.lights = self.lights.to(device)
return self
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
@ -108,10 +113,10 @@ class SoftPhongShader(nn.Module):
def __init__(
self,
device: Device = "cpu",
cameras=None,
lights=None,
materials=None,
blend_params=None,
cameras: Optional[TensorProperties] = None,
lights: Optional[TensorProperties] = None,
materials: Optional[Materials] = None,
blend_params: Optional[BlendParams] = None,
) -> None:
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
@ -123,13 +128,14 @@ class SoftPhongShader(nn.Module):
def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module
if self.cameras is not None:
self.cameras = self.cameras.to(device)
cameras = self.cameras
if cameras is not None:
self.cameras = cameras.to(device)
self.materials = self.materials.to(device)
self.lights = self.lights.to(device)
return self
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
@ -174,10 +180,10 @@ class HardGouraudShader(nn.Module):
def __init__(
self,
device: Device = "cpu",
cameras=None,
lights=None,
materials=None,
blend_params=None,
cameras: Optional[TensorProperties] = None,
lights: Optional[TensorProperties] = None,
materials: Optional[Materials] = None,
blend_params: Optional[BlendParams] = None,
) -> None:
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
@ -189,13 +195,14 @@ class HardGouraudShader(nn.Module):
def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module
if self.cameras is not None:
self.cameras = self.cameras.to(device)
cameras = self.cameras
if cameras is not None:
self.cameras = cameras.to(device)
self.materials = self.materials.to(device)
self.lights = self.lights.to(device)
return self
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
@ -239,10 +246,10 @@ class SoftGouraudShader(nn.Module):
def __init__(
self,
device: Device = "cpu",
cameras=None,
lights=None,
materials=None,
blend_params=None,
cameras: Optional[TensorProperties] = None,
lights: Optional[TensorProperties] = None,
materials: Optional[Materials] = None,
blend_params: Optional[BlendParams] = None,
) -> None:
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
@ -254,13 +261,14 @@ class SoftGouraudShader(nn.Module):
def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module
if self.cameras is not None:
self.cameras = self.cameras.to(device)
cameras = self.cameras
if cameras is not None:
self.cameras = cameras.to(device)
self.materials = self.materials.to(device)
self.lights = self.lights.to(device)
return self
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
@ -284,7 +292,11 @@ class SoftGouraudShader(nn.Module):
def TexturedSoftPhongShader(
device: Device = "cpu", cameras=None, lights=None, materials=None, blend_params=None
device: Device = "cpu",
cameras: Optional[TensorProperties] = None,
lights: Optional[TensorProperties] = None,
materials: Optional[Materials] = None,
blend_params: Optional[BlendParams] = None,
):
"""
TexturedSoftPhongShader class has been DEPRECATED. Use SoftPhongShader instead.
@ -321,10 +333,10 @@ class HardFlatShader(nn.Module):
def __init__(
self,
device: Device = "cpu",
cameras=None,
lights=None,
materials=None,
blend_params=None,
cameras: Optional[TensorProperties] = None,
lights: Optional[TensorProperties] = None,
materials: Optional[Materials] = None,
blend_params: Optional[BlendParams] = None,
) -> None:
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
@ -336,13 +348,14 @@ class HardFlatShader(nn.Module):
def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module
if self.cameras is not None:
self.cameras = self.cameras.to(device)
cameras = self.cameras
if cameras is not None:
self.cameras = cameras.to(device)
self.materials = self.materials.to(device)
self.lights = self.lights.to(device)
return self
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
@ -381,11 +394,11 @@ class SoftSilhouetteShader(nn.Module):
3D Reasoning', ICCV 2019
"""
def __init__(self, blend_params=None) -> None:
def __init__(self, blend_params: Optional[BlendParams] = None) -> None:
super().__init__()
self.blend_params = blend_params if blend_params is not None else BlendParams()
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
"""
Only want to render the silhouette so RGB values can be ones.
There is no need for lighting or texturing

View File

@ -223,7 +223,7 @@ class TexturesBase:
return new_props
def sample_textures(self):
def sample_textures(self) -> torch.Tensor:
"""
Different texture classes sample textures in different ways
e.g. for vertex textures, the values at each vertex
@ -237,7 +237,7 @@ class TexturesBase:
"""
raise NotImplementedError()
def faces_verts_textures_packed(self):
def faces_verts_textures_packed(self) -> torch.Tensor:
"""
Returns the texture for each vertex for each face in the mesh.
For N meshes, this function returns sum(Fi)x3xC where Fi is the
@ -248,14 +248,14 @@ class TexturesBase:
"""
raise NotImplementedError()
def clone(self):
def clone(self) -> "TexturesBase":
"""
Each texture class should implement a method
to clone all necessary internal tensors.
"""
raise NotImplementedError()
def detach(self):
def detach(self) -> "TexturesBase":
"""
Each texture class should implement a method
to detach all necessary internal tensors.
@ -394,7 +394,7 @@ class TexturesAtlas(TexturesBase):
# refer to the __init__ of Meshes.
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)
def clone(self):
def clone(self) -> "TexturesAtlas":
tex = self.__class__(atlas=self.atlas_padded().clone())
if self._atlas_list is not None:
tex._atlas_list = [atlas.clone() for atlas in self._atlas_list]
@ -407,7 +407,7 @@ class TexturesAtlas(TexturesBase):
tex._num_faces_per_mesh = num_faces
return tex
def detach(self):
def detach(self) -> "TexturesAtlas":
tex = self.__class__(atlas=self.atlas_padded().detach())
if self._atlas_list is not None:
tex._atlas_list = [atlas.detach() for atlas in self._atlas_list]

View File

@ -160,7 +160,7 @@ class TensorProperties(nn.Module):
msg = "Expected index of type int or slice; got %r"
raise ValueError(msg % type(index))
def to(self, device: Device = "cpu"):
def to(self, device: Device = "cpu") -> "TensorProperties":
"""
In place operation to move class properties which are tensors to a
specified device. If self has a property "device", update this as well.
@ -174,7 +174,7 @@ class TensorProperties(nn.Module):
setattr(self, k, v.to(device_))
return self
def clone(self, other):
def clone(self, other) -> "TensorProperties":
"""
Update the tensor properties of other with the cloned properties of self.
"""
@ -189,7 +189,7 @@ class TensorProperties(nn.Module):
setattr(other, k, v_clone)
return other
def gather_props(self, batch_idx):
def gather_props(self, batch_idx) -> "TensorProperties":
"""
This is an in place operation to reformat all tensor class attributes
based on a set of given indices using torch.gather. This is useful when

View File

@ -187,15 +187,15 @@ class Volumes:
"""
# handle densities
densities, grid_sizes = self._convert_densities_features_to_tensor(
densities_, grid_sizes = self._convert_densities_features_to_tensor(
densities, "densities"
)
# take device from densities
self.device = densities.device
self.device = densities_.device
# assign to the internal buffers
self._densities = densities
self._densities = densities_
self._grid_sizes = grid_sizes
# handle features
@ -497,7 +497,6 @@ class Volumes:
)
def __len__(self) -> int:
# pyre-fixme[16]: `List` has no attribute `shape`.
return self._densities.shape[0]
def __getitem__(
@ -547,8 +546,6 @@ class Volumes:
Returns:
**densities**: The tensor of volume densities.
"""
# pyre-fixme[7]: Expected `Tensor` but got `Union[List[torch.Tensor],
# torch.Tensor]`.
return self._densities
def densities_list(self) -> List[torch.Tensor]:
@ -723,7 +720,6 @@ class Volumes:
return other
other.device = device_
# pyre-fixme[16]: `List` has no attribute `to`.
other._densities = self._densities.to(device_)
if self._features is not None:
# pyre-fixme[16]: `Optional` has no attribute `to`.

View File

@ -10,6 +10,8 @@ from typing import Optional
import torch
import torch.nn.functional as F
from ..common.types import Device
"""
The transformation matrices returned from the functions in this file assume
@ -286,7 +288,9 @@ def matrix_to_euler_angles(matrix, convention: str):
return torch.stack(o, -1)
def random_quaternions(n: int, dtype: Optional[torch.dtype] = None, device=None):
def random_quaternions(
n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
):
"""
Generate random quaternions representing rotations,
i.e. versors with nonnegative real part.
@ -306,7 +310,9 @@ def random_quaternions(n: int, dtype: Optional[torch.dtype] = None, device=None)
return o
def random_rotations(n: int, dtype: Optional[torch.dtype] = None, device=None):
def random_rotations(
n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
):
"""
Generate random rotations as 3x3 rotation matrices.
@ -323,7 +329,9 @@ def random_rotations(n: int, dtype: Optional[torch.dtype] = None, device=None):
return quaternion_to_matrix(quaternions)
def random_rotation(dtype: Optional[torch.dtype] = None, device=None):
def random_rotation(
dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
):
"""
Generate a single random 3x3 rotation matrix.