mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
More type annotations
Summary: More type annotations: device, shaders, pluggable I/O, stats in NeRF project, cameras, textures, etc... Reviewed By: nikhilaravi Differential Revision: D29327396 fbshipit-source-id: cdf0ceaaa010e22423088752688c8dd81f1acc3c
This commit is contained in:
parent
542e2e7c07
commit
f593bfd3c2
@ -136,7 +136,7 @@ def download_data(
|
||||
dataset_names: Optional[List[str]] = None,
|
||||
data_root: str = DEFAULT_DATA_ROOT,
|
||||
url_root: str = DEFAULT_URL_ROOT,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Downloads the relevant dataset files.
|
||||
|
||||
|
@ -29,7 +29,7 @@ class AverageMeter:
|
||||
self.history = []
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Reset the running average meter.
|
||||
"""
|
||||
@ -38,7 +38,7 @@ class AverageMeter:
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val: float, n: int = 1, epoch: int = 0):
|
||||
def update(self, val: float, n: int = 1, epoch: int = 0) -> None:
|
||||
"""
|
||||
Updates the average meter with a value `val`.
|
||||
|
||||
@ -123,7 +123,7 @@ class Stats:
|
||||
self.plot_file = plot_file
|
||||
self.hard_reset(epoch=epoch)
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Called before an epoch to clear current epoch buffers.
|
||||
"""
|
||||
@ -138,7 +138,7 @@ class Stats:
|
||||
# Set a new timestamp.
|
||||
self._epoch_start = time.time()
|
||||
|
||||
def hard_reset(self, epoch: int = -1):
|
||||
def hard_reset(self, epoch: int = -1) -> None:
|
||||
"""
|
||||
Erases all logged data.
|
||||
"""
|
||||
@ -149,7 +149,7 @@ class Stats:
|
||||
self.stats = {}
|
||||
self.reset()
|
||||
|
||||
def new_epoch(self):
|
||||
def new_epoch(self) -> None:
|
||||
"""
|
||||
Initializes a new epoch.
|
||||
"""
|
||||
@ -166,7 +166,7 @@ class Stats:
|
||||
val = float(val.sum())
|
||||
return val
|
||||
|
||||
def update(self, preds: dict, stat_set: str = "train"):
|
||||
def update(self, preds: dict, stat_set: str = "train") -> None:
|
||||
"""
|
||||
Update the internal logs with metrics of a training step.
|
||||
|
||||
@ -211,7 +211,7 @@ class Stats:
|
||||
if val is not None:
|
||||
self.stats[stat_set][stat].update(val, epoch=epoch, n=1)
|
||||
|
||||
def print(self, max_it: Optional[int] = None, stat_set: str = "train"):
|
||||
def print(self, max_it: Optional[int] = None, stat_set: str = "train") -> None:
|
||||
"""
|
||||
Print the current values of all stored stats.
|
||||
|
||||
@ -247,7 +247,7 @@ class Stats:
|
||||
viz: Visdom = None,
|
||||
visdom_env: Optional[str] = None,
|
||||
plot_file: Optional[str] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Plot the line charts of the history of the stats.
|
||||
|
||||
|
@ -42,14 +42,13 @@ from base64 import b64decode
|
||||
from collections import deque
|
||||
from enum import IntEnum
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from iopath.common.file_io import PathManager
|
||||
from PIL import Image
|
||||
from pytorch3d.io.utils import _open_file
|
||||
from pytorch3d.io.utils import PathOrStr, _open_file
|
||||
from pytorch3d.renderer.mesh import TexturesBase, TexturesUV, TexturesVertex
|
||||
from pytorch3d.structures import Meshes, join_meshes_as_scene
|
||||
from pytorch3d.transforms import Transform3d, quaternion_to_matrix
|
||||
@ -498,7 +497,7 @@ class _GLTFLoader:
|
||||
|
||||
|
||||
def load_meshes(
|
||||
path: Union[str, Path],
|
||||
path: PathOrStr,
|
||||
path_manager: PathManager,
|
||||
include_textures: bool = True,
|
||||
) -> List[Tuple[Optional[str], Meshes]]:
|
||||
@ -544,7 +543,7 @@ class MeshGlbFormat(MeshFormatInterpreter):
|
||||
|
||||
def read(
|
||||
self,
|
||||
path: Union[str, Path],
|
||||
path: PathOrStr,
|
||||
include_textures: bool,
|
||||
device,
|
||||
path_manager: PathManager,
|
||||
@ -566,7 +565,7 @@ class MeshGlbFormat(MeshFormatInterpreter):
|
||||
def save(
|
||||
self,
|
||||
data: Meshes,
|
||||
path: Union[str, Path],
|
||||
path: PathOrStr,
|
||||
path_manager: PathManager,
|
||||
binary: Optional[bool],
|
||||
**kwargs,
|
||||
|
@ -18,7 +18,12 @@ from iopath.common.file_io import PathManager
|
||||
from PIL import Image
|
||||
from pytorch3d.common.types import Device
|
||||
from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas
|
||||
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
|
||||
from pytorch3d.io.utils import (
|
||||
PathOrStr,
|
||||
_check_faces_indices,
|
||||
_make_tensor,
|
||||
_open_file,
|
||||
)
|
||||
from pytorch3d.renderer import TexturesAtlas, TexturesUV
|
||||
from pytorch3d.structures import Meshes, join_meshes_as_batch
|
||||
|
||||
@ -213,7 +218,7 @@ def load_obj(
|
||||
None.
|
||||
"""
|
||||
data_dir = "./"
|
||||
if isinstance(f, (str, bytes, os.PathLike)):
|
||||
if isinstance(f, (str, bytes, Path)):
|
||||
data_dir = os.path.dirname(f)
|
||||
if path_manager is None:
|
||||
path_manager = PathManager()
|
||||
@ -297,7 +302,7 @@ class MeshObjFormat(MeshFormatInterpreter):
|
||||
|
||||
def read(
|
||||
self,
|
||||
path: Union[str, Path],
|
||||
path: PathOrStr,
|
||||
include_textures: bool,
|
||||
device: Device,
|
||||
path_manager: PathManager,
|
||||
@ -322,7 +327,7 @@ class MeshObjFormat(MeshFormatInterpreter):
|
||||
def save(
|
||||
self,
|
||||
data: Meshes,
|
||||
path: Union[str, Path],
|
||||
path: PathOrStr,
|
||||
path_manager: PathManager,
|
||||
binary: Optional[bool],
|
||||
decimal_places: Optional[int] = None,
|
||||
@ -650,7 +655,7 @@ def _load_obj(
|
||||
|
||||
|
||||
def save_obj(
|
||||
f: Union[str, os.PathLike],
|
||||
f: PathOrStr,
|
||||
verts,
|
||||
faces,
|
||||
decimal_places: Optional[int] = None,
|
||||
|
@ -13,13 +13,12 @@ This format is introduced, for example, at
|
||||
http://www.geomview.org/docs/html/OFF.html .
|
||||
"""
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.io.utils import _check_faces_indices, _open_file
|
||||
from pytorch3d.io.utils import PathOrStr, _check_faces_indices, _open_file
|
||||
from pytorch3d.renderer import TexturesAtlas, TexturesVertex
|
||||
from pytorch3d.structures import Meshes
|
||||
|
||||
@ -424,7 +423,7 @@ class MeshOffFormat(MeshFormatInterpreter):
|
||||
|
||||
def read(
|
||||
self,
|
||||
path: Union[str, Path],
|
||||
path: PathOrStr,
|
||||
include_textures: bool,
|
||||
device,
|
||||
path_manager: PathManager,
|
||||
@ -460,7 +459,7 @@ class MeshOffFormat(MeshFormatInterpreter):
|
||||
def save(
|
||||
self,
|
||||
data: Meshes,
|
||||
path: Union[str, Path],
|
||||
path: PathOrStr,
|
||||
path_manager: PathManager,
|
||||
binary: Optional[bool],
|
||||
decimal_places: Optional[int] = None,
|
||||
|
@ -5,10 +5,12 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
import pathlib
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.common.types import Device
|
||||
from pytorch3d.io.utils import PathOrStr
|
||||
from pytorch3d.structures import Meshes, Pointclouds
|
||||
|
||||
|
||||
@ -20,14 +22,14 @@ its load_* and save_* functions.
|
||||
"""
|
||||
|
||||
|
||||
def endswith(path, suffixes: Tuple[str, ...]) -> bool:
|
||||
def endswith(path: PathOrStr, suffixes: Tuple[str, ...]) -> bool:
|
||||
"""
|
||||
Returns whether the path ends with one of the given suffixes.
|
||||
If `path` is not actually a path, returns True. This is useful
|
||||
for allowing interpreters to bypass inappropriate paths, but
|
||||
always accepting streams.
|
||||
"""
|
||||
if isinstance(path, Path):
|
||||
if isinstance(path, pathlib.Path):
|
||||
return path.suffix.lower() in suffixes
|
||||
if isinstance(path, str):
|
||||
return path.lower().endswith(suffixes)
|
||||
@ -42,9 +44,9 @@ class MeshFormatInterpreter:
|
||||
|
||||
def read(
|
||||
self,
|
||||
path: Union[str, Path],
|
||||
path: PathOrStr,
|
||||
include_textures: bool,
|
||||
device,
|
||||
device: Device,
|
||||
path_manager: PathManager,
|
||||
**kwargs,
|
||||
) -> Optional[Meshes]:
|
||||
@ -68,7 +70,7 @@ class MeshFormatInterpreter:
|
||||
def save(
|
||||
self,
|
||||
data: Meshes,
|
||||
path: Union[str, Path],
|
||||
path: PathOrStr,
|
||||
path_manager: PathManager,
|
||||
binary: Optional[bool],
|
||||
**kwargs,
|
||||
@ -96,7 +98,7 @@ class PointcloudFormatInterpreter:
|
||||
"""
|
||||
|
||||
def read(
|
||||
self, path: Union[str, Path], device, path_manager: PathManager, **kwargs
|
||||
self, path: PathOrStr, device: Device, path_manager: PathManager, **kwargs
|
||||
) -> Optional[Pointclouds]:
|
||||
"""
|
||||
Read the data from the specified file and return it as
|
||||
@ -117,7 +119,7 @@ class PointcloudFormatInterpreter:
|
||||
def save(
|
||||
self,
|
||||
data: Pointclouds,
|
||||
path: Union[str, Path],
|
||||
path: PathOrStr,
|
||||
path_manager: PathManager,
|
||||
binary: Optional[bool],
|
||||
**kwargs,
|
||||
|
@ -15,13 +15,17 @@ import sys
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from io import BytesIO, TextIOBase
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
|
||||
from pytorch3d.io.utils import (
|
||||
PathOrStr,
|
||||
_check_faces_indices,
|
||||
_make_tensor,
|
||||
_open_file,
|
||||
)
|
||||
from pytorch3d.renderer import TexturesVertex
|
||||
from pytorch3d.structures import Meshes, Pointclouds
|
||||
|
||||
@ -1237,7 +1241,7 @@ class MeshPlyFormat(MeshFormatInterpreter):
|
||||
|
||||
def read(
|
||||
self,
|
||||
path: Union[str, Path],
|
||||
path: PathOrStr,
|
||||
include_textures: bool,
|
||||
device,
|
||||
path_manager: PathManager,
|
||||
@ -1269,7 +1273,7 @@ class MeshPlyFormat(MeshFormatInterpreter):
|
||||
def save(
|
||||
self,
|
||||
data: Meshes,
|
||||
path: Union[str, Path],
|
||||
path: PathOrStr,
|
||||
path_manager: PathManager,
|
||||
binary: Optional[bool],
|
||||
decimal_places: Optional[int] = None,
|
||||
@ -1318,7 +1322,7 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter):
|
||||
|
||||
def read(
|
||||
self,
|
||||
path: Union[str, Path],
|
||||
path: PathOrStr,
|
||||
device,
|
||||
path_manager: PathManager,
|
||||
**kwargs,
|
||||
@ -1339,7 +1343,7 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter):
|
||||
def save(
|
||||
self,
|
||||
data: Pointclouds,
|
||||
path: Union[str, Path],
|
||||
path: PathOrStr,
|
||||
path_manager: PathManager,
|
||||
binary: Optional[bool],
|
||||
decimal_places: Optional[int] = None,
|
||||
|
@ -7,7 +7,7 @@
|
||||
import contextlib
|
||||
import pathlib
|
||||
import warnings
|
||||
from typing import IO, ContextManager, Optional
|
||||
from typing import IO, ContextManager, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -25,6 +25,9 @@ def nullcontext(x):
|
||||
yield x
|
||||
|
||||
|
||||
PathOrStr = Union[pathlib.Path, str]
|
||||
|
||||
|
||||
def _open_file(f, path_manager: PathManager, mode="r") -> ContextManager[IO]:
|
||||
if isinstance(f, str):
|
||||
f = path_manager.open(f, mode)
|
||||
|
@ -20,10 +20,12 @@ from pytorch3d import _C
|
||||
class BlendParams(NamedTuple):
|
||||
sigma: float = 1e-4
|
||||
gamma: float = 1e-4
|
||||
background_color: Sequence = (1.0, 1.0, 1.0)
|
||||
background_color: Union[torch.Tensor, Sequence[float]] = (1.0, 1.0, 1.0)
|
||||
|
||||
|
||||
def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
|
||||
def hard_rgb_blend(
|
||||
colors: torch.Tensor, fragments, blend_params: BlendParams
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Naive blending of top K faces to return an RGBA image
|
||||
- **RGB** - choose color of the closest point i.e. K=0
|
||||
@ -47,10 +49,11 @@ def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
|
||||
# Mask for the background.
|
||||
is_background = fragments.pix_to_face[..., 0] < 0 # (N, H, W)
|
||||
|
||||
if torch.is_tensor(blend_params.background_color):
|
||||
background_color = blend_params.background_color.to(device)
|
||||
background_color_ = blend_params.background_color
|
||||
if isinstance(background_color_, torch.Tensor):
|
||||
background_color = background_color_.to(device)
|
||||
else:
|
||||
background_color = colors.new_tensor(blend_params.background_color) # (3)
|
||||
background_color = colors.new_tensor(background_color_) # pyre-fixme[16]
|
||||
|
||||
# Find out how much background_color needs to be expanded to be used for masked_scatter.
|
||||
num_background_pixels = is_background.sum()
|
||||
@ -90,7 +93,7 @@ class _SigmoidAlphaBlend(torch.autograd.Function):
|
||||
_sigmoid_alpha = _SigmoidAlphaBlend.apply
|
||||
|
||||
|
||||
def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
|
||||
def sigmoid_alpha_blend(colors, fragments, blend_params: BlendParams) -> torch.Tensor:
|
||||
"""
|
||||
Silhouette blending to return an RGBA image
|
||||
- **RGB** - choose color of the closest point.
|
||||
@ -121,9 +124,9 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
|
||||
|
||||
|
||||
def softmax_rgb_blend(
|
||||
colors,
|
||||
colors: torch.Tensor,
|
||||
fragments,
|
||||
blend_params,
|
||||
blend_params: BlendParams,
|
||||
znear: Union[float, torch.Tensor] = 1.0,
|
||||
zfar: Union[float, torch.Tensor] = 100,
|
||||
) -> torch.Tensor:
|
||||
@ -167,11 +170,11 @@ def softmax_rgb_blend(
|
||||
N, H, W, K = fragments.pix_to_face.shape
|
||||
device = fragments.pix_to_face.device
|
||||
pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=colors.device)
|
||||
background = blend_params.background_color
|
||||
if not torch.is_tensor(background):
|
||||
background = torch.tensor(background, dtype=torch.float32, device=device)
|
||||
background_ = blend_params.background_color
|
||||
if not isinstance(background_, torch.Tensor):
|
||||
background = torch.tensor(background_, dtype=torch.float32, device=device)
|
||||
else:
|
||||
background = background.to(device)
|
||||
background = background_.to(device)
|
||||
|
||||
# Weight for background color
|
||||
eps = 1e-10
|
||||
|
@ -172,9 +172,11 @@ class CamerasBase(TensorProperties):
|
||||
A Transform3d object which represents a batch of transforms
|
||||
of shape (N, 3, 3)
|
||||
"""
|
||||
self.R = kwargs.get("R", self.R) # pyre-ignore[16]
|
||||
self.T = kwargs.get("T", self.T) # pyre-ignore[16]
|
||||
world_to_view_transform = get_world_to_view_transform(R=self.R, T=self.T)
|
||||
R: torch.Tensor = kwargs.get("R", self.R)
|
||||
T: torch.Tensor = kwargs.get("T", self.T)
|
||||
self.R = R # pyre-ignore[16]
|
||||
self.T = T # pyre-ignore[16]
|
||||
world_to_view_transform = get_world_to_view_transform(R=R, T=T)
|
||||
return world_to_view_transform
|
||||
|
||||
def get_full_projection_transform(self, **kwargs) -> Transform3d:
|
||||
@ -195,8 +197,8 @@ class CamerasBase(TensorProperties):
|
||||
a Transform3d object which represents a batch of transforms
|
||||
of shape (N, 3, 3)
|
||||
"""
|
||||
self.R = kwargs.get("R", self.R) # pyre-ignore[16]
|
||||
self.T = kwargs.get("T", self.T) # pyre-ignore[16]
|
||||
self.R: torch.Tensor = kwargs.get("R", self.R) # pyre-ignore[16]
|
||||
self.T: torch.Tensor = kwargs.get("T", self.T) # pyre-ignore[16]
|
||||
world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T)
|
||||
view_to_ndc_transform = self.get_projection_transform(**kwargs)
|
||||
return world_to_view_transform.compose(view_to_ndc_transform)
|
||||
@ -293,10 +295,10 @@ def OpenGLPerspectiveCameras(
|
||||
aspect_ratio=1.0,
|
||||
fov=60.0,
|
||||
degrees: bool = True,
|
||||
R=_R,
|
||||
T=_T,
|
||||
R: torch.Tensor = _R,
|
||||
T: torch.Tensor = _T,
|
||||
device: Device = "cpu",
|
||||
):
|
||||
) -> "FoVPerspectiveCameras":
|
||||
"""
|
||||
OpenGLPerspectiveCameras has been DEPRECATED. Use FoVPerspectiveCameras instead.
|
||||
Preserving OpenGLPerspectiveCameras for backward compatibility.
|
||||
@ -360,9 +362,9 @@ class FoVPerspectiveCameras(CamerasBase):
|
||||
aspect_ratio=1.0,
|
||||
fov=60.0,
|
||||
degrees: bool = True,
|
||||
R=_R,
|
||||
T=_T,
|
||||
K=None,
|
||||
R: torch.Tensor = _R,
|
||||
T: torch.Tensor = _T,
|
||||
K: Optional[torch.Tensor] = None,
|
||||
device: Device = "cpu",
|
||||
) -> None:
|
||||
"""
|
||||
@ -397,7 +399,7 @@ class FoVPerspectiveCameras(CamerasBase):
|
||||
self.degrees = degrees
|
||||
|
||||
def compute_projection_matrix(
|
||||
self, znear, zfar, fov, aspect_ratio, degrees
|
||||
self, znear, zfar, fov, aspect_ratio, degrees: bool
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the calibration matrix K of shape (N, 4, 4)
|
||||
@ -559,10 +561,10 @@ def OpenGLOrthographicCameras(
|
||||
left=-1.0,
|
||||
right=1.0,
|
||||
scale_xyz=((1.0, 1.0, 1.0),), # (1, 3)
|
||||
R=_R,
|
||||
T=_T,
|
||||
device="cpu",
|
||||
):
|
||||
R: torch.Tensor = _R,
|
||||
T: torch.Tensor = _T,
|
||||
device: Device = "cpu",
|
||||
) -> "FoVOrthographicCameras":
|
||||
"""
|
||||
OpenGLOrthographicCameras has been DEPRECATED. Use FoVOrthographicCameras instead.
|
||||
Preserving OpenGLOrthographicCameras for backward compatibility.
|
||||
@ -605,10 +607,10 @@ class FoVOrthographicCameras(CamerasBase):
|
||||
max_x=1.0,
|
||||
min_x=-1.0,
|
||||
scale_xyz=((1.0, 1.0, 1.0),), # (1, 3)
|
||||
R=_R,
|
||||
T=_T,
|
||||
K=None,
|
||||
device="cpu",
|
||||
R: torch.Tensor = _R,
|
||||
T: torch.Tensor = _T,
|
||||
K: Optional[torch.Tensor] = None,
|
||||
device: Device = "cpu",
|
||||
):
|
||||
"""
|
||||
|
||||
@ -784,8 +786,12 @@ we assume the parameters are in screen space.
|
||||
|
||||
|
||||
def SfMPerspectiveCameras(
|
||||
focal_length=1.0, principal_point=((0.0, 0.0),), R=_R, T=_T, device="cpu"
|
||||
):
|
||||
focal_length=1.0,
|
||||
principal_point=((0.0, 0.0),),
|
||||
R: torch.Tensor = _R,
|
||||
T: torch.Tensor = _T,
|
||||
device: Device = "cpu",
|
||||
) -> "PerspectiveCameras":
|
||||
"""
|
||||
SfMPerspectiveCameras has been DEPRECATED. Use PerspectiveCameras instead.
|
||||
Preserving SfMPerspectiveCameras for backward compatibility.
|
||||
@ -843,10 +849,10 @@ class PerspectiveCameras(CamerasBase):
|
||||
self,
|
||||
focal_length=1.0,
|
||||
principal_point=((0.0, 0.0),),
|
||||
R=_R,
|
||||
T=_T,
|
||||
K=None,
|
||||
device="cpu",
|
||||
R: torch.Tensor = _R,
|
||||
T: torch.Tensor = _T,
|
||||
K: Optional[torch.Tensor] = None,
|
||||
device: Device = "cpu",
|
||||
image_size=((-1, -1),),
|
||||
) -> None:
|
||||
"""
|
||||
@ -950,8 +956,12 @@ class PerspectiveCameras(CamerasBase):
|
||||
|
||||
|
||||
def SfMOrthographicCameras(
|
||||
focal_length=1.0, principal_point=((0.0, 0.0),), R=_R, T=_T, device="cpu"
|
||||
):
|
||||
focal_length=1.0,
|
||||
principal_point=((0.0, 0.0),),
|
||||
R: torch.Tensor = _R,
|
||||
T: torch.Tensor = _T,
|
||||
device: Device = "cpu",
|
||||
) -> "OrthographicCameras":
|
||||
"""
|
||||
SfMOrthographicCameras has been DEPRECATED. Use OrthographicCameras instead.
|
||||
Preserving SfMOrthographicCameras for backward compatibility.
|
||||
@ -1008,10 +1018,10 @@ class OrthographicCameras(CamerasBase):
|
||||
self,
|
||||
focal_length=1.0,
|
||||
principal_point=((0.0, 0.0),),
|
||||
R=_R,
|
||||
T=_T,
|
||||
K=None,
|
||||
device="cpu",
|
||||
R: torch.Tensor = _R,
|
||||
T: torch.Tensor = _T,
|
||||
K: Optional[torch.Tensor] = None,
|
||||
device: Device = "cpu",
|
||||
image_size=((-1, -1),),
|
||||
) -> None:
|
||||
"""
|
||||
@ -1116,8 +1126,8 @@ class OrthographicCameras(CamerasBase):
|
||||
|
||||
|
||||
def _get_sfm_calibration_matrix(
|
||||
N,
|
||||
device,
|
||||
N: int,
|
||||
device: Device,
|
||||
focal_length,
|
||||
principal_point,
|
||||
orthographic: bool = False,
|
||||
@ -1216,7 +1226,9 @@ def _get_sfm_calibration_matrix(
|
||||
################################################
|
||||
|
||||
|
||||
def get_world_to_view_transform(R=_R, T=_T) -> Transform3d:
|
||||
def get_world_to_view_transform(
|
||||
R: torch.Tensor = _R, T: torch.Tensor = _T
|
||||
) -> Transform3d:
|
||||
"""
|
||||
This function returns a Transform3d representing the transformation
|
||||
matrix to go from world space to view space by applying a rotation and
|
||||
@ -1250,13 +1262,17 @@ def get_world_to_view_transform(R=_R, T=_T) -> Transform3d:
|
||||
raise ValueError(msg % repr(R.shape))
|
||||
|
||||
# Create a Transform3d object
|
||||
T = Translate(T, device=T.device)
|
||||
R = Rotate(R, device=R.device)
|
||||
return R.compose(T)
|
||||
T_ = Translate(T, device=T.device)
|
||||
R_ = Rotate(R, device=R.device)
|
||||
return R_.compose(T_)
|
||||
|
||||
|
||||
def camera_position_from_spherical_angles(
|
||||
distance, elevation, azimuth, degrees: bool = True, device: str = "cpu"
|
||||
distance: float,
|
||||
elevation: float,
|
||||
azimuth: float,
|
||||
degrees: bool = True,
|
||||
device: Device = "cpu",
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the location of the camera based on the distance away from
|
||||
@ -1294,7 +1310,7 @@ def camera_position_from_spherical_angles(
|
||||
|
||||
|
||||
def look_at_rotation(
|
||||
camera_position, at=((0, 0, 0),), up=((0, 1, 0),), device: str = "cpu"
|
||||
camera_position, at=((0, 0, 0),), up=((0, 1, 0),), device: Device = "cpu"
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function takes a vector 'camera_position' which specifies the location
|
||||
@ -1351,7 +1367,7 @@ def look_at_view_transform(
|
||||
eye: Optional[Sequence] = None,
|
||||
at=((0, 0, 0),), # (1, 3)
|
||||
up=((0, 1, 0),), # (1, 3)
|
||||
device="cpu",
|
||||
device: Device = "cpu",
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function returns a rotation and translation matrix
|
||||
|
@ -5,11 +5,13 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...common.types import Device
|
||||
from ...structures.meshes import Meshes
|
||||
from ..blending import (
|
||||
BlendParams,
|
||||
hard_rgb_blend,
|
||||
@ -18,6 +20,8 @@ from ..blending import (
|
||||
)
|
||||
from ..lighting import PointLights
|
||||
from ..materials import Materials
|
||||
from ..utils import TensorProperties
|
||||
from .rasterizer import Fragments
|
||||
from .shading import flat_shading, gouraud_shading, phong_shading
|
||||
|
||||
|
||||
@ -47,10 +51,10 @@ class HardPhongShader(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
device: Device = "cpu",
|
||||
cameras=None,
|
||||
lights=None,
|
||||
materials=None,
|
||||
blend_params=None,
|
||||
cameras: Optional[TensorProperties] = None,
|
||||
lights: Optional[TensorProperties] = None,
|
||||
materials: Optional[Materials] = None,
|
||||
blend_params: Optional[BlendParams] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.lights = lights if lights is not None else PointLights(device=device)
|
||||
@ -62,13 +66,14 @@ class HardPhongShader(nn.Module):
|
||||
|
||||
def to(self, device: Device):
|
||||
# Manually move to device modules which are not subclasses of nn.Module
|
||||
if self.cameras is not None:
|
||||
self.cameras = self.cameras.to(device)
|
||||
cameras = self.cameras
|
||||
if cameras is not None:
|
||||
self.cameras = cameras.to(device)
|
||||
self.materials = self.materials.to(device)
|
||||
self.lights = self.lights.to(device)
|
||||
return self
|
||||
|
||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
if cameras is None:
|
||||
msg = "Cameras must be specified either at initialization \
|
||||
@ -108,10 +113,10 @@ class SoftPhongShader(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
device: Device = "cpu",
|
||||
cameras=None,
|
||||
lights=None,
|
||||
materials=None,
|
||||
blend_params=None,
|
||||
cameras: Optional[TensorProperties] = None,
|
||||
lights: Optional[TensorProperties] = None,
|
||||
materials: Optional[Materials] = None,
|
||||
blend_params: Optional[BlendParams] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.lights = lights if lights is not None else PointLights(device=device)
|
||||
@ -123,13 +128,14 @@ class SoftPhongShader(nn.Module):
|
||||
|
||||
def to(self, device: Device):
|
||||
# Manually move to device modules which are not subclasses of nn.Module
|
||||
if self.cameras is not None:
|
||||
self.cameras = self.cameras.to(device)
|
||||
cameras = self.cameras
|
||||
if cameras is not None:
|
||||
self.cameras = cameras.to(device)
|
||||
self.materials = self.materials.to(device)
|
||||
self.lights = self.lights.to(device)
|
||||
return self
|
||||
|
||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
if cameras is None:
|
||||
msg = "Cameras must be specified either at initialization \
|
||||
@ -174,10 +180,10 @@ class HardGouraudShader(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
device: Device = "cpu",
|
||||
cameras=None,
|
||||
lights=None,
|
||||
materials=None,
|
||||
blend_params=None,
|
||||
cameras: Optional[TensorProperties] = None,
|
||||
lights: Optional[TensorProperties] = None,
|
||||
materials: Optional[Materials] = None,
|
||||
blend_params: Optional[BlendParams] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.lights = lights if lights is not None else PointLights(device=device)
|
||||
@ -189,13 +195,14 @@ class HardGouraudShader(nn.Module):
|
||||
|
||||
def to(self, device: Device):
|
||||
# Manually move to device modules which are not subclasses of nn.Module
|
||||
if self.cameras is not None:
|
||||
self.cameras = self.cameras.to(device)
|
||||
cameras = self.cameras
|
||||
if cameras is not None:
|
||||
self.cameras = cameras.to(device)
|
||||
self.materials = self.materials.to(device)
|
||||
self.lights = self.lights.to(device)
|
||||
return self
|
||||
|
||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
if cameras is None:
|
||||
msg = "Cameras must be specified either at initialization \
|
||||
@ -239,10 +246,10 @@ class SoftGouraudShader(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
device: Device = "cpu",
|
||||
cameras=None,
|
||||
lights=None,
|
||||
materials=None,
|
||||
blend_params=None,
|
||||
cameras: Optional[TensorProperties] = None,
|
||||
lights: Optional[TensorProperties] = None,
|
||||
materials: Optional[Materials] = None,
|
||||
blend_params: Optional[BlendParams] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.lights = lights if lights is not None else PointLights(device=device)
|
||||
@ -254,13 +261,14 @@ class SoftGouraudShader(nn.Module):
|
||||
|
||||
def to(self, device: Device):
|
||||
# Manually move to device modules which are not subclasses of nn.Module
|
||||
if self.cameras is not None:
|
||||
self.cameras = self.cameras.to(device)
|
||||
cameras = self.cameras
|
||||
if cameras is not None:
|
||||
self.cameras = cameras.to(device)
|
||||
self.materials = self.materials.to(device)
|
||||
self.lights = self.lights.to(device)
|
||||
return self
|
||||
|
||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
if cameras is None:
|
||||
msg = "Cameras must be specified either at initialization \
|
||||
@ -284,7 +292,11 @@ class SoftGouraudShader(nn.Module):
|
||||
|
||||
|
||||
def TexturedSoftPhongShader(
|
||||
device: Device = "cpu", cameras=None, lights=None, materials=None, blend_params=None
|
||||
device: Device = "cpu",
|
||||
cameras: Optional[TensorProperties] = None,
|
||||
lights: Optional[TensorProperties] = None,
|
||||
materials: Optional[Materials] = None,
|
||||
blend_params: Optional[BlendParams] = None,
|
||||
):
|
||||
"""
|
||||
TexturedSoftPhongShader class has been DEPRECATED. Use SoftPhongShader instead.
|
||||
@ -321,10 +333,10 @@ class HardFlatShader(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
device: Device = "cpu",
|
||||
cameras=None,
|
||||
lights=None,
|
||||
materials=None,
|
||||
blend_params=None,
|
||||
cameras: Optional[TensorProperties] = None,
|
||||
lights: Optional[TensorProperties] = None,
|
||||
materials: Optional[Materials] = None,
|
||||
blend_params: Optional[BlendParams] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.lights = lights if lights is not None else PointLights(device=device)
|
||||
@ -336,13 +348,14 @@ class HardFlatShader(nn.Module):
|
||||
|
||||
def to(self, device: Device):
|
||||
# Manually move to device modules which are not subclasses of nn.Module
|
||||
if self.cameras is not None:
|
||||
self.cameras = self.cameras.to(device)
|
||||
cameras = self.cameras
|
||||
if cameras is not None:
|
||||
self.cameras = cameras.to(device)
|
||||
self.materials = self.materials.to(device)
|
||||
self.lights = self.lights.to(device)
|
||||
return self
|
||||
|
||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
if cameras is None:
|
||||
msg = "Cameras must be specified either at initialization \
|
||||
@ -381,11 +394,11 @@ class SoftSilhouetteShader(nn.Module):
|
||||
3D Reasoning', ICCV 2019
|
||||
"""
|
||||
|
||||
def __init__(self, blend_params=None) -> None:
|
||||
def __init__(self, blend_params: Optional[BlendParams] = None) -> None:
|
||||
super().__init__()
|
||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||
|
||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Only want to render the silhouette so RGB values can be ones.
|
||||
There is no need for lighting or texturing
|
||||
|
@ -223,7 +223,7 @@ class TexturesBase:
|
||||
|
||||
return new_props
|
||||
|
||||
def sample_textures(self):
|
||||
def sample_textures(self) -> torch.Tensor:
|
||||
"""
|
||||
Different texture classes sample textures in different ways
|
||||
e.g. for vertex textures, the values at each vertex
|
||||
@ -237,7 +237,7 @@ class TexturesBase:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def faces_verts_textures_packed(self):
|
||||
def faces_verts_textures_packed(self) -> torch.Tensor:
|
||||
"""
|
||||
Returns the texture for each vertex for each face in the mesh.
|
||||
For N meshes, this function returns sum(Fi)x3xC where Fi is the
|
||||
@ -248,14 +248,14 @@ class TexturesBase:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def clone(self):
|
||||
def clone(self) -> "TexturesBase":
|
||||
"""
|
||||
Each texture class should implement a method
|
||||
to clone all necessary internal tensors.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def detach(self):
|
||||
def detach(self) -> "TexturesBase":
|
||||
"""
|
||||
Each texture class should implement a method
|
||||
to detach all necessary internal tensors.
|
||||
@ -394,7 +394,7 @@ class TexturesAtlas(TexturesBase):
|
||||
# refer to the __init__ of Meshes.
|
||||
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)
|
||||
|
||||
def clone(self):
|
||||
def clone(self) -> "TexturesAtlas":
|
||||
tex = self.__class__(atlas=self.atlas_padded().clone())
|
||||
if self._atlas_list is not None:
|
||||
tex._atlas_list = [atlas.clone() for atlas in self._atlas_list]
|
||||
@ -407,7 +407,7 @@ class TexturesAtlas(TexturesBase):
|
||||
tex._num_faces_per_mesh = num_faces
|
||||
return tex
|
||||
|
||||
def detach(self):
|
||||
def detach(self) -> "TexturesAtlas":
|
||||
tex = self.__class__(atlas=self.atlas_padded().detach())
|
||||
if self._atlas_list is not None:
|
||||
tex._atlas_list = [atlas.detach() for atlas in self._atlas_list]
|
||||
|
@ -160,7 +160,7 @@ class TensorProperties(nn.Module):
|
||||
msg = "Expected index of type int or slice; got %r"
|
||||
raise ValueError(msg % type(index))
|
||||
|
||||
def to(self, device: Device = "cpu"):
|
||||
def to(self, device: Device = "cpu") -> "TensorProperties":
|
||||
"""
|
||||
In place operation to move class properties which are tensors to a
|
||||
specified device. If self has a property "device", update this as well.
|
||||
@ -174,7 +174,7 @@ class TensorProperties(nn.Module):
|
||||
setattr(self, k, v.to(device_))
|
||||
return self
|
||||
|
||||
def clone(self, other):
|
||||
def clone(self, other) -> "TensorProperties":
|
||||
"""
|
||||
Update the tensor properties of other with the cloned properties of self.
|
||||
"""
|
||||
@ -189,7 +189,7 @@ class TensorProperties(nn.Module):
|
||||
setattr(other, k, v_clone)
|
||||
return other
|
||||
|
||||
def gather_props(self, batch_idx):
|
||||
def gather_props(self, batch_idx) -> "TensorProperties":
|
||||
"""
|
||||
This is an in place operation to reformat all tensor class attributes
|
||||
based on a set of given indices using torch.gather. This is useful when
|
||||
|
@ -187,15 +187,15 @@ class Volumes:
|
||||
"""
|
||||
|
||||
# handle densities
|
||||
densities, grid_sizes = self._convert_densities_features_to_tensor(
|
||||
densities_, grid_sizes = self._convert_densities_features_to_tensor(
|
||||
densities, "densities"
|
||||
)
|
||||
|
||||
# take device from densities
|
||||
self.device = densities.device
|
||||
self.device = densities_.device
|
||||
|
||||
# assign to the internal buffers
|
||||
self._densities = densities
|
||||
self._densities = densities_
|
||||
self._grid_sizes = grid_sizes
|
||||
|
||||
# handle features
|
||||
@ -497,7 +497,6 @@ class Volumes:
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
# pyre-fixme[16]: `List` has no attribute `shape`.
|
||||
return self._densities.shape[0]
|
||||
|
||||
def __getitem__(
|
||||
@ -547,8 +546,6 @@ class Volumes:
|
||||
Returns:
|
||||
**densities**: The tensor of volume densities.
|
||||
"""
|
||||
# pyre-fixme[7]: Expected `Tensor` but got `Union[List[torch.Tensor],
|
||||
# torch.Tensor]`.
|
||||
return self._densities
|
||||
|
||||
def densities_list(self) -> List[torch.Tensor]:
|
||||
@ -723,7 +720,6 @@ class Volumes:
|
||||
return other
|
||||
|
||||
other.device = device_
|
||||
# pyre-fixme[16]: `List` has no attribute `to`.
|
||||
other._densities = self._densities.to(device_)
|
||||
if self._features is not None:
|
||||
# pyre-fixme[16]: `Optional` has no attribute `to`.
|
||||
|
@ -10,6 +10,8 @@ from typing import Optional
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..common.types import Device
|
||||
|
||||
|
||||
"""
|
||||
The transformation matrices returned from the functions in this file assume
|
||||
@ -286,7 +288,9 @@ def matrix_to_euler_angles(matrix, convention: str):
|
||||
return torch.stack(o, -1)
|
||||
|
||||
|
||||
def random_quaternions(n: int, dtype: Optional[torch.dtype] = None, device=None):
|
||||
def random_quaternions(
|
||||
n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
|
||||
):
|
||||
"""
|
||||
Generate random quaternions representing rotations,
|
||||
i.e. versors with nonnegative real part.
|
||||
@ -306,7 +310,9 @@ def random_quaternions(n: int, dtype: Optional[torch.dtype] = None, device=None)
|
||||
return o
|
||||
|
||||
|
||||
def random_rotations(n: int, dtype: Optional[torch.dtype] = None, device=None):
|
||||
def random_rotations(
|
||||
n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
|
||||
):
|
||||
"""
|
||||
Generate random rotations as 3x3 rotation matrices.
|
||||
|
||||
@ -323,7 +329,9 @@ def random_rotations(n: int, dtype: Optional[torch.dtype] = None, device=None):
|
||||
return quaternion_to_matrix(quaternions)
|
||||
|
||||
|
||||
def random_rotation(dtype: Optional[torch.dtype] = None, device=None):
|
||||
def random_rotation(
|
||||
dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
|
||||
):
|
||||
"""
|
||||
Generate a single random 3x3 rotation matrix.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user