mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
voxel grid implicit function
Summary: The implicit function and its members and internal working Reviewed By: kjchalup Differential Revision: D38829764 fbshipit-source-id: 28394fe7819e311ed52c9defc9a1b29f37fbc495
This commit is contained in:
parent
d6a197be36
commit
c2d876c9e8
@ -18,6 +18,8 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
Configurable,
|
||||
registry,
|
||||
@ -179,8 +181,11 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
|
||||
class MLPDecoder(DecoderFunctionBase):
|
||||
"""
|
||||
Decoding function which uses `MLPWithIputSkips` to convert the embedding to output.
|
||||
If using Implicitron config system `input_dim` of the `network` is changed to the
|
||||
value of `input_dim` member and `input_skips` is removed.
|
||||
"""
|
||||
|
||||
input_dim: int = 3
|
||||
network: MLPWithInputSkips
|
||||
|
||||
def __post_init__(self):
|
||||
@ -192,6 +197,20 @@ class MLPDecoder(DecoderFunctionBase):
|
||||
) -> torch.Tensor:
|
||||
return self.network(features, z)
|
||||
|
||||
@classmethod
|
||||
def network_tweak_args(cls, type, args: DictConfig) -> None:
|
||||
"""
|
||||
Special method to stop get_default_args exposing member's `input_dim`.
|
||||
"""
|
||||
args.pop("input_dim", None)
|
||||
|
||||
def create_network_impl(self, type, args: DictConfig) -> None:
|
||||
"""
|
||||
Set the input dimension of the `network` to the input dimension of the
|
||||
decoding function.
|
||||
"""
|
||||
self.network = MLPWithInputSkips(input_dim=self.input_dim, **args)
|
||||
|
||||
|
||||
class TransformerWithInputSkips(torch.nn.Module):
|
||||
def __init__(
|
||||
|
@ -65,7 +65,7 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
padding: str = "zeros"
|
||||
mode: str = "bilinear"
|
||||
n_features: int = 1
|
||||
resolution: Tuple[int, int, int] = (64, 64, 64)
|
||||
resolution: Tuple[int, int, int] = (128, 128, 128)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
@ -507,8 +507,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
voxel_grid_class_type: str = "FullResolutionVoxelGrid"
|
||||
voxel_grid: VoxelGridBase
|
||||
|
||||
# pyre-fixme[8]: Attribute has type `Tuple[float, float, float]`; used as `float`.
|
||||
extents: Tuple[float, float, float] = 1.0
|
||||
extents: Tuple[float, float, float] = (1.0, 1.0, 1.0)
|
||||
translation: Tuple[float, float, float] = (0.0, 0.0, 0.0)
|
||||
|
||||
init_std: float = 0.1
|
||||
@ -552,9 +551,9 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
grid_sizes=(2, 2, 2),
|
||||
# The locator object uses (x, y, z) convention for the
|
||||
# voxel size and translation.
|
||||
voxel_size=self.extents,
|
||||
volume_translation=self.translation,
|
||||
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C._TensorBase...
|
||||
voxel_size=tuple(self.extents),
|
||||
volume_translation=tuple(self.translation),
|
||||
# pyre-ignore[29]
|
||||
device=next(self.params.values()).device,
|
||||
)
|
||||
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
|
||||
@ -562,3 +561,18 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
grid_values = self.voxel_grid.values_type(**self.params)
|
||||
# voxel grids operate with extra n_grids dimension, which we fix to one
|
||||
return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]
|
||||
|
||||
@staticmethod
|
||||
def get_output_dim(args: DictConfig) -> int:
|
||||
"""
|
||||
Utility to help predict the shape of the output of `forward`.
|
||||
|
||||
Args:
|
||||
args: DictConfig which would be used to initialize the object
|
||||
Returns:
|
||||
int: the length of the last dimension of the output tensor
|
||||
"""
|
||||
grid = registry.get(VoxelGridBase, args["voxel_grid_class_type"])
|
||||
return grid.get_output_dim(
|
||||
args["voxel_grid_" + args["voxel_grid_class_type"] + "_args"]
|
||||
)
|
||||
|
@ -88,7 +88,7 @@ class HarmonicEmbedding(torch.nn.Module):
|
||||
embedding: a harmonic embedding of `x`
|
||||
of shape [..., (n_harmonic_functions * 2 + int(append_input)) * dim]
|
||||
"""
|
||||
embed = (x[..., None] * self._frequencies).view(*x.shape[:-1], -1)
|
||||
embed = (x[..., None] * self._frequencies).reshape(*x.shape[:-1], -1)
|
||||
embed = torch.cat(
|
||||
(embed.sin(), embed.cos(), x)
|
||||
if self.append_input
|
||||
|
Loading…
x
Reference in New Issue
Block a user