mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-01-17 03:40:34 +08:00
Summary: Formats the covered files with pyfmt. paintitblack Reviewed By: itamaro Differential Revision: D90476295 fbshipit-source-id: 5101d4aae980a9f8955a4cb10bae23997c48837f
140 lines
4.6 KiB
Python
140 lines
4.6 KiB
Python
# @lint-ignore-every LICENSELINT
|
|
# Adapted from RenderingNetwork from IDR
|
|
# https://github.com/lioryariv/idr/
|
|
# Copyright (c) 2020 Lior Yariv
|
|
|
|
# pyre-unsafe
|
|
|
|
import logging
|
|
from typing import List, Tuple
|
|
|
|
import torch
|
|
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
|
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
|
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
|
from torch import nn
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RayNormalColoringNetwork(torch.nn.Module):
|
|
"""
|
|
Members:
|
|
d_in and feature_vector_size: Sum of these is the input
|
|
dimension. These must add up to the sum of
|
|
- 3 [for the points]
|
|
- 3 unless mode=no_normal [for the normals]
|
|
- 3 unless mode=no_view_dir [for view directions]
|
|
- the feature size, [number of channels in feature_vectors]
|
|
|
|
d_out: dimension of output.
|
|
mode: One of "idr", "no_view_dir" or "no_normal" to allow omitting
|
|
part of the network input.
|
|
dims: list of hidden layer sizes.
|
|
weight_norm: whether to apply weight normalization to each layer.
|
|
n_harmonic_functions_dir:
|
|
If >0, use a harmonic embedding with this number of
|
|
harmonic functions for the view direction. Otherwise view directions
|
|
are fed without embedding, unless mode is `no_view_dir`.
|
|
pooled_feature_dim: If a pooling function is in use (provided as
|
|
pooling_fn to forward()) this must be its number of features.
|
|
Otherwise this must be set to 0. (If used from GenericModel,
|
|
this will be set automatically.)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
feature_vector_size: int = 3,
|
|
mode: str = "idr",
|
|
d_in: int = 9,
|
|
d_out: int = 3,
|
|
dims: Tuple[int, ...] = (512, 512, 512, 512),
|
|
weight_norm: bool = True,
|
|
n_harmonic_functions_dir: int = 0,
|
|
pooled_feature_dim: int = 0,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.mode = mode
|
|
self.output_dimensions = d_out
|
|
dims_full: List[int] = [d_in + feature_vector_size] + list(dims) + [d_out]
|
|
|
|
self.embedview_fn = None
|
|
if n_harmonic_functions_dir > 0:
|
|
self.embedview_fn = HarmonicEmbedding(
|
|
n_harmonic_functions_dir, append_input=True
|
|
)
|
|
dims_full[0] += self.embedview_fn.get_output_dim() - 3
|
|
|
|
if pooled_feature_dim > 0:
|
|
logger.info("Pooled features in rendering network.")
|
|
dims_full[0] += pooled_feature_dim
|
|
|
|
self.num_layers = len(dims_full)
|
|
|
|
layers = []
|
|
for layer_idx in range(self.num_layers - 1):
|
|
out_dim = dims_full[layer_idx + 1]
|
|
lin = nn.Linear(dims_full[layer_idx], out_dim)
|
|
|
|
if weight_norm:
|
|
lin = nn.utils.weight_norm(lin)
|
|
|
|
layers.append(lin)
|
|
self.linear_layers = torch.nn.ModuleList(layers)
|
|
|
|
self.relu = nn.ReLU()
|
|
self.tanh = nn.Tanh()
|
|
|
|
def forward(
|
|
self,
|
|
feature_vectors: torch.Tensor,
|
|
points,
|
|
normals,
|
|
ray_bundle: ImplicitronRayBundle,
|
|
masks=None,
|
|
pooling_fn=None,
|
|
):
|
|
if masks is not None and not masks.any():
|
|
return torch.zeros_like(normals)
|
|
|
|
view_dirs = ray_bundle.directions
|
|
if masks is not None:
|
|
# in case of IDR, other outputs are passed here after applying the mask
|
|
view_dirs = view_dirs.reshape(view_dirs.shape[0], -1, 3)[
|
|
:, masks.reshape(-1)
|
|
]
|
|
|
|
if self.embedview_fn is not None:
|
|
view_dirs = self.embedview_fn(view_dirs)
|
|
|
|
if self.mode == "idr":
|
|
rendering_input = torch.cat(
|
|
[points, view_dirs, normals, feature_vectors], dim=-1
|
|
)
|
|
elif self.mode == "no_view_dir":
|
|
rendering_input = torch.cat([points, normals, feature_vectors], dim=-1)
|
|
elif self.mode == "no_normal":
|
|
rendering_input = torch.cat([points, view_dirs, feature_vectors], dim=-1)
|
|
else:
|
|
raise ValueError(f"Unsupported rendering mode: {self.mode}")
|
|
|
|
if pooling_fn is not None:
|
|
featspool = pooling_fn(points[None])[0]
|
|
rendering_input = torch.cat((rendering_input, featspool), dim=-1)
|
|
|
|
x = rendering_input
|
|
|
|
for layer_idx in range(self.num_layers - 1):
|
|
x = self.linear_layers[layer_idx](x)
|
|
|
|
if layer_idx < self.num_layers - 2:
|
|
x = self.relu(x)
|
|
|
|
x = self.tanh(x)
|
|
return x
|
|
|
|
|
|
enable_get_default_args(RayNormalColoringNetwork)
|