mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Moved MLP and Transformer
Summary: Moved the MLP and transformer from nerf to a new file to be reused. Reviewed By: bottler Differential Revision: D38828150 fbshipit-source-id: 8ff77b18b3aeeda398d90758a7bcb2482edce66f
This commit is contained in:
parent
edee25a1e5
commit
898ba5c53b
@ -0,0 +1,315 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MLPWithInputSkips(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Implements the multi-layer perceptron architecture of the Neural Radiance Field.
|
||||||
|
|
||||||
|
As such, `MLPWithInputSkips` is a multi layer perceptron consisting
|
||||||
|
of a sequence of linear layers with ReLU activations.
|
||||||
|
|
||||||
|
Additionally, for a set of predefined layers `input_skips`, the forward pass
|
||||||
|
appends a skip tensor `z` to the output of the preceding layer.
|
||||||
|
|
||||||
|
Note that this follows the architecture described in the Supplementary
|
||||||
|
Material (Fig. 7) of [1].
|
||||||
|
|
||||||
|
References:
|
||||||
|
[1] Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik
|
||||||
|
and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng:
|
||||||
|
NeRF: Representing Scenes as Neural Radiance Fields for View
|
||||||
|
Synthesis, ECCV2020
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _make_affine_layer(self, input_dim, hidden_dim):
|
||||||
|
l1 = torch.nn.Linear(input_dim, hidden_dim * 2)
|
||||||
|
l2 = torch.nn.Linear(hidden_dim * 2, hidden_dim * 2)
|
||||||
|
_xavier_init(l1)
|
||||||
|
_xavier_init(l2)
|
||||||
|
return torch.nn.Sequential(l1, torch.nn.ReLU(True), l2)
|
||||||
|
|
||||||
|
def _apply_affine_layer(self, layer, x, z):
|
||||||
|
mu_log_std = layer(z)
|
||||||
|
mu, log_std = mu_log_std.split(mu_log_std.shape[-1] // 2, dim=-1)
|
||||||
|
std = torch.nn.functional.softplus(log_std)
|
||||||
|
return (x - mu) * std
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_layers: int = 8,
|
||||||
|
input_dim: int = 39,
|
||||||
|
output_dim: int = 256,
|
||||||
|
skip_dim: int = 39,
|
||||||
|
hidden_dim: int = 256,
|
||||||
|
input_skips: Tuple[int, ...] = (5,),
|
||||||
|
skip_affine_trans: bool = False,
|
||||||
|
no_last_relu=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
n_layers: The number of linear layers of the MLP.
|
||||||
|
input_dim: The number of channels of the input tensor.
|
||||||
|
output_dim: The number of channels of the output.
|
||||||
|
skip_dim: The number of channels of the tensor `z` appended when
|
||||||
|
evaluating the skip layers.
|
||||||
|
hidden_dim: The number of hidden units of the MLP.
|
||||||
|
input_skips: The list of layer indices at which we append the skip
|
||||||
|
tensor `z`.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
layers = []
|
||||||
|
skip_affine_layers = []
|
||||||
|
for layeri in range(n_layers):
|
||||||
|
dimin = hidden_dim if layeri > 0 else input_dim
|
||||||
|
dimout = hidden_dim if layeri + 1 < n_layers else output_dim
|
||||||
|
|
||||||
|
if layeri > 0 and layeri in input_skips:
|
||||||
|
if skip_affine_trans:
|
||||||
|
skip_affine_layers.append(
|
||||||
|
self._make_affine_layer(skip_dim, hidden_dim)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
dimin = hidden_dim + skip_dim
|
||||||
|
|
||||||
|
linear = torch.nn.Linear(dimin, dimout)
|
||||||
|
_xavier_init(linear)
|
||||||
|
layers.append(
|
||||||
|
torch.nn.Sequential(linear, torch.nn.ReLU(True))
|
||||||
|
if not no_last_relu or layeri + 1 < n_layers
|
||||||
|
else linear
|
||||||
|
)
|
||||||
|
self.mlp = torch.nn.ModuleList(layers)
|
||||||
|
if skip_affine_trans:
|
||||||
|
self.skip_affines = torch.nn.ModuleList(skip_affine_layers)
|
||||||
|
self._input_skips = set(input_skips)
|
||||||
|
self._skip_affine_trans = skip_affine_trans
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, z: Optional[torch.Tensor] = None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: The input tensor of shape `(..., input_dim)`.
|
||||||
|
z: The input skip tensor of shape `(..., skip_dim)` which is appended
|
||||||
|
to layers whose indices are specified by `input_skips`.
|
||||||
|
Returns:
|
||||||
|
y: The output tensor of shape `(..., output_dim)`.
|
||||||
|
"""
|
||||||
|
y = x
|
||||||
|
if z is None:
|
||||||
|
# if the skip tensor is None, we use `x` instead.
|
||||||
|
z = x
|
||||||
|
skipi = 0
|
||||||
|
for li, layer in enumerate(self.mlp):
|
||||||
|
if li in self._input_skips:
|
||||||
|
if self._skip_affine_trans:
|
||||||
|
y = self._apply_affine_layer(self.skip_affines[skipi], y, z)
|
||||||
|
else:
|
||||||
|
y = torch.cat((y, z), dim=-1)
|
||||||
|
skipi += 1
|
||||||
|
y = layer(y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerWithInputSkips(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_layers: int = 8,
|
||||||
|
input_dim: int = 39,
|
||||||
|
output_dim: int = 256,
|
||||||
|
skip_dim: int = 39,
|
||||||
|
hidden_dim: int = 64,
|
||||||
|
input_skips: Tuple[int, ...] = (5,),
|
||||||
|
dim_down_factor: float = 1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
n_layers: The number of linear layers of the MLP.
|
||||||
|
input_dim: The number of channels of the input tensor.
|
||||||
|
output_dim: The number of channels of the output.
|
||||||
|
skip_dim: The number of channels of the tensor `z` appended when
|
||||||
|
evaluating the skip layers.
|
||||||
|
hidden_dim: The number of hidden units of the MLP.
|
||||||
|
input_skips: The list of layer indices at which we append the skip
|
||||||
|
tensor `z`.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.first = torch.nn.Linear(input_dim, hidden_dim)
|
||||||
|
_xavier_init(self.first)
|
||||||
|
|
||||||
|
self.skip_linear = torch.nn.ModuleList()
|
||||||
|
|
||||||
|
layers_pool, layers_ray = [], []
|
||||||
|
dimout = 0
|
||||||
|
for layeri in range(n_layers):
|
||||||
|
dimin = int(round(hidden_dim / (dim_down_factor**layeri)))
|
||||||
|
dimout = int(round(hidden_dim / (dim_down_factor ** (layeri + 1))))
|
||||||
|
logger.info(f"Tr: {dimin} -> {dimout}")
|
||||||
|
for _i, l in enumerate((layers_pool, layers_ray)):
|
||||||
|
l.append(
|
||||||
|
TransformerEncoderLayer(
|
||||||
|
d_model=[dimin, dimout][_i],
|
||||||
|
nhead=4,
|
||||||
|
dim_feedforward=hidden_dim,
|
||||||
|
dropout=0.0,
|
||||||
|
d_model_out=dimout,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if layeri in input_skips:
|
||||||
|
self.skip_linear.append(torch.nn.Linear(input_dim, dimin))
|
||||||
|
|
||||||
|
self.last = torch.nn.Linear(dimout, output_dim)
|
||||||
|
_xavier_init(self.last)
|
||||||
|
|
||||||
|
# pyre-fixme[8]: Attribute has type `Tuple[ModuleList, ModuleList]`; used as
|
||||||
|
# `ModuleList`.
|
||||||
|
self.layers_pool, self.layers_ray = (
|
||||||
|
torch.nn.ModuleList(layers_pool),
|
||||||
|
torch.nn.ModuleList(layers_ray),
|
||||||
|
)
|
||||||
|
self._input_skips = set(input_skips)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
z: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: The input tensor of shape
|
||||||
|
`(minibatch, n_pooled_feats, ..., n_ray_pts, input_dim)`.
|
||||||
|
z: The input skip tensor of shape
|
||||||
|
`(minibatch, n_pooled_feats, ..., n_ray_pts, skip_dim)`
|
||||||
|
which is appended to layers whose indices are specified by `input_skips`.
|
||||||
|
Returns:
|
||||||
|
y: The output tensor of shape
|
||||||
|
`(minibatch, 1, ..., n_ray_pts, input_dim)`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if z is None:
|
||||||
|
# if the skip tensor is None, we use `x` instead.
|
||||||
|
z = x
|
||||||
|
|
||||||
|
y = self.first(x)
|
||||||
|
|
||||||
|
B, n_pool, n_rays, n_pts, dim = y.shape
|
||||||
|
|
||||||
|
# y_p in n_pool, n_pts, B x n_rays x dim
|
||||||
|
y_p = y.permute(1, 3, 0, 2, 4)
|
||||||
|
|
||||||
|
skipi = 0
|
||||||
|
dimh = dim
|
||||||
|
for li, (layer_pool, layer_ray) in enumerate(
|
||||||
|
zip(self.layers_pool, self.layers_ray)
|
||||||
|
):
|
||||||
|
y_pool_attn = y_p.reshape(n_pool, n_pts * B * n_rays, dimh)
|
||||||
|
if li in self._input_skips:
|
||||||
|
z_skip = self.skip_linear[skipi](z)
|
||||||
|
y_pool_attn = y_pool_attn + z_skip.permute(1, 3, 0, 2, 4).reshape(
|
||||||
|
n_pool, n_pts * B * n_rays, dimh
|
||||||
|
)
|
||||||
|
skipi += 1
|
||||||
|
# n_pool x B*n_rays*n_pts x dim
|
||||||
|
y_pool_attn, pool_attn = layer_pool(y_pool_attn, src_key_padding_mask=None)
|
||||||
|
dimh = y_pool_attn.shape[-1]
|
||||||
|
|
||||||
|
y_ray_attn = (
|
||||||
|
y_pool_attn.view(n_pool, n_pts, B * n_rays, dimh)
|
||||||
|
.permute(1, 0, 2, 3)
|
||||||
|
.reshape(n_pts, n_pool * B * n_rays, dimh)
|
||||||
|
)
|
||||||
|
# n_pts x n_pool*B*n_rays x dim
|
||||||
|
y_ray_attn, ray_attn = layer_ray(
|
||||||
|
y_ray_attn,
|
||||||
|
src_key_padding_mask=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
y_p = y_ray_attn.view(n_pts, n_pool, B * n_rays, dimh).permute(1, 0, 2, 3)
|
||||||
|
|
||||||
|
y = y_p.view(n_pool, n_pts, B, n_rays, dimh).permute(2, 0, 3, 1, 4)
|
||||||
|
|
||||||
|
W = torch.softmax(y[..., :1], dim=1)
|
||||||
|
y = (y * W).sum(dim=1)
|
||||||
|
y = self.last(y)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoderLayer(torch.nn.Module):
|
||||||
|
r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
|
||||||
|
This standard encoder layer is based on the paper "Attention Is All You Need".
|
||||||
|
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
||||||
|
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
||||||
|
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
||||||
|
in a different way during application.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
d_model: the number of expected features in the input (required).
|
||||||
|
nhead: the number of heads in the multiheadattention models (required).
|
||||||
|
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
||||||
|
dropout: the dropout value (default=0.1).
|
||||||
|
activation: the activation function of intermediate layer, relu or gelu (default=relu).
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
||||||
|
>>> src = torch.rand(10, 32, 512)
|
||||||
|
>>> out = encoder_layer(src)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, d_model_out=-1
|
||||||
|
):
|
||||||
|
super(TransformerEncoderLayer, self).__init__()
|
||||||
|
self.self_attn = torch.nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||||
|
# Implementation of Feedforward model
|
||||||
|
self.linear1 = torch.nn.Linear(d_model, dim_feedforward)
|
||||||
|
self.dropout = torch.nn.Dropout(dropout)
|
||||||
|
d_model_out = d_model if d_model_out <= 0 else d_model_out
|
||||||
|
self.linear2 = torch.nn.Linear(dim_feedforward, d_model_out)
|
||||||
|
self.norm1 = torch.nn.LayerNorm(d_model)
|
||||||
|
self.norm2 = torch.nn.LayerNorm(d_model_out)
|
||||||
|
self.dropout1 = torch.nn.Dropout(dropout)
|
||||||
|
self.dropout2 = torch.nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.activation = torch.nn.functional.relu
|
||||||
|
|
||||||
|
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
||||||
|
r"""Pass the input through the encoder layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src: the sequence to the encoder layer (required).
|
||||||
|
src_mask: the mask for the src sequence (optional).
|
||||||
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
see the docs in Transformer class.
|
||||||
|
"""
|
||||||
|
src2, attn = self.self_attn(
|
||||||
|
src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
||||||
|
)
|
||||||
|
src = src + self.dropout1(src2)
|
||||||
|
src = self.norm1(src)
|
||||||
|
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
||||||
|
d_out = src2.shape[-1]
|
||||||
|
src = src[..., :d_out] + self.dropout2(src2)[..., :d_out]
|
||||||
|
src = self.norm2(src)
|
||||||
|
return src, attn
|
||||||
|
|
||||||
|
|
||||||
|
def _xavier_init(linear) -> None:
|
||||||
|
"""
|
||||||
|
Performs the Xavier weight initialization of the linear layer `linear`.
|
||||||
|
"""
|
||||||
|
torch.nn.init.xavier_uniform_(linear.weight.data)
|
@ -15,6 +15,12 @@ from pytorch3d.renderer.cameras import CamerasBase
|
|||||||
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
||||||
|
|
||||||
from .base import ImplicitFunctionBase
|
from .base import ImplicitFunctionBase
|
||||||
|
|
||||||
|
from .decoding_functions import ( # noqa
|
||||||
|
_xavier_init,
|
||||||
|
MLPWithInputSkips,
|
||||||
|
TransformerWithInputSkips,
|
||||||
|
)
|
||||||
from .utils import create_embeddings_for_implicit_function
|
from .utils import create_embeddings_for_implicit_function
|
||||||
|
|
||||||
|
|
||||||
@ -243,305 +249,3 @@ class NeRFormerImplicitFunction(NeuralRadianceFieldBase):
|
|||||||
pooling without aggregation. Overridden from ImplicitFunctionBase.
|
pooling without aggregation. Overridden from ImplicitFunctionBase.
|
||||||
"""
|
"""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class MLPWithInputSkips(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
Implements the multi-layer perceptron architecture of the Neural Radiance Field.
|
|
||||||
|
|
||||||
As such, `MLPWithInputSkips` is a multi layer perceptron consisting
|
|
||||||
of a sequence of linear layers with ReLU activations.
|
|
||||||
|
|
||||||
Additionally, for a set of predefined layers `input_skips`, the forward pass
|
|
||||||
appends a skip tensor `z` to the output of the preceding layer.
|
|
||||||
|
|
||||||
Note that this follows the architecture described in the Supplementary
|
|
||||||
Material (Fig. 7) of [1].
|
|
||||||
|
|
||||||
References:
|
|
||||||
[1] Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik
|
|
||||||
and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng:
|
|
||||||
NeRF: Representing Scenes as Neural Radiance Fields for View
|
|
||||||
Synthesis, ECCV2020
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _make_affine_layer(self, input_dim, hidden_dim):
|
|
||||||
l1 = torch.nn.Linear(input_dim, hidden_dim * 2)
|
|
||||||
l2 = torch.nn.Linear(hidden_dim * 2, hidden_dim * 2)
|
|
||||||
_xavier_init(l1)
|
|
||||||
_xavier_init(l2)
|
|
||||||
return torch.nn.Sequential(l1, torch.nn.ReLU(True), l2)
|
|
||||||
|
|
||||||
def _apply_affine_layer(self, layer, x, z):
|
|
||||||
mu_log_std = layer(z)
|
|
||||||
mu, log_std = mu_log_std.split(mu_log_std.shape[-1] // 2, dim=-1)
|
|
||||||
std = torch.nn.functional.softplus(log_std)
|
|
||||||
return (x - mu) * std
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
n_layers: int = 8,
|
|
||||||
input_dim: int = 39,
|
|
||||||
output_dim: int = 256,
|
|
||||||
skip_dim: int = 39,
|
|
||||||
hidden_dim: int = 256,
|
|
||||||
input_skips: Tuple[int, ...] = (5,),
|
|
||||||
skip_affine_trans: bool = False,
|
|
||||||
no_last_relu=False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
n_layers: The number of linear layers of the MLP.
|
|
||||||
input_dim: The number of channels of the input tensor.
|
|
||||||
output_dim: The number of channels of the output.
|
|
||||||
skip_dim: The number of channels of the tensor `z` appended when
|
|
||||||
evaluating the skip layers.
|
|
||||||
hidden_dim: The number of hidden units of the MLP.
|
|
||||||
input_skips: The list of layer indices at which we append the skip
|
|
||||||
tensor `z`.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
layers = []
|
|
||||||
skip_affine_layers = []
|
|
||||||
for layeri in range(n_layers):
|
|
||||||
dimin = hidden_dim if layeri > 0 else input_dim
|
|
||||||
dimout = hidden_dim if layeri + 1 < n_layers else output_dim
|
|
||||||
|
|
||||||
if layeri > 0 and layeri in input_skips:
|
|
||||||
if skip_affine_trans:
|
|
||||||
skip_affine_layers.append(
|
|
||||||
self._make_affine_layer(skip_dim, hidden_dim)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
dimin = hidden_dim + skip_dim
|
|
||||||
|
|
||||||
linear = torch.nn.Linear(dimin, dimout)
|
|
||||||
_xavier_init(linear)
|
|
||||||
layers.append(
|
|
||||||
torch.nn.Sequential(linear, torch.nn.ReLU(True))
|
|
||||||
if not no_last_relu or layeri + 1 < n_layers
|
|
||||||
else linear
|
|
||||||
)
|
|
||||||
self.mlp = torch.nn.ModuleList(layers)
|
|
||||||
if skip_affine_trans:
|
|
||||||
self.skip_affines = torch.nn.ModuleList(skip_affine_layers)
|
|
||||||
self._input_skips = set(input_skips)
|
|
||||||
self._skip_affine_trans = skip_affine_trans
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, z: Optional[torch.Tensor] = None):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x: The input tensor of shape `(..., input_dim)`.
|
|
||||||
z: The input skip tensor of shape `(..., skip_dim)` which is appended
|
|
||||||
to layers whose indices are specified by `input_skips`.
|
|
||||||
Returns:
|
|
||||||
y: The output tensor of shape `(..., output_dim)`.
|
|
||||||
"""
|
|
||||||
y = x
|
|
||||||
if z is None:
|
|
||||||
# if the skip tensor is None, we use `x` instead.
|
|
||||||
z = x
|
|
||||||
skipi = 0
|
|
||||||
for li, layer in enumerate(self.mlp):
|
|
||||||
if li in self._input_skips:
|
|
||||||
if self._skip_affine_trans:
|
|
||||||
y = self._apply_affine_layer(self.skip_affines[skipi], y, z)
|
|
||||||
else:
|
|
||||||
y = torch.cat((y, z), dim=-1)
|
|
||||||
skipi += 1
|
|
||||||
y = layer(y)
|
|
||||||
return y
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerWithInputSkips(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
n_layers: int = 8,
|
|
||||||
input_dim: int = 39,
|
|
||||||
output_dim: int = 256,
|
|
||||||
skip_dim: int = 39,
|
|
||||||
hidden_dim: int = 64,
|
|
||||||
input_skips: Tuple[int, ...] = (5,),
|
|
||||||
dim_down_factor: float = 1,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
n_layers: The number of linear layers of the MLP.
|
|
||||||
input_dim: The number of channels of the input tensor.
|
|
||||||
output_dim: The number of channels of the output.
|
|
||||||
skip_dim: The number of channels of the tensor `z` appended when
|
|
||||||
evaluating the skip layers.
|
|
||||||
hidden_dim: The number of hidden units of the MLP.
|
|
||||||
input_skips: The list of layer indices at which we append the skip
|
|
||||||
tensor `z`.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.first = torch.nn.Linear(input_dim, hidden_dim)
|
|
||||||
_xavier_init(self.first)
|
|
||||||
|
|
||||||
self.skip_linear = torch.nn.ModuleList()
|
|
||||||
|
|
||||||
layers_pool, layers_ray = [], []
|
|
||||||
dimout = 0
|
|
||||||
for layeri in range(n_layers):
|
|
||||||
dimin = int(round(hidden_dim / (dim_down_factor**layeri)))
|
|
||||||
dimout = int(round(hidden_dim / (dim_down_factor ** (layeri + 1))))
|
|
||||||
logger.info(f"Tr: {dimin} -> {dimout}")
|
|
||||||
for _i, l in enumerate((layers_pool, layers_ray)):
|
|
||||||
l.append(
|
|
||||||
TransformerEncoderLayer(
|
|
||||||
d_model=[dimin, dimout][_i],
|
|
||||||
nhead=4,
|
|
||||||
dim_feedforward=hidden_dim,
|
|
||||||
dropout=0.0,
|
|
||||||
d_model_out=dimout,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if layeri in input_skips:
|
|
||||||
self.skip_linear.append(torch.nn.Linear(input_dim, dimin))
|
|
||||||
|
|
||||||
self.last = torch.nn.Linear(dimout, output_dim)
|
|
||||||
_xavier_init(self.last)
|
|
||||||
|
|
||||||
# pyre-fixme[8]: Attribute has type `Tuple[ModuleList, ModuleList]`; used as
|
|
||||||
# `ModuleList`.
|
|
||||||
self.layers_pool, self.layers_ray = (
|
|
||||||
torch.nn.ModuleList(layers_pool),
|
|
||||||
torch.nn.ModuleList(layers_ray),
|
|
||||||
)
|
|
||||||
self._input_skips = set(input_skips)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
z: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x: The input tensor of shape
|
|
||||||
`(minibatch, n_pooled_feats, ..., n_ray_pts, input_dim)`.
|
|
||||||
z: The input skip tensor of shape
|
|
||||||
`(minibatch, n_pooled_feats, ..., n_ray_pts, skip_dim)`
|
|
||||||
which is appended to layers whose indices are specified by `input_skips`.
|
|
||||||
Returns:
|
|
||||||
y: The output tensor of shape
|
|
||||||
`(minibatch, 1, ..., n_ray_pts, input_dim)`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if z is None:
|
|
||||||
# if the skip tensor is None, we use `x` instead.
|
|
||||||
z = x
|
|
||||||
|
|
||||||
y = self.first(x)
|
|
||||||
|
|
||||||
B, n_pool, n_rays, n_pts, dim = y.shape
|
|
||||||
|
|
||||||
# y_p in n_pool, n_pts, B x n_rays x dim
|
|
||||||
y_p = y.permute(1, 3, 0, 2, 4)
|
|
||||||
|
|
||||||
skipi = 0
|
|
||||||
dimh = dim
|
|
||||||
for li, (layer_pool, layer_ray) in enumerate(
|
|
||||||
zip(self.layers_pool, self.layers_ray)
|
|
||||||
):
|
|
||||||
y_pool_attn = y_p.reshape(n_pool, n_pts * B * n_rays, dimh)
|
|
||||||
if li in self._input_skips:
|
|
||||||
z_skip = self.skip_linear[skipi](z)
|
|
||||||
y_pool_attn = y_pool_attn + z_skip.permute(1, 3, 0, 2, 4).reshape(
|
|
||||||
n_pool, n_pts * B * n_rays, dimh
|
|
||||||
)
|
|
||||||
skipi += 1
|
|
||||||
# n_pool x B*n_rays*n_pts x dim
|
|
||||||
y_pool_attn, pool_attn = layer_pool(y_pool_attn, src_key_padding_mask=None)
|
|
||||||
dimh = y_pool_attn.shape[-1]
|
|
||||||
|
|
||||||
y_ray_attn = (
|
|
||||||
y_pool_attn.view(n_pool, n_pts, B * n_rays, dimh)
|
|
||||||
.permute(1, 0, 2, 3)
|
|
||||||
.reshape(n_pts, n_pool * B * n_rays, dimh)
|
|
||||||
)
|
|
||||||
# n_pts x n_pool*B*n_rays x dim
|
|
||||||
y_ray_attn, ray_attn = layer_ray(
|
|
||||||
y_ray_attn,
|
|
||||||
src_key_padding_mask=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
y_p = y_ray_attn.view(n_pts, n_pool, B * n_rays, dimh).permute(1, 0, 2, 3)
|
|
||||||
|
|
||||||
y = y_p.view(n_pool, n_pts, B, n_rays, dimh).permute(2, 0, 3, 1, 4)
|
|
||||||
|
|
||||||
W = torch.softmax(y[..., :1], dim=1)
|
|
||||||
y = (y * W).sum(dim=1)
|
|
||||||
y = self.last(y)
|
|
||||||
|
|
||||||
return y
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(torch.nn.Module):
|
|
||||||
r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
|
|
||||||
This standard encoder layer is based on the paper "Attention Is All You Need".
|
|
||||||
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
|
||||||
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
|
||||||
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
|
||||||
in a different way during application.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
d_model: the number of expected features in the input (required).
|
|
||||||
nhead: the number of heads in the multiheadattention models (required).
|
|
||||||
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
|
||||||
dropout: the dropout value (default=0.1).
|
|
||||||
activation: the activation function of intermediate layer, relu or gelu (default=relu).
|
|
||||||
|
|
||||||
Examples::
|
|
||||||
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
|
||||||
>>> src = torch.rand(10, 32, 512)
|
|
||||||
>>> out = encoder_layer(src)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, d_model_out=-1
|
|
||||||
):
|
|
||||||
super(TransformerEncoderLayer, self).__init__()
|
|
||||||
self.self_attn = torch.nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
|
||||||
# Implementation of Feedforward model
|
|
||||||
self.linear1 = torch.nn.Linear(d_model, dim_feedforward)
|
|
||||||
self.dropout = torch.nn.Dropout(dropout)
|
|
||||||
d_model_out = d_model if d_model_out <= 0 else d_model_out
|
|
||||||
self.linear2 = torch.nn.Linear(dim_feedforward, d_model_out)
|
|
||||||
self.norm1 = torch.nn.LayerNorm(d_model)
|
|
||||||
self.norm2 = torch.nn.LayerNorm(d_model_out)
|
|
||||||
self.dropout1 = torch.nn.Dropout(dropout)
|
|
||||||
self.dropout2 = torch.nn.Dropout(dropout)
|
|
||||||
|
|
||||||
self.activation = torch.nn.functional.relu
|
|
||||||
|
|
||||||
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
|
||||||
r"""Pass the input through the encoder layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
src: the sequence to the encoder layer (required).
|
|
||||||
src_mask: the mask for the src sequence (optional).
|
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
|
||||||
|
|
||||||
Shape:
|
|
||||||
see the docs in Transformer class.
|
|
||||||
"""
|
|
||||||
src2, attn = self.self_attn(
|
|
||||||
src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
|
||||||
)
|
|
||||||
src = src + self.dropout1(src2)
|
|
||||||
src = self.norm1(src)
|
|
||||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
|
||||||
d_out = src2.shape[-1]
|
|
||||||
src = src[..., :d_out] + self.dropout2(src2)[..., :d_out]
|
|
||||||
src = self.norm2(src)
|
|
||||||
return src, attn
|
|
||||||
|
|
||||||
|
|
||||||
def _xavier_init(linear) -> None:
|
|
||||||
"""
|
|
||||||
Performs the Xavier weight initialization of the linear layer `linear`.
|
|
||||||
"""
|
|
||||||
torch.nn.init.xavier_uniform_(linear.weight.data)
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user