mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
More type annotations
Summary: More type annotations: device, shaders, pluggable I/O, stats in NeRF project, cameras, textures, etc... Reviewed By: nikhilaravi Differential Revision: D29327396 fbshipit-source-id: cdf0ceaaa010e22423088752688c8dd81f1acc3c
This commit is contained in:
parent
542e2e7c07
commit
f593bfd3c2
@ -136,7 +136,7 @@ def download_data(
|
|||||||
dataset_names: Optional[List[str]] = None,
|
dataset_names: Optional[List[str]] = None,
|
||||||
data_root: str = DEFAULT_DATA_ROOT,
|
data_root: str = DEFAULT_DATA_ROOT,
|
||||||
url_root: str = DEFAULT_URL_ROOT,
|
url_root: str = DEFAULT_URL_ROOT,
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Downloads the relevant dataset files.
|
Downloads the relevant dataset files.
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ class AverageMeter:
|
|||||||
self.history = []
|
self.history = []
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self) -> None:
|
||||||
"""
|
"""
|
||||||
Reset the running average meter.
|
Reset the running average meter.
|
||||||
"""
|
"""
|
||||||
@ -38,7 +38,7 @@ class AverageMeter:
|
|||||||
self.sum = 0
|
self.sum = 0
|
||||||
self.count = 0
|
self.count = 0
|
||||||
|
|
||||||
def update(self, val: float, n: int = 1, epoch: int = 0):
|
def update(self, val: float, n: int = 1, epoch: int = 0) -> None:
|
||||||
"""
|
"""
|
||||||
Updates the average meter with a value `val`.
|
Updates the average meter with a value `val`.
|
||||||
|
|
||||||
@ -123,7 +123,7 @@ class Stats:
|
|||||||
self.plot_file = plot_file
|
self.plot_file = plot_file
|
||||||
self.hard_reset(epoch=epoch)
|
self.hard_reset(epoch=epoch)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self) -> None:
|
||||||
"""
|
"""
|
||||||
Called before an epoch to clear current epoch buffers.
|
Called before an epoch to clear current epoch buffers.
|
||||||
"""
|
"""
|
||||||
@ -138,7 +138,7 @@ class Stats:
|
|||||||
# Set a new timestamp.
|
# Set a new timestamp.
|
||||||
self._epoch_start = time.time()
|
self._epoch_start = time.time()
|
||||||
|
|
||||||
def hard_reset(self, epoch: int = -1):
|
def hard_reset(self, epoch: int = -1) -> None:
|
||||||
"""
|
"""
|
||||||
Erases all logged data.
|
Erases all logged data.
|
||||||
"""
|
"""
|
||||||
@ -149,7 +149,7 @@ class Stats:
|
|||||||
self.stats = {}
|
self.stats = {}
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def new_epoch(self):
|
def new_epoch(self) -> None:
|
||||||
"""
|
"""
|
||||||
Initializes a new epoch.
|
Initializes a new epoch.
|
||||||
"""
|
"""
|
||||||
@ -166,7 +166,7 @@ class Stats:
|
|||||||
val = float(val.sum())
|
val = float(val.sum())
|
||||||
return val
|
return val
|
||||||
|
|
||||||
def update(self, preds: dict, stat_set: str = "train"):
|
def update(self, preds: dict, stat_set: str = "train") -> None:
|
||||||
"""
|
"""
|
||||||
Update the internal logs with metrics of a training step.
|
Update the internal logs with metrics of a training step.
|
||||||
|
|
||||||
@ -211,7 +211,7 @@ class Stats:
|
|||||||
if val is not None:
|
if val is not None:
|
||||||
self.stats[stat_set][stat].update(val, epoch=epoch, n=1)
|
self.stats[stat_set][stat].update(val, epoch=epoch, n=1)
|
||||||
|
|
||||||
def print(self, max_it: Optional[int] = None, stat_set: str = "train"):
|
def print(self, max_it: Optional[int] = None, stat_set: str = "train") -> None:
|
||||||
"""
|
"""
|
||||||
Print the current values of all stored stats.
|
Print the current values of all stored stats.
|
||||||
|
|
||||||
@ -247,7 +247,7 @@ class Stats:
|
|||||||
viz: Visdom = None,
|
viz: Visdom = None,
|
||||||
visdom_env: Optional[str] = None,
|
visdom_env: Optional[str] = None,
|
||||||
plot_file: Optional[str] = None,
|
plot_file: Optional[str] = None,
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Plot the line charts of the history of the stats.
|
Plot the line charts of the history of the stats.
|
||||||
|
|
||||||
|
@ -42,14 +42,13 @@ from base64 import b64decode
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union, cast
|
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from iopath.common.file_io import PathManager
|
from iopath.common.file_io import PathManager
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pytorch3d.io.utils import _open_file
|
from pytorch3d.io.utils import PathOrStr, _open_file
|
||||||
from pytorch3d.renderer.mesh import TexturesBase, TexturesUV, TexturesVertex
|
from pytorch3d.renderer.mesh import TexturesBase, TexturesUV, TexturesVertex
|
||||||
from pytorch3d.structures import Meshes, join_meshes_as_scene
|
from pytorch3d.structures import Meshes, join_meshes_as_scene
|
||||||
from pytorch3d.transforms import Transform3d, quaternion_to_matrix
|
from pytorch3d.transforms import Transform3d, quaternion_to_matrix
|
||||||
@ -498,7 +497,7 @@ class _GLTFLoader:
|
|||||||
|
|
||||||
|
|
||||||
def load_meshes(
|
def load_meshes(
|
||||||
path: Union[str, Path],
|
path: PathOrStr,
|
||||||
path_manager: PathManager,
|
path_manager: PathManager,
|
||||||
include_textures: bool = True,
|
include_textures: bool = True,
|
||||||
) -> List[Tuple[Optional[str], Meshes]]:
|
) -> List[Tuple[Optional[str], Meshes]]:
|
||||||
@ -544,7 +543,7 @@ class MeshGlbFormat(MeshFormatInterpreter):
|
|||||||
|
|
||||||
def read(
|
def read(
|
||||||
self,
|
self,
|
||||||
path: Union[str, Path],
|
path: PathOrStr,
|
||||||
include_textures: bool,
|
include_textures: bool,
|
||||||
device,
|
device,
|
||||||
path_manager: PathManager,
|
path_manager: PathManager,
|
||||||
@ -566,7 +565,7 @@ class MeshGlbFormat(MeshFormatInterpreter):
|
|||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
data: Meshes,
|
data: Meshes,
|
||||||
path: Union[str, Path],
|
path: PathOrStr,
|
||||||
path_manager: PathManager,
|
path_manager: PathManager,
|
||||||
binary: Optional[bool],
|
binary: Optional[bool],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -18,7 +18,12 @@ from iopath.common.file_io import PathManager
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pytorch3d.common.types import Device
|
from pytorch3d.common.types import Device
|
||||||
from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas
|
from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas
|
||||||
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
|
from pytorch3d.io.utils import (
|
||||||
|
PathOrStr,
|
||||||
|
_check_faces_indices,
|
||||||
|
_make_tensor,
|
||||||
|
_open_file,
|
||||||
|
)
|
||||||
from pytorch3d.renderer import TexturesAtlas, TexturesUV
|
from pytorch3d.renderer import TexturesAtlas, TexturesUV
|
||||||
from pytorch3d.structures import Meshes, join_meshes_as_batch
|
from pytorch3d.structures import Meshes, join_meshes_as_batch
|
||||||
|
|
||||||
@ -213,7 +218,7 @@ def load_obj(
|
|||||||
None.
|
None.
|
||||||
"""
|
"""
|
||||||
data_dir = "./"
|
data_dir = "./"
|
||||||
if isinstance(f, (str, bytes, os.PathLike)):
|
if isinstance(f, (str, bytes, Path)):
|
||||||
data_dir = os.path.dirname(f)
|
data_dir = os.path.dirname(f)
|
||||||
if path_manager is None:
|
if path_manager is None:
|
||||||
path_manager = PathManager()
|
path_manager = PathManager()
|
||||||
@ -297,7 +302,7 @@ class MeshObjFormat(MeshFormatInterpreter):
|
|||||||
|
|
||||||
def read(
|
def read(
|
||||||
self,
|
self,
|
||||||
path: Union[str, Path],
|
path: PathOrStr,
|
||||||
include_textures: bool,
|
include_textures: bool,
|
||||||
device: Device,
|
device: Device,
|
||||||
path_manager: PathManager,
|
path_manager: PathManager,
|
||||||
@ -322,7 +327,7 @@ class MeshObjFormat(MeshFormatInterpreter):
|
|||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
data: Meshes,
|
data: Meshes,
|
||||||
path: Union[str, Path],
|
path: PathOrStr,
|
||||||
path_manager: PathManager,
|
path_manager: PathManager,
|
||||||
binary: Optional[bool],
|
binary: Optional[bool],
|
||||||
decimal_places: Optional[int] = None,
|
decimal_places: Optional[int] = None,
|
||||||
@ -650,7 +655,7 @@ def _load_obj(
|
|||||||
|
|
||||||
|
|
||||||
def save_obj(
|
def save_obj(
|
||||||
f: Union[str, os.PathLike],
|
f: PathOrStr,
|
||||||
verts,
|
verts,
|
||||||
faces,
|
faces,
|
||||||
decimal_places: Optional[int] = None,
|
decimal_places: Optional[int] = None,
|
||||||
|
@ -13,13 +13,12 @@ This format is introduced, for example, at
|
|||||||
http://www.geomview.org/docs/html/OFF.html .
|
http://www.geomview.org/docs/html/OFF.html .
|
||||||
"""
|
"""
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, Tuple, Union, cast
|
from typing import Optional, Tuple, Union, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from iopath.common.file_io import PathManager
|
from iopath.common.file_io import PathManager
|
||||||
from pytorch3d.io.utils import _check_faces_indices, _open_file
|
from pytorch3d.io.utils import PathOrStr, _check_faces_indices, _open_file
|
||||||
from pytorch3d.renderer import TexturesAtlas, TexturesVertex
|
from pytorch3d.renderer import TexturesAtlas, TexturesVertex
|
||||||
from pytorch3d.structures import Meshes
|
from pytorch3d.structures import Meshes
|
||||||
|
|
||||||
@ -424,7 +423,7 @@ class MeshOffFormat(MeshFormatInterpreter):
|
|||||||
|
|
||||||
def read(
|
def read(
|
||||||
self,
|
self,
|
||||||
path: Union[str, Path],
|
path: PathOrStr,
|
||||||
include_textures: bool,
|
include_textures: bool,
|
||||||
device,
|
device,
|
||||||
path_manager: PathManager,
|
path_manager: PathManager,
|
||||||
@ -460,7 +459,7 @@ class MeshOffFormat(MeshFormatInterpreter):
|
|||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
data: Meshes,
|
data: Meshes,
|
||||||
path: Union[str, Path],
|
path: PathOrStr,
|
||||||
path_manager: PathManager,
|
path_manager: PathManager,
|
||||||
binary: Optional[bool],
|
binary: Optional[bool],
|
||||||
decimal_places: Optional[int] = None,
|
decimal_places: Optional[int] = None,
|
||||||
|
@ -5,10 +5,12 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from pathlib import Path
|
import pathlib
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
from iopath.common.file_io import PathManager
|
from iopath.common.file_io import PathManager
|
||||||
|
from pytorch3d.common.types import Device
|
||||||
|
from pytorch3d.io.utils import PathOrStr
|
||||||
from pytorch3d.structures import Meshes, Pointclouds
|
from pytorch3d.structures import Meshes, Pointclouds
|
||||||
|
|
||||||
|
|
||||||
@ -20,14 +22,14 @@ its load_* and save_* functions.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def endswith(path, suffixes: Tuple[str, ...]) -> bool:
|
def endswith(path: PathOrStr, suffixes: Tuple[str, ...]) -> bool:
|
||||||
"""
|
"""
|
||||||
Returns whether the path ends with one of the given suffixes.
|
Returns whether the path ends with one of the given suffixes.
|
||||||
If `path` is not actually a path, returns True. This is useful
|
If `path` is not actually a path, returns True. This is useful
|
||||||
for allowing interpreters to bypass inappropriate paths, but
|
for allowing interpreters to bypass inappropriate paths, but
|
||||||
always accepting streams.
|
always accepting streams.
|
||||||
"""
|
"""
|
||||||
if isinstance(path, Path):
|
if isinstance(path, pathlib.Path):
|
||||||
return path.suffix.lower() in suffixes
|
return path.suffix.lower() in suffixes
|
||||||
if isinstance(path, str):
|
if isinstance(path, str):
|
||||||
return path.lower().endswith(suffixes)
|
return path.lower().endswith(suffixes)
|
||||||
@ -42,9 +44,9 @@ class MeshFormatInterpreter:
|
|||||||
|
|
||||||
def read(
|
def read(
|
||||||
self,
|
self,
|
||||||
path: Union[str, Path],
|
path: PathOrStr,
|
||||||
include_textures: bool,
|
include_textures: bool,
|
||||||
device,
|
device: Device,
|
||||||
path_manager: PathManager,
|
path_manager: PathManager,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Optional[Meshes]:
|
) -> Optional[Meshes]:
|
||||||
@ -68,7 +70,7 @@ class MeshFormatInterpreter:
|
|||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
data: Meshes,
|
data: Meshes,
|
||||||
path: Union[str, Path],
|
path: PathOrStr,
|
||||||
path_manager: PathManager,
|
path_manager: PathManager,
|
||||||
binary: Optional[bool],
|
binary: Optional[bool],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -96,7 +98,7 @@ class PointcloudFormatInterpreter:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def read(
|
def read(
|
||||||
self, path: Union[str, Path], device, path_manager: PathManager, **kwargs
|
self, path: PathOrStr, device: Device, path_manager: PathManager, **kwargs
|
||||||
) -> Optional[Pointclouds]:
|
) -> Optional[Pointclouds]:
|
||||||
"""
|
"""
|
||||||
Read the data from the specified file and return it as
|
Read the data from the specified file and return it as
|
||||||
@ -117,7 +119,7 @@ class PointcloudFormatInterpreter:
|
|||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
data: Pointclouds,
|
data: Pointclouds,
|
||||||
path: Union[str, Path],
|
path: PathOrStr,
|
||||||
path_manager: PathManager,
|
path_manager: PathManager,
|
||||||
binary: Optional[bool],
|
binary: Optional[bool],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -15,13 +15,17 @@ import sys
|
|||||||
import warnings
|
import warnings
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from io import BytesIO, TextIOBase
|
from io import BytesIO, TextIOBase
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Optional, Tuple, Union, cast
|
from typing import List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from iopath.common.file_io import PathManager
|
from iopath.common.file_io import PathManager
|
||||||
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
|
from pytorch3d.io.utils import (
|
||||||
|
PathOrStr,
|
||||||
|
_check_faces_indices,
|
||||||
|
_make_tensor,
|
||||||
|
_open_file,
|
||||||
|
)
|
||||||
from pytorch3d.renderer import TexturesVertex
|
from pytorch3d.renderer import TexturesVertex
|
||||||
from pytorch3d.structures import Meshes, Pointclouds
|
from pytorch3d.structures import Meshes, Pointclouds
|
||||||
|
|
||||||
@ -1237,7 +1241,7 @@ class MeshPlyFormat(MeshFormatInterpreter):
|
|||||||
|
|
||||||
def read(
|
def read(
|
||||||
self,
|
self,
|
||||||
path: Union[str, Path],
|
path: PathOrStr,
|
||||||
include_textures: bool,
|
include_textures: bool,
|
||||||
device,
|
device,
|
||||||
path_manager: PathManager,
|
path_manager: PathManager,
|
||||||
@ -1269,7 +1273,7 @@ class MeshPlyFormat(MeshFormatInterpreter):
|
|||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
data: Meshes,
|
data: Meshes,
|
||||||
path: Union[str, Path],
|
path: PathOrStr,
|
||||||
path_manager: PathManager,
|
path_manager: PathManager,
|
||||||
binary: Optional[bool],
|
binary: Optional[bool],
|
||||||
decimal_places: Optional[int] = None,
|
decimal_places: Optional[int] = None,
|
||||||
@ -1318,7 +1322,7 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter):
|
|||||||
|
|
||||||
def read(
|
def read(
|
||||||
self,
|
self,
|
||||||
path: Union[str, Path],
|
path: PathOrStr,
|
||||||
device,
|
device,
|
||||||
path_manager: PathManager,
|
path_manager: PathManager,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -1339,7 +1343,7 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter):
|
|||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
data: Pointclouds,
|
data: Pointclouds,
|
||||||
path: Union[str, Path],
|
path: PathOrStr,
|
||||||
path_manager: PathManager,
|
path_manager: PathManager,
|
||||||
binary: Optional[bool],
|
binary: Optional[bool],
|
||||||
decimal_places: Optional[int] = None,
|
decimal_places: Optional[int] = None,
|
||||||
|
@ -7,7 +7,7 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import pathlib
|
import pathlib
|
||||||
import warnings
|
import warnings
|
||||||
from typing import IO, ContextManager, Optional
|
from typing import IO, ContextManager, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -25,6 +25,9 @@ def nullcontext(x):
|
|||||||
yield x
|
yield x
|
||||||
|
|
||||||
|
|
||||||
|
PathOrStr = Union[pathlib.Path, str]
|
||||||
|
|
||||||
|
|
||||||
def _open_file(f, path_manager: PathManager, mode="r") -> ContextManager[IO]:
|
def _open_file(f, path_manager: PathManager, mode="r") -> ContextManager[IO]:
|
||||||
if isinstance(f, str):
|
if isinstance(f, str):
|
||||||
f = path_manager.open(f, mode)
|
f = path_manager.open(f, mode)
|
||||||
|
@ -20,10 +20,12 @@ from pytorch3d import _C
|
|||||||
class BlendParams(NamedTuple):
|
class BlendParams(NamedTuple):
|
||||||
sigma: float = 1e-4
|
sigma: float = 1e-4
|
||||||
gamma: float = 1e-4
|
gamma: float = 1e-4
|
||||||
background_color: Sequence = (1.0, 1.0, 1.0)
|
background_color: Union[torch.Tensor, Sequence[float]] = (1.0, 1.0, 1.0)
|
||||||
|
|
||||||
|
|
||||||
def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
|
def hard_rgb_blend(
|
||||||
|
colors: torch.Tensor, fragments, blend_params: BlendParams
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Naive blending of top K faces to return an RGBA image
|
Naive blending of top K faces to return an RGBA image
|
||||||
- **RGB** - choose color of the closest point i.e. K=0
|
- **RGB** - choose color of the closest point i.e. K=0
|
||||||
@ -47,10 +49,11 @@ def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
|
|||||||
# Mask for the background.
|
# Mask for the background.
|
||||||
is_background = fragments.pix_to_face[..., 0] < 0 # (N, H, W)
|
is_background = fragments.pix_to_face[..., 0] < 0 # (N, H, W)
|
||||||
|
|
||||||
if torch.is_tensor(blend_params.background_color):
|
background_color_ = blend_params.background_color
|
||||||
background_color = blend_params.background_color.to(device)
|
if isinstance(background_color_, torch.Tensor):
|
||||||
|
background_color = background_color_.to(device)
|
||||||
else:
|
else:
|
||||||
background_color = colors.new_tensor(blend_params.background_color) # (3)
|
background_color = colors.new_tensor(background_color_) # pyre-fixme[16]
|
||||||
|
|
||||||
# Find out how much background_color needs to be expanded to be used for masked_scatter.
|
# Find out how much background_color needs to be expanded to be used for masked_scatter.
|
||||||
num_background_pixels = is_background.sum()
|
num_background_pixels = is_background.sum()
|
||||||
@ -90,7 +93,7 @@ class _SigmoidAlphaBlend(torch.autograd.Function):
|
|||||||
_sigmoid_alpha = _SigmoidAlphaBlend.apply
|
_sigmoid_alpha = _SigmoidAlphaBlend.apply
|
||||||
|
|
||||||
|
|
||||||
def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
|
def sigmoid_alpha_blend(colors, fragments, blend_params: BlendParams) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Silhouette blending to return an RGBA image
|
Silhouette blending to return an RGBA image
|
||||||
- **RGB** - choose color of the closest point.
|
- **RGB** - choose color of the closest point.
|
||||||
@ -121,9 +124,9 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
|
|||||||
|
|
||||||
|
|
||||||
def softmax_rgb_blend(
|
def softmax_rgb_blend(
|
||||||
colors,
|
colors: torch.Tensor,
|
||||||
fragments,
|
fragments,
|
||||||
blend_params,
|
blend_params: BlendParams,
|
||||||
znear: Union[float, torch.Tensor] = 1.0,
|
znear: Union[float, torch.Tensor] = 1.0,
|
||||||
zfar: Union[float, torch.Tensor] = 100,
|
zfar: Union[float, torch.Tensor] = 100,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -167,11 +170,11 @@ def softmax_rgb_blend(
|
|||||||
N, H, W, K = fragments.pix_to_face.shape
|
N, H, W, K = fragments.pix_to_face.shape
|
||||||
device = fragments.pix_to_face.device
|
device = fragments.pix_to_face.device
|
||||||
pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=colors.device)
|
pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=colors.device)
|
||||||
background = blend_params.background_color
|
background_ = blend_params.background_color
|
||||||
if not torch.is_tensor(background):
|
if not isinstance(background_, torch.Tensor):
|
||||||
background = torch.tensor(background, dtype=torch.float32, device=device)
|
background = torch.tensor(background_, dtype=torch.float32, device=device)
|
||||||
else:
|
else:
|
||||||
background = background.to(device)
|
background = background_.to(device)
|
||||||
|
|
||||||
# Weight for background color
|
# Weight for background color
|
||||||
eps = 1e-10
|
eps = 1e-10
|
||||||
|
@ -172,9 +172,11 @@ class CamerasBase(TensorProperties):
|
|||||||
A Transform3d object which represents a batch of transforms
|
A Transform3d object which represents a batch of transforms
|
||||||
of shape (N, 3, 3)
|
of shape (N, 3, 3)
|
||||||
"""
|
"""
|
||||||
self.R = kwargs.get("R", self.R) # pyre-ignore[16]
|
R: torch.Tensor = kwargs.get("R", self.R)
|
||||||
self.T = kwargs.get("T", self.T) # pyre-ignore[16]
|
T: torch.Tensor = kwargs.get("T", self.T)
|
||||||
world_to_view_transform = get_world_to_view_transform(R=self.R, T=self.T)
|
self.R = R # pyre-ignore[16]
|
||||||
|
self.T = T # pyre-ignore[16]
|
||||||
|
world_to_view_transform = get_world_to_view_transform(R=R, T=T)
|
||||||
return world_to_view_transform
|
return world_to_view_transform
|
||||||
|
|
||||||
def get_full_projection_transform(self, **kwargs) -> Transform3d:
|
def get_full_projection_transform(self, **kwargs) -> Transform3d:
|
||||||
@ -195,8 +197,8 @@ class CamerasBase(TensorProperties):
|
|||||||
a Transform3d object which represents a batch of transforms
|
a Transform3d object which represents a batch of transforms
|
||||||
of shape (N, 3, 3)
|
of shape (N, 3, 3)
|
||||||
"""
|
"""
|
||||||
self.R = kwargs.get("R", self.R) # pyre-ignore[16]
|
self.R: torch.Tensor = kwargs.get("R", self.R) # pyre-ignore[16]
|
||||||
self.T = kwargs.get("T", self.T) # pyre-ignore[16]
|
self.T: torch.Tensor = kwargs.get("T", self.T) # pyre-ignore[16]
|
||||||
world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T)
|
world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T)
|
||||||
view_to_ndc_transform = self.get_projection_transform(**kwargs)
|
view_to_ndc_transform = self.get_projection_transform(**kwargs)
|
||||||
return world_to_view_transform.compose(view_to_ndc_transform)
|
return world_to_view_transform.compose(view_to_ndc_transform)
|
||||||
@ -293,10 +295,10 @@ def OpenGLPerspectiveCameras(
|
|||||||
aspect_ratio=1.0,
|
aspect_ratio=1.0,
|
||||||
fov=60.0,
|
fov=60.0,
|
||||||
degrees: bool = True,
|
degrees: bool = True,
|
||||||
R=_R,
|
R: torch.Tensor = _R,
|
||||||
T=_T,
|
T: torch.Tensor = _T,
|
||||||
device: Device = "cpu",
|
device: Device = "cpu",
|
||||||
):
|
) -> "FoVPerspectiveCameras":
|
||||||
"""
|
"""
|
||||||
OpenGLPerspectiveCameras has been DEPRECATED. Use FoVPerspectiveCameras instead.
|
OpenGLPerspectiveCameras has been DEPRECATED. Use FoVPerspectiveCameras instead.
|
||||||
Preserving OpenGLPerspectiveCameras for backward compatibility.
|
Preserving OpenGLPerspectiveCameras for backward compatibility.
|
||||||
@ -360,9 +362,9 @@ class FoVPerspectiveCameras(CamerasBase):
|
|||||||
aspect_ratio=1.0,
|
aspect_ratio=1.0,
|
||||||
fov=60.0,
|
fov=60.0,
|
||||||
degrees: bool = True,
|
degrees: bool = True,
|
||||||
R=_R,
|
R: torch.Tensor = _R,
|
||||||
T=_T,
|
T: torch.Tensor = _T,
|
||||||
K=None,
|
K: Optional[torch.Tensor] = None,
|
||||||
device: Device = "cpu",
|
device: Device = "cpu",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -397,7 +399,7 @@ class FoVPerspectiveCameras(CamerasBase):
|
|||||||
self.degrees = degrees
|
self.degrees = degrees
|
||||||
|
|
||||||
def compute_projection_matrix(
|
def compute_projection_matrix(
|
||||||
self, znear, zfar, fov, aspect_ratio, degrees
|
self, znear, zfar, fov, aspect_ratio, degrees: bool
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute the calibration matrix K of shape (N, 4, 4)
|
Compute the calibration matrix K of shape (N, 4, 4)
|
||||||
@ -559,10 +561,10 @@ def OpenGLOrthographicCameras(
|
|||||||
left=-1.0,
|
left=-1.0,
|
||||||
right=1.0,
|
right=1.0,
|
||||||
scale_xyz=((1.0, 1.0, 1.0),), # (1, 3)
|
scale_xyz=((1.0, 1.0, 1.0),), # (1, 3)
|
||||||
R=_R,
|
R: torch.Tensor = _R,
|
||||||
T=_T,
|
T: torch.Tensor = _T,
|
||||||
device="cpu",
|
device: Device = "cpu",
|
||||||
):
|
) -> "FoVOrthographicCameras":
|
||||||
"""
|
"""
|
||||||
OpenGLOrthographicCameras has been DEPRECATED. Use FoVOrthographicCameras instead.
|
OpenGLOrthographicCameras has been DEPRECATED. Use FoVOrthographicCameras instead.
|
||||||
Preserving OpenGLOrthographicCameras for backward compatibility.
|
Preserving OpenGLOrthographicCameras for backward compatibility.
|
||||||
@ -605,10 +607,10 @@ class FoVOrthographicCameras(CamerasBase):
|
|||||||
max_x=1.0,
|
max_x=1.0,
|
||||||
min_x=-1.0,
|
min_x=-1.0,
|
||||||
scale_xyz=((1.0, 1.0, 1.0),), # (1, 3)
|
scale_xyz=((1.0, 1.0, 1.0),), # (1, 3)
|
||||||
R=_R,
|
R: torch.Tensor = _R,
|
||||||
T=_T,
|
T: torch.Tensor = _T,
|
||||||
K=None,
|
K: Optional[torch.Tensor] = None,
|
||||||
device="cpu",
|
device: Device = "cpu",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -784,8 +786,12 @@ we assume the parameters are in screen space.
|
|||||||
|
|
||||||
|
|
||||||
def SfMPerspectiveCameras(
|
def SfMPerspectiveCameras(
|
||||||
focal_length=1.0, principal_point=((0.0, 0.0),), R=_R, T=_T, device="cpu"
|
focal_length=1.0,
|
||||||
):
|
principal_point=((0.0, 0.0),),
|
||||||
|
R: torch.Tensor = _R,
|
||||||
|
T: torch.Tensor = _T,
|
||||||
|
device: Device = "cpu",
|
||||||
|
) -> "PerspectiveCameras":
|
||||||
"""
|
"""
|
||||||
SfMPerspectiveCameras has been DEPRECATED. Use PerspectiveCameras instead.
|
SfMPerspectiveCameras has been DEPRECATED. Use PerspectiveCameras instead.
|
||||||
Preserving SfMPerspectiveCameras for backward compatibility.
|
Preserving SfMPerspectiveCameras for backward compatibility.
|
||||||
@ -843,10 +849,10 @@ class PerspectiveCameras(CamerasBase):
|
|||||||
self,
|
self,
|
||||||
focal_length=1.0,
|
focal_length=1.0,
|
||||||
principal_point=((0.0, 0.0),),
|
principal_point=((0.0, 0.0),),
|
||||||
R=_R,
|
R: torch.Tensor = _R,
|
||||||
T=_T,
|
T: torch.Tensor = _T,
|
||||||
K=None,
|
K: Optional[torch.Tensor] = None,
|
||||||
device="cpu",
|
device: Device = "cpu",
|
||||||
image_size=((-1, -1),),
|
image_size=((-1, -1),),
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -950,8 +956,12 @@ class PerspectiveCameras(CamerasBase):
|
|||||||
|
|
||||||
|
|
||||||
def SfMOrthographicCameras(
|
def SfMOrthographicCameras(
|
||||||
focal_length=1.0, principal_point=((0.0, 0.0),), R=_R, T=_T, device="cpu"
|
focal_length=1.0,
|
||||||
):
|
principal_point=((0.0, 0.0),),
|
||||||
|
R: torch.Tensor = _R,
|
||||||
|
T: torch.Tensor = _T,
|
||||||
|
device: Device = "cpu",
|
||||||
|
) -> "OrthographicCameras":
|
||||||
"""
|
"""
|
||||||
SfMOrthographicCameras has been DEPRECATED. Use OrthographicCameras instead.
|
SfMOrthographicCameras has been DEPRECATED. Use OrthographicCameras instead.
|
||||||
Preserving SfMOrthographicCameras for backward compatibility.
|
Preserving SfMOrthographicCameras for backward compatibility.
|
||||||
@ -1008,10 +1018,10 @@ class OrthographicCameras(CamerasBase):
|
|||||||
self,
|
self,
|
||||||
focal_length=1.0,
|
focal_length=1.0,
|
||||||
principal_point=((0.0, 0.0),),
|
principal_point=((0.0, 0.0),),
|
||||||
R=_R,
|
R: torch.Tensor = _R,
|
||||||
T=_T,
|
T: torch.Tensor = _T,
|
||||||
K=None,
|
K: Optional[torch.Tensor] = None,
|
||||||
device="cpu",
|
device: Device = "cpu",
|
||||||
image_size=((-1, -1),),
|
image_size=((-1, -1),),
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -1116,8 +1126,8 @@ class OrthographicCameras(CamerasBase):
|
|||||||
|
|
||||||
|
|
||||||
def _get_sfm_calibration_matrix(
|
def _get_sfm_calibration_matrix(
|
||||||
N,
|
N: int,
|
||||||
device,
|
device: Device,
|
||||||
focal_length,
|
focal_length,
|
||||||
principal_point,
|
principal_point,
|
||||||
orthographic: bool = False,
|
orthographic: bool = False,
|
||||||
@ -1216,7 +1226,9 @@ def _get_sfm_calibration_matrix(
|
|||||||
################################################
|
################################################
|
||||||
|
|
||||||
|
|
||||||
def get_world_to_view_transform(R=_R, T=_T) -> Transform3d:
|
def get_world_to_view_transform(
|
||||||
|
R: torch.Tensor = _R, T: torch.Tensor = _T
|
||||||
|
) -> Transform3d:
|
||||||
"""
|
"""
|
||||||
This function returns a Transform3d representing the transformation
|
This function returns a Transform3d representing the transformation
|
||||||
matrix to go from world space to view space by applying a rotation and
|
matrix to go from world space to view space by applying a rotation and
|
||||||
@ -1250,13 +1262,17 @@ def get_world_to_view_transform(R=_R, T=_T) -> Transform3d:
|
|||||||
raise ValueError(msg % repr(R.shape))
|
raise ValueError(msg % repr(R.shape))
|
||||||
|
|
||||||
# Create a Transform3d object
|
# Create a Transform3d object
|
||||||
T = Translate(T, device=T.device)
|
T_ = Translate(T, device=T.device)
|
||||||
R = Rotate(R, device=R.device)
|
R_ = Rotate(R, device=R.device)
|
||||||
return R.compose(T)
|
return R_.compose(T_)
|
||||||
|
|
||||||
|
|
||||||
def camera_position_from_spherical_angles(
|
def camera_position_from_spherical_angles(
|
||||||
distance, elevation, azimuth, degrees: bool = True, device: str = "cpu"
|
distance: float,
|
||||||
|
elevation: float,
|
||||||
|
azimuth: float,
|
||||||
|
degrees: bool = True,
|
||||||
|
device: Device = "cpu",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Calculate the location of the camera based on the distance away from
|
Calculate the location of the camera based on the distance away from
|
||||||
@ -1294,7 +1310,7 @@ def camera_position_from_spherical_angles(
|
|||||||
|
|
||||||
|
|
||||||
def look_at_rotation(
|
def look_at_rotation(
|
||||||
camera_position, at=((0, 0, 0),), up=((0, 1, 0),), device: str = "cpu"
|
camera_position, at=((0, 0, 0),), up=((0, 1, 0),), device: Device = "cpu"
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function takes a vector 'camera_position' which specifies the location
|
This function takes a vector 'camera_position' which specifies the location
|
||||||
@ -1351,7 +1367,7 @@ def look_at_view_transform(
|
|||||||
eye: Optional[Sequence] = None,
|
eye: Optional[Sequence] = None,
|
||||||
at=((0, 0, 0),), # (1, 3)
|
at=((0, 0, 0),), # (1, 3)
|
||||||
up=((0, 1, 0),), # (1, 3)
|
up=((0, 1, 0),), # (1, 3)
|
||||||
device="cpu",
|
device: Device = "cpu",
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
This function returns a rotation and translation matrix
|
This function returns a rotation and translation matrix
|
||||||
|
@ -5,11 +5,13 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from ...common.types import Device
|
from ...common.types import Device
|
||||||
|
from ...structures.meshes import Meshes
|
||||||
from ..blending import (
|
from ..blending import (
|
||||||
BlendParams,
|
BlendParams,
|
||||||
hard_rgb_blend,
|
hard_rgb_blend,
|
||||||
@ -18,6 +20,8 @@ from ..blending import (
|
|||||||
)
|
)
|
||||||
from ..lighting import PointLights
|
from ..lighting import PointLights
|
||||||
from ..materials import Materials
|
from ..materials import Materials
|
||||||
|
from ..utils import TensorProperties
|
||||||
|
from .rasterizer import Fragments
|
||||||
from .shading import flat_shading, gouraud_shading, phong_shading
|
from .shading import flat_shading, gouraud_shading, phong_shading
|
||||||
|
|
||||||
|
|
||||||
@ -47,10 +51,10 @@ class HardPhongShader(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
device: Device = "cpu",
|
device: Device = "cpu",
|
||||||
cameras=None,
|
cameras: Optional[TensorProperties] = None,
|
||||||
lights=None,
|
lights: Optional[TensorProperties] = None,
|
||||||
materials=None,
|
materials: Optional[Materials] = None,
|
||||||
blend_params=None,
|
blend_params: Optional[BlendParams] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lights = lights if lights is not None else PointLights(device=device)
|
self.lights = lights if lights is not None else PointLights(device=device)
|
||||||
@ -62,13 +66,14 @@ class HardPhongShader(nn.Module):
|
|||||||
|
|
||||||
def to(self, device: Device):
|
def to(self, device: Device):
|
||||||
# Manually move to device modules which are not subclasses of nn.Module
|
# Manually move to device modules which are not subclasses of nn.Module
|
||||||
if self.cameras is not None:
|
cameras = self.cameras
|
||||||
self.cameras = self.cameras.to(device)
|
if cameras is not None:
|
||||||
|
self.cameras = cameras.to(device)
|
||||||
self.materials = self.materials.to(device)
|
self.materials = self.materials.to(device)
|
||||||
self.lights = self.lights.to(device)
|
self.lights = self.lights.to(device)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||||
cameras = kwargs.get("cameras", self.cameras)
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
if cameras is None:
|
if cameras is None:
|
||||||
msg = "Cameras must be specified either at initialization \
|
msg = "Cameras must be specified either at initialization \
|
||||||
@ -108,10 +113,10 @@ class SoftPhongShader(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
device: Device = "cpu",
|
device: Device = "cpu",
|
||||||
cameras=None,
|
cameras: Optional[TensorProperties] = None,
|
||||||
lights=None,
|
lights: Optional[TensorProperties] = None,
|
||||||
materials=None,
|
materials: Optional[Materials] = None,
|
||||||
blend_params=None,
|
blend_params: Optional[BlendParams] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lights = lights if lights is not None else PointLights(device=device)
|
self.lights = lights if lights is not None else PointLights(device=device)
|
||||||
@ -123,13 +128,14 @@ class SoftPhongShader(nn.Module):
|
|||||||
|
|
||||||
def to(self, device: Device):
|
def to(self, device: Device):
|
||||||
# Manually move to device modules which are not subclasses of nn.Module
|
# Manually move to device modules which are not subclasses of nn.Module
|
||||||
if self.cameras is not None:
|
cameras = self.cameras
|
||||||
self.cameras = self.cameras.to(device)
|
if cameras is not None:
|
||||||
|
self.cameras = cameras.to(device)
|
||||||
self.materials = self.materials.to(device)
|
self.materials = self.materials.to(device)
|
||||||
self.lights = self.lights.to(device)
|
self.lights = self.lights.to(device)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||||
cameras = kwargs.get("cameras", self.cameras)
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
if cameras is None:
|
if cameras is None:
|
||||||
msg = "Cameras must be specified either at initialization \
|
msg = "Cameras must be specified either at initialization \
|
||||||
@ -174,10 +180,10 @@ class HardGouraudShader(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
device: Device = "cpu",
|
device: Device = "cpu",
|
||||||
cameras=None,
|
cameras: Optional[TensorProperties] = None,
|
||||||
lights=None,
|
lights: Optional[TensorProperties] = None,
|
||||||
materials=None,
|
materials: Optional[Materials] = None,
|
||||||
blend_params=None,
|
blend_params: Optional[BlendParams] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lights = lights if lights is not None else PointLights(device=device)
|
self.lights = lights if lights is not None else PointLights(device=device)
|
||||||
@ -189,13 +195,14 @@ class HardGouraudShader(nn.Module):
|
|||||||
|
|
||||||
def to(self, device: Device):
|
def to(self, device: Device):
|
||||||
# Manually move to device modules which are not subclasses of nn.Module
|
# Manually move to device modules which are not subclasses of nn.Module
|
||||||
if self.cameras is not None:
|
cameras = self.cameras
|
||||||
self.cameras = self.cameras.to(device)
|
if cameras is not None:
|
||||||
|
self.cameras = cameras.to(device)
|
||||||
self.materials = self.materials.to(device)
|
self.materials = self.materials.to(device)
|
||||||
self.lights = self.lights.to(device)
|
self.lights = self.lights.to(device)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||||
cameras = kwargs.get("cameras", self.cameras)
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
if cameras is None:
|
if cameras is None:
|
||||||
msg = "Cameras must be specified either at initialization \
|
msg = "Cameras must be specified either at initialization \
|
||||||
@ -239,10 +246,10 @@ class SoftGouraudShader(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
device: Device = "cpu",
|
device: Device = "cpu",
|
||||||
cameras=None,
|
cameras: Optional[TensorProperties] = None,
|
||||||
lights=None,
|
lights: Optional[TensorProperties] = None,
|
||||||
materials=None,
|
materials: Optional[Materials] = None,
|
||||||
blend_params=None,
|
blend_params: Optional[BlendParams] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lights = lights if lights is not None else PointLights(device=device)
|
self.lights = lights if lights is not None else PointLights(device=device)
|
||||||
@ -254,13 +261,14 @@ class SoftGouraudShader(nn.Module):
|
|||||||
|
|
||||||
def to(self, device: Device):
|
def to(self, device: Device):
|
||||||
# Manually move to device modules which are not subclasses of nn.Module
|
# Manually move to device modules which are not subclasses of nn.Module
|
||||||
if self.cameras is not None:
|
cameras = self.cameras
|
||||||
self.cameras = self.cameras.to(device)
|
if cameras is not None:
|
||||||
|
self.cameras = cameras.to(device)
|
||||||
self.materials = self.materials.to(device)
|
self.materials = self.materials.to(device)
|
||||||
self.lights = self.lights.to(device)
|
self.lights = self.lights.to(device)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||||
cameras = kwargs.get("cameras", self.cameras)
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
if cameras is None:
|
if cameras is None:
|
||||||
msg = "Cameras must be specified either at initialization \
|
msg = "Cameras must be specified either at initialization \
|
||||||
@ -284,7 +292,11 @@ class SoftGouraudShader(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def TexturedSoftPhongShader(
|
def TexturedSoftPhongShader(
|
||||||
device: Device = "cpu", cameras=None, lights=None, materials=None, blend_params=None
|
device: Device = "cpu",
|
||||||
|
cameras: Optional[TensorProperties] = None,
|
||||||
|
lights: Optional[TensorProperties] = None,
|
||||||
|
materials: Optional[Materials] = None,
|
||||||
|
blend_params: Optional[BlendParams] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
TexturedSoftPhongShader class has been DEPRECATED. Use SoftPhongShader instead.
|
TexturedSoftPhongShader class has been DEPRECATED. Use SoftPhongShader instead.
|
||||||
@ -321,10 +333,10 @@ class HardFlatShader(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
device: Device = "cpu",
|
device: Device = "cpu",
|
||||||
cameras=None,
|
cameras: Optional[TensorProperties] = None,
|
||||||
lights=None,
|
lights: Optional[TensorProperties] = None,
|
||||||
materials=None,
|
materials: Optional[Materials] = None,
|
||||||
blend_params=None,
|
blend_params: Optional[BlendParams] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lights = lights if lights is not None else PointLights(device=device)
|
self.lights = lights if lights is not None else PointLights(device=device)
|
||||||
@ -336,13 +348,14 @@ class HardFlatShader(nn.Module):
|
|||||||
|
|
||||||
def to(self, device: Device):
|
def to(self, device: Device):
|
||||||
# Manually move to device modules which are not subclasses of nn.Module
|
# Manually move to device modules which are not subclasses of nn.Module
|
||||||
if self.cameras is not None:
|
cameras = self.cameras
|
||||||
self.cameras = self.cameras.to(device)
|
if cameras is not None:
|
||||||
|
self.cameras = cameras.to(device)
|
||||||
self.materials = self.materials.to(device)
|
self.materials = self.materials.to(device)
|
||||||
self.lights = self.lights.to(device)
|
self.lights = self.lights.to(device)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||||
cameras = kwargs.get("cameras", self.cameras)
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
if cameras is None:
|
if cameras is None:
|
||||||
msg = "Cameras must be specified either at initialization \
|
msg = "Cameras must be specified either at initialization \
|
||||||
@ -381,11 +394,11 @@ class SoftSilhouetteShader(nn.Module):
|
|||||||
3D Reasoning', ICCV 2019
|
3D Reasoning', ICCV 2019
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, blend_params=None) -> None:
|
def __init__(self, blend_params: Optional[BlendParams] = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||||
|
|
||||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Only want to render the silhouette so RGB values can be ones.
|
Only want to render the silhouette so RGB values can be ones.
|
||||||
There is no need for lighting or texturing
|
There is no need for lighting or texturing
|
||||||
|
@ -223,7 +223,7 @@ class TexturesBase:
|
|||||||
|
|
||||||
return new_props
|
return new_props
|
||||||
|
|
||||||
def sample_textures(self):
|
def sample_textures(self) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Different texture classes sample textures in different ways
|
Different texture classes sample textures in different ways
|
||||||
e.g. for vertex textures, the values at each vertex
|
e.g. for vertex textures, the values at each vertex
|
||||||
@ -237,7 +237,7 @@ class TexturesBase:
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def faces_verts_textures_packed(self):
|
def faces_verts_textures_packed(self) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Returns the texture for each vertex for each face in the mesh.
|
Returns the texture for each vertex for each face in the mesh.
|
||||||
For N meshes, this function returns sum(Fi)x3xC where Fi is the
|
For N meshes, this function returns sum(Fi)x3xC where Fi is the
|
||||||
@ -248,14 +248,14 @@ class TexturesBase:
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def clone(self):
|
def clone(self) -> "TexturesBase":
|
||||||
"""
|
"""
|
||||||
Each texture class should implement a method
|
Each texture class should implement a method
|
||||||
to clone all necessary internal tensors.
|
to clone all necessary internal tensors.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def detach(self):
|
def detach(self) -> "TexturesBase":
|
||||||
"""
|
"""
|
||||||
Each texture class should implement a method
|
Each texture class should implement a method
|
||||||
to detach all necessary internal tensors.
|
to detach all necessary internal tensors.
|
||||||
@ -394,7 +394,7 @@ class TexturesAtlas(TexturesBase):
|
|||||||
# refer to the __init__ of Meshes.
|
# refer to the __init__ of Meshes.
|
||||||
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)
|
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)
|
||||||
|
|
||||||
def clone(self):
|
def clone(self) -> "TexturesAtlas":
|
||||||
tex = self.__class__(atlas=self.atlas_padded().clone())
|
tex = self.__class__(atlas=self.atlas_padded().clone())
|
||||||
if self._atlas_list is not None:
|
if self._atlas_list is not None:
|
||||||
tex._atlas_list = [atlas.clone() for atlas in self._atlas_list]
|
tex._atlas_list = [atlas.clone() for atlas in self._atlas_list]
|
||||||
@ -407,7 +407,7 @@ class TexturesAtlas(TexturesBase):
|
|||||||
tex._num_faces_per_mesh = num_faces
|
tex._num_faces_per_mesh = num_faces
|
||||||
return tex
|
return tex
|
||||||
|
|
||||||
def detach(self):
|
def detach(self) -> "TexturesAtlas":
|
||||||
tex = self.__class__(atlas=self.atlas_padded().detach())
|
tex = self.__class__(atlas=self.atlas_padded().detach())
|
||||||
if self._atlas_list is not None:
|
if self._atlas_list is not None:
|
||||||
tex._atlas_list = [atlas.detach() for atlas in self._atlas_list]
|
tex._atlas_list = [atlas.detach() for atlas in self._atlas_list]
|
||||||
|
@ -160,7 +160,7 @@ class TensorProperties(nn.Module):
|
|||||||
msg = "Expected index of type int or slice; got %r"
|
msg = "Expected index of type int or slice; got %r"
|
||||||
raise ValueError(msg % type(index))
|
raise ValueError(msg % type(index))
|
||||||
|
|
||||||
def to(self, device: Device = "cpu"):
|
def to(self, device: Device = "cpu") -> "TensorProperties":
|
||||||
"""
|
"""
|
||||||
In place operation to move class properties which are tensors to a
|
In place operation to move class properties which are tensors to a
|
||||||
specified device. If self has a property "device", update this as well.
|
specified device. If self has a property "device", update this as well.
|
||||||
@ -174,7 +174,7 @@ class TensorProperties(nn.Module):
|
|||||||
setattr(self, k, v.to(device_))
|
setattr(self, k, v.to(device_))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def clone(self, other):
|
def clone(self, other) -> "TensorProperties":
|
||||||
"""
|
"""
|
||||||
Update the tensor properties of other with the cloned properties of self.
|
Update the tensor properties of other with the cloned properties of self.
|
||||||
"""
|
"""
|
||||||
@ -189,7 +189,7 @@ class TensorProperties(nn.Module):
|
|||||||
setattr(other, k, v_clone)
|
setattr(other, k, v_clone)
|
||||||
return other
|
return other
|
||||||
|
|
||||||
def gather_props(self, batch_idx):
|
def gather_props(self, batch_idx) -> "TensorProperties":
|
||||||
"""
|
"""
|
||||||
This is an in place operation to reformat all tensor class attributes
|
This is an in place operation to reformat all tensor class attributes
|
||||||
based on a set of given indices using torch.gather. This is useful when
|
based on a set of given indices using torch.gather. This is useful when
|
||||||
|
@ -187,15 +187,15 @@ class Volumes:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# handle densities
|
# handle densities
|
||||||
densities, grid_sizes = self._convert_densities_features_to_tensor(
|
densities_, grid_sizes = self._convert_densities_features_to_tensor(
|
||||||
densities, "densities"
|
densities, "densities"
|
||||||
)
|
)
|
||||||
|
|
||||||
# take device from densities
|
# take device from densities
|
||||||
self.device = densities.device
|
self.device = densities_.device
|
||||||
|
|
||||||
# assign to the internal buffers
|
# assign to the internal buffers
|
||||||
self._densities = densities
|
self._densities = densities_
|
||||||
self._grid_sizes = grid_sizes
|
self._grid_sizes = grid_sizes
|
||||||
|
|
||||||
# handle features
|
# handle features
|
||||||
@ -497,7 +497,6 @@ class Volumes:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
# pyre-fixme[16]: `List` has no attribute `shape`.
|
|
||||||
return self._densities.shape[0]
|
return self._densities.shape[0]
|
||||||
|
|
||||||
def __getitem__(
|
def __getitem__(
|
||||||
@ -547,8 +546,6 @@ class Volumes:
|
|||||||
Returns:
|
Returns:
|
||||||
**densities**: The tensor of volume densities.
|
**densities**: The tensor of volume densities.
|
||||||
"""
|
"""
|
||||||
# pyre-fixme[7]: Expected `Tensor` but got `Union[List[torch.Tensor],
|
|
||||||
# torch.Tensor]`.
|
|
||||||
return self._densities
|
return self._densities
|
||||||
|
|
||||||
def densities_list(self) -> List[torch.Tensor]:
|
def densities_list(self) -> List[torch.Tensor]:
|
||||||
@ -723,7 +720,6 @@ class Volumes:
|
|||||||
return other
|
return other
|
||||||
|
|
||||||
other.device = device_
|
other.device = device_
|
||||||
# pyre-fixme[16]: `List` has no attribute `to`.
|
|
||||||
other._densities = self._densities.to(device_)
|
other._densities = self._densities.to(device_)
|
||||||
if self._features is not None:
|
if self._features is not None:
|
||||||
# pyre-fixme[16]: `Optional` has no attribute `to`.
|
# pyre-fixme[16]: `Optional` has no attribute `to`.
|
||||||
|
@ -10,6 +10,8 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from ..common.types import Device
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
The transformation matrices returned from the functions in this file assume
|
The transformation matrices returned from the functions in this file assume
|
||||||
@ -286,7 +288,9 @@ def matrix_to_euler_angles(matrix, convention: str):
|
|||||||
return torch.stack(o, -1)
|
return torch.stack(o, -1)
|
||||||
|
|
||||||
|
|
||||||
def random_quaternions(n: int, dtype: Optional[torch.dtype] = None, device=None):
|
def random_quaternions(
|
||||||
|
n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generate random quaternions representing rotations,
|
Generate random quaternions representing rotations,
|
||||||
i.e. versors with nonnegative real part.
|
i.e. versors with nonnegative real part.
|
||||||
@ -306,7 +310,9 @@ def random_quaternions(n: int, dtype: Optional[torch.dtype] = None, device=None)
|
|||||||
return o
|
return o
|
||||||
|
|
||||||
|
|
||||||
def random_rotations(n: int, dtype: Optional[torch.dtype] = None, device=None):
|
def random_rotations(
|
||||||
|
n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generate random rotations as 3x3 rotation matrices.
|
Generate random rotations as 3x3 rotation matrices.
|
||||||
|
|
||||||
@ -323,7 +329,9 @@ def random_rotations(n: int, dtype: Optional[torch.dtype] = None, device=None):
|
|||||||
return quaternion_to_matrix(quaternions)
|
return quaternion_to_matrix(quaternions)
|
||||||
|
|
||||||
|
|
||||||
def random_rotation(dtype: Optional[torch.dtype] = None, device=None):
|
def random_rotation(
|
||||||
|
dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generate a single random 3x3 rotation matrix.
|
Generate a single random 3x3 rotation matrix.
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user