mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
voxel grid implicit function
Summary: The implicit function and its members and internal working Reviewed By: kjchalup Differential Revision: D38829764 fbshipit-source-id: 28394fe7819e311ed52c9defc9a1b29f37fbc495
This commit is contained in:
parent
d6a197be36
commit
c2d876c9e8
@ -18,6 +18,8 @@ from typing import Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
Configurable,
|
Configurable,
|
||||||
registry,
|
registry,
|
||||||
@ -179,8 +181,11 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
|
|||||||
class MLPDecoder(DecoderFunctionBase):
|
class MLPDecoder(DecoderFunctionBase):
|
||||||
"""
|
"""
|
||||||
Decoding function which uses `MLPWithIputSkips` to convert the embedding to output.
|
Decoding function which uses `MLPWithIputSkips` to convert the embedding to output.
|
||||||
|
If using Implicitron config system `input_dim` of the `network` is changed to the
|
||||||
|
value of `input_dim` member and `input_skips` is removed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
input_dim: int = 3
|
||||||
network: MLPWithInputSkips
|
network: MLPWithInputSkips
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@ -192,6 +197,20 @@ class MLPDecoder(DecoderFunctionBase):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.network(features, z)
|
return self.network(features, z)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def network_tweak_args(cls, type, args: DictConfig) -> None:
|
||||||
|
"""
|
||||||
|
Special method to stop get_default_args exposing member's `input_dim`.
|
||||||
|
"""
|
||||||
|
args.pop("input_dim", None)
|
||||||
|
|
||||||
|
def create_network_impl(self, type, args: DictConfig) -> None:
|
||||||
|
"""
|
||||||
|
Set the input dimension of the `network` to the input dimension of the
|
||||||
|
decoding function.
|
||||||
|
"""
|
||||||
|
self.network = MLPWithInputSkips(input_dim=self.input_dim, **args)
|
||||||
|
|
||||||
|
|
||||||
class TransformerWithInputSkips(torch.nn.Module):
|
class TransformerWithInputSkips(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -65,7 +65,7 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
|||||||
padding: str = "zeros"
|
padding: str = "zeros"
|
||||||
mode: str = "bilinear"
|
mode: str = "bilinear"
|
||||||
n_features: int = 1
|
n_features: int = 1
|
||||||
resolution: Tuple[int, int, int] = (64, 64, 64)
|
resolution: Tuple[int, int, int] = (128, 128, 128)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -507,8 +507,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
|||||||
voxel_grid_class_type: str = "FullResolutionVoxelGrid"
|
voxel_grid_class_type: str = "FullResolutionVoxelGrid"
|
||||||
voxel_grid: VoxelGridBase
|
voxel_grid: VoxelGridBase
|
||||||
|
|
||||||
# pyre-fixme[8]: Attribute has type `Tuple[float, float, float]`; used as `float`.
|
extents: Tuple[float, float, float] = (1.0, 1.0, 1.0)
|
||||||
extents: Tuple[float, float, float] = 1.0
|
|
||||||
translation: Tuple[float, float, float] = (0.0, 0.0, 0.0)
|
translation: Tuple[float, float, float] = (0.0, 0.0, 0.0)
|
||||||
|
|
||||||
init_std: float = 0.1
|
init_std: float = 0.1
|
||||||
@ -552,9 +551,9 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
|||||||
grid_sizes=(2, 2, 2),
|
grid_sizes=(2, 2, 2),
|
||||||
# The locator object uses (x, y, z) convention for the
|
# The locator object uses (x, y, z) convention for the
|
||||||
# voxel size and translation.
|
# voxel size and translation.
|
||||||
voxel_size=self.extents,
|
voxel_size=tuple(self.extents),
|
||||||
volume_translation=self.translation,
|
volume_translation=tuple(self.translation),
|
||||||
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C._TensorBase...
|
# pyre-ignore[29]
|
||||||
device=next(self.params.values()).device,
|
device=next(self.params.values()).device,
|
||||||
)
|
)
|
||||||
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
|
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
|
||||||
@ -562,3 +561,18 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
|||||||
grid_values = self.voxel_grid.values_type(**self.params)
|
grid_values = self.voxel_grid.values_type(**self.params)
|
||||||
# voxel grids operate with extra n_grids dimension, which we fix to one
|
# voxel grids operate with extra n_grids dimension, which we fix to one
|
||||||
return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]
|
return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_output_dim(args: DictConfig) -> int:
|
||||||
|
"""
|
||||||
|
Utility to help predict the shape of the output of `forward`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: DictConfig which would be used to initialize the object
|
||||||
|
Returns:
|
||||||
|
int: the length of the last dimension of the output tensor
|
||||||
|
"""
|
||||||
|
grid = registry.get(VoxelGridBase, args["voxel_grid_class_type"])
|
||||||
|
return grid.get_output_dim(
|
||||||
|
args["voxel_grid_" + args["voxel_grid_class_type"] + "_args"]
|
||||||
|
)
|
||||||
|
@ -88,7 +88,7 @@ class HarmonicEmbedding(torch.nn.Module):
|
|||||||
embedding: a harmonic embedding of `x`
|
embedding: a harmonic embedding of `x`
|
||||||
of shape [..., (n_harmonic_functions * 2 + int(append_input)) * dim]
|
of shape [..., (n_harmonic_functions * 2 + int(append_input)) * dim]
|
||||||
"""
|
"""
|
||||||
embed = (x[..., None] * self._frequencies).view(*x.shape[:-1], -1)
|
embed = (x[..., None] * self._frequencies).reshape(*x.shape[:-1], -1)
|
||||||
embed = torch.cat(
|
embed = torch.cat(
|
||||||
(embed.sin(), embed.cos(), x)
|
(embed.sin(), embed.cos(), x)
|
||||||
if self.append_input
|
if self.append_input
|
||||||
|
Loading…
x
Reference in New Issue
Block a user